diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 4c1c8d2..ffd5a1d 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -32,7 +32,7 @@ jobs: - name: Run pytest run: | cd misinformation - python -m pytest -m "not gcv" -svv --cov=. --cov-report=xml + python -m pytest -m "not gcv and not long" -svv --cov=. --cov-report=xml - name: Upload coverage if: matrix.os == 'ubuntu-22.04' && matrix.python-version == '3.9' uses: codecov/codecov-action@v3 diff --git a/docs/source/notebooks/Example multimodal.ipynb b/docs/source/notebooks/Example multimodal.ipynb index 7d0abf0..e9030f2 100644 --- a/docs/source/notebooks/Example multimodal.ipynb +++ b/docs/source/notebooks/Example multimodal.ipynb @@ -113,7 +113,7 @@ " image_keys,\n", " image_names,\n", " features_image_stacked,\n", - ") = ms.MultimodalSearch.parsing_images(mydict, model_type)" + ") = ms.MultimodalSearch.parsing_images(mydict, model_type, path_to_saved_tensors=\".\")" ] }, { @@ -128,7 +128,9 @@ "cell_type": "code", "execution_count": null, "id": "56c6d488-f093-4661-835a-5c73a329c874", - "metadata": {}, + "metadata": { + "tags": [] + }, "outputs": [], "source": [ "# (\n", diff --git a/misinformation/multimodal_search.py b/misinformation/multimodal_search.py index 2b002de..e7d7cdf 100644 --- a/misinformation/multimodal_search.py +++ b/misinformation/multimodal_search.py @@ -4,9 +4,14 @@ import torch.nn.functional as Func import requests import lavis import os +import numpy as np from PIL import Image +from skimage import transform as skimage_transform +from scipy.ndimage import filters +from matplotlib import pyplot as plt from IPython.display import display -from lavis.models import load_model_and_preprocess +from lavis.models import load_model_and_preprocess, load_model, BlipBase +from lavis.processors import load_processor class MultimodalSearch(AnalysisMethod): @@ -233,7 +238,7 @@ class MultimodalSearch(AnalysisMethod): } for query in search_query: - if not (len(query) == 1) and (query in ("image", "text_input")): + if len(query) != 1 and (query in ("image", "text_input")): raise SyntaxError( 'Each query must contain either an "image" or a "text_input"' ) @@ -343,7 +348,288 @@ class MultimodalSearch(AnalysisMethod): self[image_keys[key]][list(search_query[q].values())[0]] = 0 return similarity, sorted_lists - def show_results(self, query): + def itm_text_precessing(self, search_query): + for query in search_query: + if not (len(query) == 1) and (query in ("image", "text_input")): + raise SyntaxError( + 'Each querry must contain either an "image" or a "text_input"' + ) + text_query_index = [] + for i, query in zip(range(len(search_query)), search_query): + if "text_input" in query.keys(): + text_query_index.append(i) + + return text_query_index + + def get_pathes_from_query(self, query): + paths = [] + image_names = [] + for s in sorted( + self.items(), key=lambda t: t[1][list(query.values())[0]], reverse=True + ): + if s[1]["rank " + list(query.values())[0]] is None: + break + paths.append(s[1]["filename"]) + image_names.append(s[0]) + return paths, image_names + + def read_and_process_images_itm(self, image_paths, vis_processor): + raw_images = [MultimodalSearch.read_img(self, path) for path in image_paths] + images = [vis_processor(r_img) for r_img in raw_images] + images_tensors = torch.stack(images).to(MultimodalSearch.multimodal_device) + + return raw_images, images_tensors + + def compute_gradcam_batch( + self, + model, + visual_input, + text_input, + tokenized_text, + block_num=6, + ): + model.text_encoder.base_model.base_model.encoder.layer[ + block_num + ].crossattention.self.save_attention = True + + output = model( + {"image": visual_input, "text_input": text_input}, match_head="itm" + ) + loss = output[:, 1].sum() + + model.zero_grad() + loss.backward() + with torch.no_grad(): + mask = tokenized_text.attention_mask.view( + tokenized_text.attention_mask.size(0), 1, -1, 1, 1 + ) # (bsz,1,token_len, 1,1) + token_length = mask.sum() - 2 + token_length = token_length.cpu() + # grads and cams [bsz, num_head, seq_len, image_patch] + grads = model.text_encoder.base_model.base_model.encoder.layer[ + block_num + ].crossattention.self.get_attn_gradients() + cams = model.text_encoder.base_model.base_model.encoder.layer[ + block_num + ].crossattention.self.get_attention_map() + + # assume using vit large with 576 num image patch + cams = ( + cams[:, :, :, 1:].reshape(visual_input.size(0), 12, -1, 24, 24) * mask + ) + grads = ( + grads[:, :, :, 1:] + .clamp(0) + .reshape(visual_input.size(0), 12, -1, 24, 24) + * mask + ) + + gradcam = cams * grads + # [enc token gradcam, average gradcam across token, gradcam for individual token] + # gradcam = torch.cat((gradcam[0:1,:], gradcam[1:token_length+1, :].sum(dim=0, keepdim=True)/token_length, gradcam[1:, :])) + gradcam = gradcam.mean(1).cpu().detach() + gradcam = ( + gradcam[:, 1 : token_length + 1, :].sum(dim=1, keepdim=True) + / token_length + ) + + return gradcam, output + + def resize_img(self, raw_img): + w, h = raw_img.size + scaling_factor = 240 / w + resized_image = raw_img.resize( + (int(w * scaling_factor), int(h * scaling_factor)) + ) + return resized_image + + def get_att_map(self, img, att_map, blur=True, overlap=True): + att_map -= att_map.min() + if att_map.max() > 0: + att_map /= att_map.max() + att_map = skimage_transform.resize( + att_map, (img.shape[:2]), order=3, mode="constant" + ) + if blur: + att_map = filters.gaussian_filter(att_map, 0.02 * max(img.shape[:2])) + att_map -= att_map.min() + att_map /= att_map.max() + cmap = plt.get_cmap("jet") + att_mapv = cmap(att_map) + att_mapv = np.delete(att_mapv, 3, 2) + if overlap: + att_map = ( + 1 * (1 - att_map**0.7).reshape(att_map.shape + (1,)) * img + + (att_map**0.7).reshape(att_map.shape + (1,)) * att_mapv + ) + return att_map + + def upload_model_blip2_coco(self): + itm_model = load_model( + "blip2_image_text_matching", + "coco", + is_eval=True, + device=MultimodalSearch.multimodal_device, + ) + vis_processor = load_processor("blip_image_eval").build(image_size=364) + return itm_model, vis_processor + + def upload_model_blip_base(self): + itm_model = load_model( + "blip_image_text_matching", + "base", + is_eval=True, + device=MultimodalSearch.multimodal_device, + ) + vis_processor = load_processor("blip_image_eval").build(image_size=384) + return itm_model, vis_processor + + def upload_model_blip_large(self): + itm_model = load_model( + "blip_image_text_matching", + "large", + is_eval=True, + device=MultimodalSearch.multimodal_device, + ) + vis_processor = load_processor("blip_image_eval").build(image_size=384) + return itm_model, vis_processor + + def image_text_match_reordering( + self, + search_query, + itm_model_type, + image_keys, + sorted_lists, + batch_size=1, + need_grad_cam=False, + ): + if itm_model_type == "blip2_coco" and need_grad_cam is True: + raise SyntaxError( + "The blip2_coco model does not yet work with gradcam. Please set need_grad_cam to False" + ) + + choose_model = { + "blip_base": MultimodalSearch.upload_model_blip_base, + "blip_large": MultimodalSearch.upload_model_blip_large, + "blip2_coco": MultimodalSearch.upload_model_blip2_coco, + } + + itm_model, vis_processor_itm = choose_model[itm_model_type](self) + text_processor = load_processor("blip_caption") + tokenizer = BlipBase.init_tokenizer() + + if itm_model_type == "blip2_coco": + need_grad_cam = False + + text_query_index = MultimodalSearch.itm_text_precessing(self, search_query) + + avg_gradcams = [] + itm_scores = [] + itm_scores2 = [] + image_gradcam_with_itm = {} + + for index_text_query in text_query_index: + query = search_query[index_text_query] + pathes, image_names = MultimodalSearch.get_pathes_from_query(self, query) + num_batches = int(len(pathes) / batch_size) + num_batches_residue = len(pathes) % batch_size + + local_itm_scores = [] + local_avg_gradcams = [] + + if num_batches_residue != 0: + num_batches = num_batches + 1 + for i in range(num_batches): + filenames_in_batch = pathes[i * batch_size : (i + 1) * batch_size] + current_len = len(filenames_in_batch) + raw_images, images = MultimodalSearch.read_and_process_images_itm( + self, filenames_in_batch, vis_processor_itm + ) + queries_batch = [text_processor(query["text_input"])] * current_len + queries_tok_batch = tokenizer(queries_batch, return_tensors="pt").to( + MultimodalSearch.multimodal_device + ) + + if need_grad_cam: + gradcam, itm_output = MultimodalSearch.compute_gradcam_batch( + self, + itm_model, + images, + queries_batch, + queries_tok_batch, + ) + norm_imgs = [np.float32(r_img) / 255 for r_img in raw_images] + + for norm_img, grad_cam in zip(norm_imgs, gradcam): + avg_gradcam = MultimodalSearch.get_att_map( + self, norm_img, np.float32(grad_cam[0]), blur=True + ) + local_avg_gradcams.append(avg_gradcam) + + else: + itm_output = itm_model( + {"image": images, "text_input": queries_batch}, match_head="itm" + ) + + with torch.no_grad(): + itm_score = torch.nn.functional.softmax(itm_output, dim=1) + + local_itm_scores.append(itm_score) + + local_itm_scores2 = torch.cat(local_itm_scores)[:, 1] + if need_grad_cam: + localimage_gradcam_with_itm = { + n: i * 255 for n, i in zip(image_names, local_avg_gradcams) + } + else: + localimage_gradcam_with_itm = "" + image_names_with_itm = { + n: i.item() for n, i in zip(image_names, local_itm_scores2) + } + itm_rank = torch.argsort(local_itm_scores2, descending=True) + image_names_with_new_rank = { + image_names[i.item()]: rank + for i, rank in zip(itm_rank, range(len(itm_rank))) + } + for i, key in zip(range(len(image_keys)), sorted_lists[index_text_query]): + if image_keys[key] in image_names: + self[image_keys[key]][ + "itm " + list(search_query[index_text_query].values())[0] + ] = image_names_with_itm[image_keys[key]] + self[image_keys[key]][ + "itm_rank " + list(search_query[index_text_query].values())[0] + ] = image_names_with_new_rank[image_keys[key]] + else: + self[image_keys[key]][ + "itm " + list(search_query[index_text_query].values())[0] + ] = 0 + self[image_keys[key]][ + "itm_rank " + list(search_query[index_text_query].values())[0] + ] = None + + avg_gradcams.append(local_avg_gradcams) + itm_scores.append(local_itm_scores) + itm_scores2.append(local_itm_scores2) + image_gradcam_with_itm[ + list(search_query[index_text_query].values())[0] + ] = localimage_gradcam_with_itm + del ( + itm_model, + vis_processor_itm, + text_processor, + raw_images, + images, + tokenizer, + queries_batch, + queries_tok_batch, + itm_score, + ) + if need_grad_cam: + del itm_output, gradcam, norm_img, grad_cam, avg_gradcam + torch.cuda.empty_cache() + return itm_scores2, image_gradcam_with_itm + + def show_results(self, query, itm=False, image_gradcam_with_itm=False): if "image" in query.keys(): pic = Image.open(query["image"]).convert("RGB") pic.thumbnail((400, 400)) @@ -359,18 +645,29 @@ class MultimodalSearch(AnalysisMethod): "--------------------------------------------------", "Results:", ) + if itm: + current_querry_val = "itm " + list(query.values())[0] + current_querry_rank = "itm_rank " + list(query.values())[0] + else: + current_querry_val = list(query.values())[0] + current_querry_rank = "rank " + list(query.values())[0] + for s in sorted( - self.items(), key=lambda t: t[1][list(query.values())[0]], reverse=True + self.items(), key=lambda t: t[1][current_querry_val], reverse=True ): - if s[1]["rank " + list(query.values())[0]] is None: + if s[1][current_querry_rank] is None: break - p1 = Image.open(s[1]["filename"]).convert("RGB") + if image_gradcam_with_itm is False: + p1 = Image.open(s[1]["filename"]).convert("RGB") + else: + image = image_gradcam_with_itm[list(query.values())[0]][s[0]] + p1 = Image.fromarray(image.astype("uint8"), "RGB") p1.thumbnail((400, 400)) display( "Rank: " - + str(s[1]["rank " + list(query.values())[0]]) + + str(s[1][current_querry_rank]) + " Val: " - + str(s[1][list(query.values())[0]]), + + str(s[1][current_querry_val]), s[0], p1, ) diff --git a/misinformation/test/conftest.py b/misinformation/test/conftest.py index 3a43142..cb42774 100644 --- a/misinformation/test/conftest.py +++ b/misinformation/test/conftest.py @@ -16,3 +16,33 @@ def set_environ(request): mypath + "/../../data/seismic-bonfire-329406-412821a70264.json" ) print(os.environ.get("GOOGLE_APPLICATION_CREDENTIALS")) + + +@pytest.fixture +def get_testdict(get_path): + testdict = { + "IMG_2746": {"filename": get_path + "IMG_2746.png"}, + "IMG_2809": {"filename": get_path + "IMG_2809.png"}, + } + return testdict + + +@pytest.fixture +def get_test_my_dict(get_path): + test_my_dict = { + "IMG_2746": { + "filename": get_path + "IMG_2746.png", + "rank A bus": 1, + "A bus": 0.15640679001808167, + "rank " + get_path + "IMG_3758.png": 1, + get_path + "IMG_3758.png": 0.7533495426177979, + }, + "IMG_2809": { + "filename": get_path + "IMG_2809.png", + "rank A bus": 0, + "A bus": 0.1970970332622528, + "rank " + get_path + "IMG_3758.png": 0, + get_path + "IMG_3758.png": 0.8907483816146851, + }, + } + return test_my_dict diff --git a/misinformation/test/test_multimodal_search.py b/misinformation/test/test_multimodal_search.py index 98ce575..1bfb954 100644 --- a/misinformation/test/test_multimodal_search.py +++ b/misinformation/test/test_multimodal_search.py @@ -5,22 +5,17 @@ import numpy from torch import device, cuda import misinformation.multimodal_search as ms - -testdict = { - "IMG_2746": {"filename": "./test/data/IMG_2746.png"}, - "IMG_2809": {"filename": "./test/data/IMG_2809.png"}, -} - related_error = 1e-2 gpu_is_not_available = not cuda.is_available() - cuda.empty_cache() -def test_read_img(): +def test_read_img(get_testdict): my_dict = {} - test_img = ms.MultimodalSearch.read_img(my_dict, testdict["IMG_2746"]["filename"]) + test_img = ms.MultimodalSearch.read_img( + my_dict, get_testdict["IMG_2746"]["filename"] + ) assert list(numpy.array(test_img)[257][34]) == [70, 66, 63] @@ -205,29 +200,29 @@ dict_image_gradcam_with_itm_for_blip = { "pre_sorted", ), [ - # ( - # device("cpu"), - # "blip2", - # pre_proc_pic_blip2_blip_albef, - # pre_proc_text_blip2_blip_albef, - # pre_extracted_feature_img_blip2, - # pre_extracted_feature_text_blip2, - # simularity_blip2, - # sorted_blip2, - # ), - # pytest.param( - # device("cuda"), - # "blip2", - # pre_proc_pic_blip2_blip_albef, - # pre_proc_text_blip2_blip_albef, - # pre_extracted_feature_img_blip2, - # pre_extracted_feature_text_blip2, - # simularity_blip2, - # sorted_blip2, - # marks=pytest.mark.skipif( - # gpu_is_not_available, reason="gpu_is_not_availible" - # ), - # ), + ( + device("cpu"), + "blip2", + pre_proc_pic_blip2_blip_albef, + pre_proc_text_blip2_blip_albef, + pre_extracted_feature_img_blip2, + pre_extracted_feature_text_blip2, + simularity_blip2, + sorted_blip2, + ), + pytest.param( + device("cuda"), + "blip2", + pre_proc_pic_blip2_blip_albef, + pre_proc_text_blip2_blip_albef, + pre_extracted_feature_img_blip2, + pre_extracted_feature_text_blip2, + simularity_blip2, + sorted_blip2, + marks=pytest.mark.skipif( + gpu_is_not_available, reason="gpu_is_not_availible" + ), + ), ( device("cpu"), "blip", @@ -354,6 +349,8 @@ def test_parsing_images( pre_extracted_feature_text, pre_simularity, pre_sorted, + get_path, + get_testdict, tmp_path, ): ms.MultimodalSearch.multimodal_device = pre_multimodal_device @@ -365,7 +362,7 @@ def test_parsing_images( _, features_image_stacked, ) = ms.MultimodalSearch.parsing_images( - testdict, pre_model, path_to_saved_tensors=tmp_path + get_testdict, pre_model, path_to_saved_tensors=tmp_path ) for i, num in zip(range(10), features_image_stacked[0, 10:12].tolist()): @@ -374,7 +371,7 @@ def test_parsing_images( is True ) - test_pic = Image.open(testdict["IMG_2746"]["filename"]).convert("RGB") + test_pic = Image.open(get_testdict["IMG_2746"]["filename"]).convert("RGB") test_querry = ( "The bird sat on a tree located at the intersection of 23rd and 43rd streets." ) @@ -390,10 +387,10 @@ def test_parsing_images( search_query = [ {"text_input": test_querry}, - {"image": testdict["IMG_2746"]["filename"]}, + {"image": get_testdict["IMG_2746"]["filename"]}, ] multi_features_stacked = ms.MultimodalSearch.querys_processing( - testdict, search_query, model, txt_processor, vis_processor, pre_model + get_testdict, search_query, model, txt_processor, vis_processor, pre_model ) for i, num in zip(range(10), multi_features_stacked[0, 10:12].tolist()): @@ -410,11 +407,11 @@ def test_parsing_images( search_query2 = [ {"text_input": "A bus"}, - {"image": "../misinformation/test/data/IMG_3758.png"}, + {"image": get_path + "IMG_3758.png"}, ] similarity, sorted_list = ms.MultimodalSearch.multimodal_search( - testdict, + get_testdict, model, vis_processor, txt_processor, @@ -445,3 +442,81 @@ def test_parsing_images( multi_features_stacked, ) cuda.empty_cache() + + +@pytest.mark.long +def test_itm(get_test_my_dict, get_path): + search_query3 = [ + {"text_input": "A bus"}, + {"image": get_path + "IMG_3758.png"}, + ] + image_keys = ["IMG_2746", "IMG_2809"] + sorted_list = [[1, 0], [1, 0]] + for itm_model in ["blip_base", "blip_large"]: + ( + itm_scores, + image_gradcam_with_itm, + ) = ms.MultimodalSearch.image_text_match_reordering( + get_test_my_dict, + search_query3, + itm_model, + image_keys, + sorted_list, + batch_size=1, + need_grad_cam=True, + ) + for i, itm in zip( + range(len(dict_itm_scores_for_blib[itm_model])), + dict_itm_scores_for_blib[itm_model], + ): + assert ( + math.isclose(itm_scores[0].tolist()[i], itm, rel_tol=10 * related_error) + is True + ) + for i, grad_cam in zip( + range(len(dict_image_gradcam_with_itm_for_blip[itm_model])), + dict_image_gradcam_with_itm_for_blip[itm_model], + ): + assert ( + math.isclose( + image_gradcam_with_itm["A bus"]["IMG_2809"][0][0].tolist()[i], + grad_cam, + rel_tol=10 * related_error, + ) + is True + ) + del itm_scores, image_gradcam_with_itm + cuda.empty_cache() + + +@pytest.mark.long +def test_itm_blip2_coco(get_test_my_dict, get_path): + search_query3 = [ + {"text_input": "A bus"}, + {"image": get_path + "IMG_3758.png"}, + ] + image_keys = ["IMG_2746", "IMG_2809"] + sorted_list = [[1, 0], [1, 0]] + + ( + itm_scores, + image_gradcam_with_itm, + ) = ms.MultimodalSearch.image_text_match_reordering( + get_test_my_dict, + search_query3, + "blip2_coco", + image_keys, + sorted_list, + batch_size=1, + need_grad_cam=False, + ) + for i, itm in zip( + range(len(dict_itm_scores_for_blib["blip2_coco"])), + dict_itm_scores_for_blib["blip2_coco"], + ): + assert ( + math.isclose(itm_scores[0].tolist()[i], itm, rel_tol=10 * related_error) + is True + ) + del itm_scores, image_gradcam_with_itm + cuda.empty_cache() diff --git a/notebooks/multimodal_search.ipynb b/notebooks/multimodal_search.ipynb index d7b0abb..6ff9b4f 100644 --- a/notebooks/multimodal_search.ipynb +++ b/notebooks/multimodal_search.ipynb @@ -138,9 +138,7 @@ " image_keys,\n", " image_names,\n", " features_image_stacked,\n", - ") = ms.MultimodalSearch.parsing_images(\n", - " mydict, model_type, path_to_saved_tensors=\"./saved_tensors/\"\n", - ")" + ") = ms.MultimodalSearch.parsing_images(mydict, model_type, path_to_saved_tensors=\".\")" ] }, { @@ -170,7 +168,7 @@ "# ) = ms.MultimodalSearch.parsing_images(\n", "# mydict,\n", "# model_type,\n", - "# path_to_load_tensors=\"./saved_tensors/18_blip_saved_features_image.pt\",\n", + "# path_to_load_tensors=\".5_blip_saved_features_image.pt\",\n", "# )" ] }, @@ -194,15 +192,14 @@ "cell_type": "code", "execution_count": null, "id": "c4196a52-d01e-42e4-8674-5712f7d6f792", - "metadata": {}, + "metadata": { + "tags": [] + }, "outputs": [], "source": [ "search_query3 = [\n", " {\"text_input\": \"politician press conference\"},\n", " {\"text_input\": \"a world map\"},\n", - " {\"image\": \"../data/haos.png\"},\n", - " {\"image\": \"../data/image-34098-800.png\"},\n", - " {\"image\": \"../data/LeonPresserMorocco20032015_600.png\"},\n", " {\"text_input\": \"a dog\"},\n", "]" ] @@ -222,10 +219,12 @@ "cell_type": "code", "execution_count": null, "id": "7f7dc52f-7ee9-4590-96b7-e0d9d3b82378", - "metadata": {}, + "metadata": { + "tags": [] + }, "outputs": [], "source": [ - "similarity = ms.MultimodalSearch.multimodal_search(\n", + "similarity, sorted_lists = ms.MultimodalSearch.multimodal_search(\n", " mydict,\n", " model,\n", " vis_processors,\n", @@ -234,6 +233,7 @@ " image_keys,\n", " features_image_stacked,\n", " search_query3,\n", + " filter_number_of_images=20,\n", ")" ] }, @@ -249,10 +249,12 @@ "cell_type": "code", "execution_count": null, "id": "9ad74b21-6187-4a58-9ed8-fd3e80f5a4ed", - "metadata": {}, + "metadata": { + "tags": [] + }, "outputs": [], "source": [ - "mydict[\"100127S_ara\"]" + "mydict[\"109237S_spa\"]" ] }, { @@ -267,10 +269,79 @@ "cell_type": "code", "execution_count": null, "id": "4324e4fd-e9aa-4933-bb12-074d54e0c510", - "metadata": {}, + "metadata": { + "tags": [] + }, "outputs": [], "source": [ - "ms.MultimodalSearch.show_results(mydict, search_query3[4])" + "ms.MultimodalSearch.show_results(\n", + " mydict,\n", + " search_query3[0],\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "0b750e9f-fe64-4028-9caf-52d7187462f1", + "metadata": {}, + "source": [ + "For even better results, a slightly different approach has been prepared that can improve search results. It is quite resource-intensive, so it is applied after the main algorithm has found the most relevant images. This approach works only with text queries. Among the parameters you can choose 3 models: `\"blip_base\"`, `\"blip_large\"`, `\"blip2_coco\"`. If you get the Out of Memory error, try reducing the batch_size value (minimum = 1), which is the number of images being processed simultaneously. With the parameter `need_grad_cam = True/False` you can enable the calculation of the heat map of each image to be processed. Thus the `image_text_match_reordering` function calculates new similarity values and new ranks for each image. The resulting values are added to the general dictionary." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b3af7b39-6d0d-4da3-9b8f-7dfd3f5779be", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "itm_model = \"blip_base\"\n", + "# itm_model = \"blip_large\"\n", + "# itm_model = \"blip2_coco\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "caf1f4ae-4b37-4954-800e-7120f0419de5", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "itm_scores, image_gradcam_with_itm = ms.MultimodalSearch.image_text_match_reordering(\n", + " mydict,\n", + " search_query3,\n", + " itm_model,\n", + " image_keys,\n", + " sorted_lists,\n", + " batch_size=1,\n", + " need_grad_cam=True,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "9e98c150-5fab-4251-bce7-0d8fc7b385b9", + "metadata": {}, + "source": [ + "Then using the same output function you can add the `ITM=True` arguments to output the new image order. You can also add the `image_gradcam_with_itm` argument to output the heat maps of the calculated images. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6a829b99-5230-463a-8b11-30ffbb67fc3a", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "ms.MultimodalSearch.show_results(\n", + " mydict, search_query3[0], itm=True, image_gradcam_with_itm=image_gradcam_with_itm\n", + ")" ] }, { @@ -320,7 +391,9 @@ "cell_type": "code", "execution_count": null, "id": "e78646d6-80be-4d3e-8123-3360957bcaa8", - "metadata": {}, + "metadata": { + "tags": [] + }, "outputs": [], "source": [ "df.head(10)" @@ -338,11 +411,21 @@ "cell_type": "code", "execution_count": null, "id": "185f7dde-20dc-44d8-9ab0-de41f9b5734d", - "metadata": {}, + "metadata": { + "tags": [] + }, "outputs": [], "source": [ "df.to_csv(\"./data_out.csv\")" ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b6a79201-7c17-496c-a6a1-b8ecfd3dd1e8", + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": {