diff --git a/misinformation/multimodal_search.py b/misinformation/multimodal_search.py index 3405493..854469b 100644 --- a/misinformation/multimodal_search.py +++ b/misinformation/multimodal_search.py @@ -3,9 +3,14 @@ import torch import torch.nn.functional as Func import requests import lavis +import numpy as np from PIL import Image +from skimage import transform as skimage_transform +from scipy.ndimage import filters +from matplotlib import pyplot as plt from IPython.display import display -from lavis.models import load_model_and_preprocess +from lavis.models import load_model_and_preprocess, load_model, BlipBase +from lavis.processors import load_processor class MultimodalSearch(AnalysisMethod): @@ -301,7 +306,6 @@ class MultimodalSearch(AnalysisMethod): for q in range(len(search_query)): max_val = similarity[sorted_lists[q][0]][q].item() - print(max_val) for i, key in zip(range(len(image_keys)), sorted_lists[q]): if ( i < filter_number_of_images @@ -322,7 +326,278 @@ class MultimodalSearch(AnalysisMethod): self[image_keys[key]][list(search_query[q].values())[0]] = 0 return similarity, sorted_lists - def show_results(self, query): + def itm_text_precessing(self, search_query): + for query in search_query: + if not (len(query) == 1) and (query in ("image", "text_input")): + raise SyntaxError( + 'Each querry must contain either an "image" or a "text_input"' + ) + text_query_index = [] + for i, query in zip(range(len(search_query)), search_query): + if "text_input" in query.keys(): + text_query_index.append(i) + + return text_query_index + + def itm_images_processing(self, image_paths, vis_processor): + raw_images = [MultimodalSearch.read_img(self, path) for path in image_paths] + images = [vis_processor(r_img) for r_img in raw_images] + images_tensors = torch.stack(images).to(MultimodalSearch.multimodal_device) + + return raw_images, images_tensors + + def get_pathes_from_query(self, query): + paths = [] + image_names = [] + for s in sorted( + self.items(), key=lambda t: t[1][list(query.values())[0]], reverse=True + ): + if s[1]["rank " + list(query.values())[0]] is None: + break + paths.append(s[1]["filename"]) + image_names.append(s[0]) + return paths, image_names + + def read_and_process_images_itm(self, image_paths, vis_processor): + raw_images = [MultimodalSearch.read_img(self, path) for path in image_paths] + images = [vis_processor(r_img) for r_img in raw_images] + images_tensors = torch.stack(images).to(MultimodalSearch.multimodal_device) + + return raw_images, images_tensors + + def compute_gradcam_batch( + self, + itm_model_type, + model, + visual_input, + text_input, + tokenized_text, + block_num=6, + ): + if itm_model_type != "blip2_coco": + model.text_encoder.base_model.base_model.encoder.layer[ + block_num + ].crossattention.self.save_attention = True + + output = model( + {"image": visual_input, "text_input": text_input}, match_head="itm" + ) + loss = output[:, 1].sum() + + model.zero_grad() + loss.backward() + with torch.no_grad(): + mask = tokenized_text.attention_mask.view( + tokenized_text.attention_mask.size(0), 1, -1, 1, 1 + ) # (bsz,1,token_len, 1,1) + token_length = mask.sum() - 2 + token_length = token_length.cpu() + # grads and cams [bsz, num_head, seq_len, image_patch] + grads = model.text_encoder.base_model.base_model.encoder.layer[ + block_num + ].crossattention.self.get_attn_gradients() + cams = model.text_encoder.base_model.base_model.encoder.layer[ + block_num + ].crossattention.self.get_attention_map() + + # assume using vit large with 576 num image patch + cams = ( + cams[:, :, :, 1:].reshape(visual_input.size(0), 12, -1, 24, 24) * mask + ) + grads = ( + grads[:, :, :, 1:] + .clamp(0) + .reshape(visual_input.size(0), 12, -1, 24, 24) + * mask + ) + + gradcam = cams * grads + # [enc token gradcam, average gradcam across token, gradcam for individual token] + # gradcam = torch.cat((gradcam[0:1,:], gradcam[1:token_length+1, :].sum(dim=0, keepdim=True)/token_length, gradcam[1:, :])) + gradcam = gradcam.mean(1).cpu().detach() + gradcam = ( + gradcam[:, 1 : token_length + 1, :].sum(dim=1, keepdim=True) + / token_length + ) + + return gradcam, output + + def resize_img(self, raw_img): + w, h = raw_img.size + scaling_factor = 240 / w + resized_image = raw_img.resize( + (int(w * scaling_factor), int(h * scaling_factor)) + ) + return resized_image + + def getAttMap(self, img, attMap, blur=True, overlap=True): + attMap -= attMap.min() + if attMap.max() > 0: + attMap /= attMap.max() + attMap = skimage_transform.resize( + attMap, (img.shape[:2]), order=3, mode="constant" + ) + if blur: + attMap = filters.gaussian_filter(attMap, 0.02 * max(img.shape[:2])) + attMap -= attMap.min() + attMap /= attMap.max() + cmap = plt.get_cmap("jet") + attMapV = cmap(attMap) + attMapV = np.delete(attMapV, 3, 2) + if overlap: + attMap = ( + 1 * (1 - attMap**0.7).reshape(attMap.shape + (1,)) * img + + (attMap**0.7).reshape(attMap.shape + (1,)) * attMapV + ) + return attMap + + def upload_model_blip2_coco(self): + + itm_model = load_model( + "blip2_image_text_matching", + "coco", + is_eval=True, + device=MultimodalSearch.multimodal_device, + ) + vis_processor = load_processor("blip_image_eval").build(image_size=364) + return itm_model, vis_processor + + def upload_model_blip_base(self): + + itm_model = load_model( + "blip_image_text_matching", + "base", + is_eval=True, + device=MultimodalSearch.multimodal_device, + ) + vis_processor = load_processor("blip_image_eval").build(image_size=384) + return itm_model, vis_processor + + def upload_model_blip_large(self): + + itm_model = load_model( + "blip_image_text_matching", + "large", + is_eval=True, + device=MultimodalSearch.multimodal_device, + ) + vis_processor = load_processor("blip_image_eval").build(image_size=384) + return itm_model, vis_processor + + def image_text_match_reordering( + self, + search_query, + itm_model_type, + image_keys, + sorted_lists, + batch_size=1, + need_grad_cam=False, + ): + choose_model = { + " ": MultimodalSearch.upload_model_blip_base, + "blip_large": MultimodalSearch.upload_model_blip_large, + "blip2_coco": MultimodalSearch.upload_model_blip2_coco, + } + itm_model, vis_processor = choose_model[itm_model_type](self) + text_processor = load_processor("blip_caption") + tokenizer = BlipBase.init_tokenizer() + + text_query_index = MultimodalSearch.itm_text_precessing(self, search_query) + + avg_gradcams = [] + itm_scores = [] + itm_scores2 = [] + image_gradcam_with_itm = {} + + for index_text_query in text_query_index: + query = search_query[index_text_query] + pathes, image_names = MultimodalSearch.get_pathes_from_query(self, query) + num_batches = int(len(pathes) / batch_size) + num_batches_residue = len(pathes) % batch_size + + local_itm_scores = [] + local_avg_gradcams = [] + + if num_batches_residue != 0: + num_batches = num_batches + 1 + for i in range(num_batches): + filenames_in_batch = pathes[i * batch_size : (i + 1) * batch_size] + current_len = len(filenames_in_batch) + raw_images, images = MultimodalSearch.read_and_process_images_itm( + self, filenames_in_batch, vis_processor + ) + queries_batch = [text_processor(query["text_input"])] * current_len + queries_tok_batch = tokenizer(queries_batch, return_tensors="pt").to( + MultimodalSearch.multimodal_device + ) + + if need_grad_cam: + gradcam, itm_output = MultimodalSearch.compute_gradcam_batch( + self, + itm_model_type, + itm_model, + images, + queries_batch, + queries_tok_batch, + ) + norm_imgs = [np.float32(r_img) / 255 for r_img in raw_images] + + for norm_img, grad_cam in zip(norm_imgs, gradcam): + avg_gradcam = MultimodalSearch.getAttMap( + self, norm_img, grad_cam[0], blur=True + ) + local_avg_gradcams.append(avg_gradcam) + + else: + itm_output = itm_model( + {"image": images, "text_input": queries_batch}, match_head="itm" + ) + + with torch.no_grad(): + itm_score = torch.nn.functional.softmax(itm_output, dim=1) + + local_itm_scores.append(itm_score) + + local_itm_scores2 = torch.cat(local_itm_scores)[:, 1] + if need_grad_cam: + localimage_gradcam_with_itm = { + n: i * 255 for n, i in zip(image_names, local_avg_gradcams) + } + else: + localimage_gradcam_with_itm = "" + image_names_with_itm = { + n: i.item() for n, i in zip(image_names, local_itm_scores2) + } + itm_rank = torch.argsort(local_itm_scores2, descending=True) + image_names_with_new_rank = { + image_names[i.item()]: rank + for i, rank in zip(itm_rank, range(len(itm_rank))) + } + for i, key in zip(range(len(image_keys)), sorted_lists[index_text_query]): + if image_keys[key] in image_names: + self[image_keys[key]][ + "itm " + list(search_query[index_text_query].values())[0] + ] = image_names_with_itm[image_keys[key]] + self[image_keys[key]][ + "itm_rank " + list(search_query[index_text_query].values())[0] + ] = image_names_with_new_rank[image_keys[key]] + else: + self[image_keys[key]][ + "itm " + list(search_query[index_text_query].values())[0] + ] = 0 + self[image_keys[key]][ + "itm_rank " + list(search_query[index_text_query].values())[0] + ] = None + + avg_gradcams.append(local_avg_gradcams) + itm_scores.append(local_itm_scores) + itm_scores2.append(local_itm_scores2) + image_gradcam_with_itm[ + list(search_query[index_text_query].values())[0] + ] = localimage_gradcam_with_itm + return itm_scores2, image_gradcam_with_itm + + def show_results(self, query, itm=False, image_gradcam_with_itm=False): if "image" in query.keys(): pic = Image.open(query["image"]).convert("RGB") pic.thumbnail((400, 400)) @@ -338,18 +613,29 @@ class MultimodalSearch(AnalysisMethod): "--------------------------------------------------", "Results:", ) + if itm: + current_querry_val = "itm " + list(query.values())[0] + current_querry_rank = "itm_rank " + list(query.values())[0] + else: + current_querry_val = list(query.values())[0] + current_querry_rank = "rank " + list(query.values())[0] + for s in sorted( - self.items(), key=lambda t: t[1][list(query.values())[0]], reverse=True + self.items(), key=lambda t: t[1][current_querry_val], reverse=True ): - if s[1]["rank " + list(query.values())[0]] is None: + if s[1][current_querry_rank] is None: break - p1 = Image.open(s[1]["filename"]).convert("RGB") + if image_gradcam_with_itm is False: + p1 = Image.open(s[1]["filename"]).convert("RGB") + else: + image = image_gradcam_with_itm[list(query.values())[0]][s[0]] + p1 = Image.fromarray(image.astype("uint8"), "RGB") p1.thumbnail((400, 400)) display( "Rank: " - + str(s[1]["rank " + list(query.values())[0]]) + + str(s[1][current_querry_rank]) + " Val: " - + str(s[1][list(query.values())[0]]), + + str(s[1][current_querry_val]), s[0], p1, ) diff --git a/notebooks/multimodal_search.ipynb b/notebooks/multimodal_search.ipynb index 24664a7..98c14f6 100644 --- a/notebooks/multimodal_search.ipynb +++ b/notebooks/multimodal_search.ipynb @@ -192,7 +192,7 @@ "metadata": {}, "outputs": [], "source": [ - "similarity = ms.MultimodalSearch.multimodal_search(\n", + "similarity, sorted_lists = ms.MultimodalSearch.multimodal_search(\n", " mydict,\n", " model,\n", " vis_processors,\n", @@ -201,6 +201,7 @@ " image_keys,\n", " features_image_stacked,\n", " search_query3,\n", + " filter_number_of_images=20,\n", ")" ] }, @@ -237,7 +238,68 @@ "metadata": {}, "outputs": [], "source": [ - "ms.MultimodalSearch.show_results(mydict, search_query3[4])" + "ms.MultimodalSearch.show_results(\n", + " mydict,\n", + " search_query3[5],\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "0b750e9f-fe64-4028-9caf-52d7187462f1", + "metadata": {}, + "source": [ + "For even better results, a slightly different approach has been prepared that can improve search results. It is quite resource-intensive, so it is applied after the main algorithm has found the most relevant images. This approach works only with text queries. Among the parameters you can choose 3 models: `\"blip_base\"`, `\"blip_large\"`, `\"blip2_coco\"`. If you get the Out of Memory error, try reducing the batch_size value (minimum = 1), which is the number of images being processed simultaneously. With the parameter `need_grad_cam = True/False` you can enable the calculation of the heat map of each image to be processed. Thus the `image_text_match_reordering` function calculates new similarity values and new ranks for each image. The resulting values are added to the general dictionary." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b3af7b39-6d0d-4da3-9b8f-7dfd3f5779be", + "metadata": {}, + "outputs": [], + "source": [ + "itm_model = \"blip_base\"\n", + "# itm_model = \"blip_large\"\n", + "# itm_model = \"blip2_coco\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "caf1f4ae-4b37-4954-800e-7120f0419de5", + "metadata": {}, + "outputs": [], + "source": [ + "itm_scores, image_gradcam_with_itm = ms.MultimodalSearch.image_text_match_reordering(\n", + " mydict,\n", + " search_query3,\n", + " itm_model,\n", + " image_keys,\n", + " sorted_lists,\n", + " batch_size=1,\n", + " need_grad_cam=True,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "9e98c150-5fab-4251-bce7-0d8fc7b385b9", + "metadata": {}, + "source": [ + "Then using the same output function you can add the `ITM=True` arguments to output the new image order. You can also add the `image_gradcam_with_itm` argument to output the heat maps of the calculated images. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6a829b99-5230-463a-8b11-30ffbb67fc3a", + "metadata": {}, + "outputs": [], + "source": [ + "ms.MultimodalSearch.show_results(\n", + " mydict, search_query3[0], itm=True, image_gradcam_with_itm=image_gradcam_with_itm\n", + ")" ] }, {