diff --git a/misinformation/multimodal_search.py b/misinformation/multimodal_search.py index 1d1aee4..de0df14 100644 --- a/misinformation/multimodal_search.py +++ b/misinformation/multimodal_search.py @@ -374,10 +374,9 @@ class MultimodalSearch(AnalysisMethod): tokenized_text, block_num=6, ): - if itm_model_type != "blip2_coco": - model.text_encoder.base_model.base_model.encoder.layer[ - block_num - ].crossattention.self.save_attention = True + model.text_encoder.base_model.base_model.encoder.layer[ + block_num + ].crossattention.self.save_attention = True output = model( {"image": visual_input, "text_input": text_input}, match_head="itm" @@ -493,15 +492,23 @@ class MultimodalSearch(AnalysisMethod): batch_size=1, need_grad_cam=False, ): + if itm_model_type == "blip2_coco" and need_grad_cam is True: + raise SyntaxError( + "The blip2_coco model does not yet work with gradcam. Please set need_grad_cam to False" + ) + choose_model = { "blip_base": MultimodalSearch.upload_model_blip_base, "blip_large": MultimodalSearch.upload_model_blip_large, "blip2_coco": MultimodalSearch.upload_model_blip2_coco, } - itm_model, vis_processor = choose_model[itm_model_type](self) + itm_model, vis_processor_itm = choose_model[itm_model_type](self) text_processor = load_processor("blip_caption") tokenizer = BlipBase.init_tokenizer() + if itm_model_type == "blip2_coco": + need_grad_cam = False + text_query_index = MultimodalSearch.itm_text_precessing(self, search_query) avg_gradcams = [] @@ -524,7 +531,7 @@ class MultimodalSearch(AnalysisMethod): filenames_in_batch = pathes[i * batch_size : (i + 1) * batch_size] current_len = len(filenames_in_batch) raw_images, images = MultimodalSearch.read_and_process_images_itm( - self, filenames_in_batch, vis_processor + self, filenames_in_batch, vis_processor_itm ) queries_batch = [text_processor(query["text_input"])] * current_len queries_tok_batch = tokenizer(queries_batch, return_tensors="pt").to( diff --git a/misinformation/test/test_multimodal_search.py b/misinformation/test/test_multimodal_search.py index 8f82e15..668ae76 100644 --- a/misinformation/test/test_multimodal_search.py +++ b/misinformation/test/test_multimodal_search.py @@ -370,20 +370,55 @@ sorted_clip_vitl14_336 = [ [6, 5, 4, 10, 2, 3, 0, 1, 11, 7, 9, 8], ] -itm_scores_for_blib_base = [ - 0.09277121722698212, - 0.020782141014933586, - 0.020832309499382973, - 0.004225197248160839, - 0.00020702541223727167, - 0.00410003075376153, - 0.0009893759852275252, - 0.00015318581426981837, - 1.9936736862291582e-05, - 4.0083985368255526e-05, - 0.0006117734010331333, - 4.1486648115096614e-05, -] +dict_itm_scores_for_blib = { + "blip_base": [ + 0.07107225805521011, + 0.02078203856945038, + 0.02083236537873745, + 0.0042252070270478725, + 0.0002070252230623737, + 0.004100032616406679, + 0.0009893750539049506, + 0.00015318625082727522, + 1.9936736862291582e-05, + 4.0084025386022404e-05, + 0.0006117739249020815, + 4.1486648115096614e-05, + ], + "blip_large": [ + 0.07890705019235611, + 0.04954551160335541, + 0.05564938113093376, + 0.002710158471018076, + 0.0026644798927009106, + 0.01277624536305666, + 0.003585426602512598, + 0.0019450040999799967, + 0.0036240608897060156, + 0.0013280785642564297, + 0.015366943553090096, + 0.0030039174016565084, + ], + "blip2_coco": [ + 0.0833505243062973, + 0.046232130378484726, + 0.04996354877948761, + 0.004187352955341339, + 2.5233526685042307e-05, + 0.002679687924683094, + 2.4826533262967132e-05, + 5.1878203521482646e-05, + 1.3434584616334178e-05, + 9.76747560343938e-06, + 7.34204331820365e-06, + 1.1423194337112363e-05, + ], +} + +dict_image_gradcam_with_itm_for_blip = { + "blip_base": [125.12124404, 132.07243145, 65.43589668], + "blip_large": [118.75610679, 125.35366997, 69.63849807], +} @pytest.mark.parametrize( @@ -630,7 +665,7 @@ def test_parsing_images( cuda.empty_cache() if pre_model == "blip": - for itm_model in ["blip_base", "blip_large", "blip2_coco"]: + for itm_model in ["blip_base","blip_large","blip2_coco"]: ( itm_scores, image_gradcam_with_itm, @@ -641,14 +676,13 @@ def test_parsing_images( image_keys, sorted_list, batch_size=1, - need_grad_cam=True, + need_grad_cam=False, ) for i, itm in zip( - range(len(itm_scores_for_blib_base)), itm_scores_for_blib_base + range(len(dict_itm_scores_for_blib[itm_model])), + dict_itm_scores_for_blib[itm_model], ): assert ( - math.isclose( - itm_scores[0].tolist()[i], itm, rel_tol=100 * related_error - ) + math.isclose(itm_scores[0].tolist()[i], itm, rel_tol=related_error) is True )