diff --git a/ammico/test/test_multimodal_search.py b/ammico/test/test_multimodal_search.py index 4ba0733..e3ae791 100644 --- a/ammico/test/test_multimodal_search.py +++ b/ammico/test/test_multimodal_search.py @@ -354,6 +354,7 @@ def test_parsing_images( tmp_path, ): ms.MultimodalSearch.multimodal_device = pre_multimodal_device + my_obj = ms.MultimodalSearch(get_testdict) ( model, vis_processor, @@ -361,9 +362,7 @@ def test_parsing_images( image_keys, _, features_image_stacked, - ) = ms.MultimodalSearch.parsing_images( - get_testdict, pre_model, path_to_saved_tensors=tmp_path - ) + ) = my_obj.parsing_images(pre_model, path_to_save_tensors=tmp_path) for i, num in zip(range(10), features_image_stacked[0, 10:12].tolist()): assert ( @@ -371,7 +370,7 @@ def test_parsing_images( is True ) - test_pic = Image.open(get_testdict["IMG_2746"]["filename"]).convert("RGB") + test_pic = Image.open(my_obj.subdict["IMG_2746"]["filename"]).convert("RGB") test_querry = ( "The bird sat on a tree located at the intersection of 23rd and 43rd streets." ) @@ -387,10 +386,10 @@ def test_parsing_images( search_query = [ {"text_input": test_querry}, - {"image": get_testdict["IMG_2746"]["filename"]}, + {"image": my_obj.subdict["IMG_2746"]["filename"]}, ] - multi_features_stacked = ms.MultimodalSearch.querys_processing( - get_testdict, search_query, model, txt_processor, vis_processor, pre_model + multi_features_stacked = my_obj.querys_processing( + search_query, model, txt_processor, vis_processor, pre_model ) for i, num in zip(range(10), multi_features_stacked[0, 10:12].tolist()): @@ -410,8 +409,7 @@ def test_parsing_images( {"image": get_path + "IMG_3758.png"}, ] - similarity, sorted_list = ms.MultimodalSearch.multimodal_search( - get_testdict, + similarity, sorted_list = my_obj.multimodal_search( model, vis_processor, txt_processor, @@ -440,6 +438,7 @@ def test_parsing_images( features_image_stacked, processed_pic, multi_features_stacked, + my_obj, ) cuda.empty_cache() @@ -452,12 +451,12 @@ def test_itm(get_test_my_dict, get_path): ] image_keys = ["IMG_2746", "IMG_2809"] sorted_list = [[1, 0], [1, 0]] + my_obj = ms.MultimodalSearch(get_test_my_dict) for itm_model in ["blip_base", "blip_large"]: ( itm_scores, image_gradcam_with_itm, - ) = ms.MultimodalSearch.image_text_match_reordering( - get_test_my_dict, + ) = my_obj.image_text_match_reordering( search_query3, itm_model, image_keys, @@ -497,12 +496,12 @@ def test_itm_blip2_coco(get_test_my_dict, get_path): ] image_keys = ["IMG_2746", "IMG_2809"] sorted_list = [[1, 0], [1, 0]] + my_obj = ms.MultimodalSearch(get_test_my_dict) ( itm_scores, image_gradcam_with_itm, - ) = ms.MultimodalSearch.image_text_match_reordering( - get_test_my_dict, + ) = my_obj.image_text_match_reordering( search_query3, "blip2_coco", image_keys,