зеркало из
				https://github.com/ssciwr/AMMICO.git
				synced 2025-10-30 21:46:04 +02:00 
			
		
		
		
	
						Коммит
						80665e5f82
					
				
							
								
								
									
										2
									
								
								.github/workflows/ci.yml
									
									
									
									
										поставляемый
									
									
								
							
							
						
						
									
										2
									
								
								.github/workflows/ci.yml
									
									
									
									
										поставляемый
									
									
								
							| @ -32,7 +32,7 @@ jobs: | |||||||
|     - name: Run pytest |     - name: Run pytest | ||||||
|       run: | |       run: | | ||||||
|         cd misinformation |         cd misinformation | ||||||
|         python -m pytest -m "not gcv" -svv --cov=. --cov-report=xml |         python -m pytest -m "not gcv and not long" -svv --cov=. --cov-report=xml | ||||||
|     - name: Upload coverage |     - name: Upload coverage | ||||||
|       if: matrix.os == 'ubuntu-22.04' && matrix.python-version == '3.9' |       if: matrix.os == 'ubuntu-22.04' && matrix.python-version == '3.9' | ||||||
|       uses: codecov/codecov-action@v3 |       uses: codecov/codecov-action@v3 | ||||||
|  | |||||||
| @ -113,7 +113,7 @@ | |||||||
|     "    image_keys,\n", |     "    image_keys,\n", | ||||||
|     "    image_names,\n", |     "    image_names,\n", | ||||||
|     "    features_image_stacked,\n", |     "    features_image_stacked,\n", | ||||||
|     ") = ms.MultimodalSearch.parsing_images(mydict, model_type)" |     ") = ms.MultimodalSearch.parsing_images(mydict, model_type, path_to_saved_tensors=\".\")" | ||||||
|    ] |    ] | ||||||
|   }, |   }, | ||||||
|   { |   { | ||||||
| @ -128,7 +128,9 @@ | |||||||
|    "cell_type": "code", |    "cell_type": "code", | ||||||
|    "execution_count": null, |    "execution_count": null, | ||||||
|    "id": "56c6d488-f093-4661-835a-5c73a329c874", |    "id": "56c6d488-f093-4661-835a-5c73a329c874", | ||||||
|    "metadata": {}, |    "metadata": { | ||||||
|  |     "tags": [] | ||||||
|  |    }, | ||||||
|    "outputs": [], |    "outputs": [], | ||||||
|    "source": [ |    "source": [ | ||||||
|     "# (\n", |     "# (\n", | ||||||
|  | |||||||
| @ -4,9 +4,14 @@ import torch.nn.functional as Func | |||||||
| import requests | import requests | ||||||
| import lavis | import lavis | ||||||
| import os | import os | ||||||
|  | import numpy as np | ||||||
| from PIL import Image | from PIL import Image | ||||||
|  | from skimage import transform as skimage_transform | ||||||
|  | from scipy.ndimage import filters | ||||||
|  | from matplotlib import pyplot as plt | ||||||
| from IPython.display import display | from IPython.display import display | ||||||
| from lavis.models import load_model_and_preprocess | from lavis.models import load_model_and_preprocess, load_model, BlipBase | ||||||
|  | from lavis.processors import load_processor | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class MultimodalSearch(AnalysisMethod): | class MultimodalSearch(AnalysisMethod): | ||||||
| @ -233,7 +238,7 @@ class MultimodalSearch(AnalysisMethod): | |||||||
|         } |         } | ||||||
| 
 | 
 | ||||||
|         for query in search_query: |         for query in search_query: | ||||||
|             if not (len(query) == 1) and (query in ("image", "text_input")): |             if len(query) != 1 and (query in ("image", "text_input")): | ||||||
|                 raise SyntaxError( |                 raise SyntaxError( | ||||||
|                     'Each query must contain either an "image" or a "text_input"' |                     'Each query must contain either an "image" or a "text_input"' | ||||||
|                 ) |                 ) | ||||||
| @ -343,7 +348,288 @@ class MultimodalSearch(AnalysisMethod): | |||||||
|                     self[image_keys[key]][list(search_query[q].values())[0]] = 0 |                     self[image_keys[key]][list(search_query[q].values())[0]] = 0 | ||||||
|         return similarity, sorted_lists |         return similarity, sorted_lists | ||||||
| 
 | 
 | ||||||
|     def show_results(self, query): |     def itm_text_precessing(self, search_query): | ||||||
|  |         for query in search_query: | ||||||
|  |             if not (len(query) == 1) and (query in ("image", "text_input")): | ||||||
|  |                 raise SyntaxError( | ||||||
|  |                     'Each querry must contain either an "image" or a "text_input"' | ||||||
|  |                 ) | ||||||
|  |         text_query_index = [] | ||||||
|  |         for i, query in zip(range(len(search_query)), search_query): | ||||||
|  |             if "text_input" in query.keys(): | ||||||
|  |                 text_query_index.append(i) | ||||||
|  | 
 | ||||||
|  |         return text_query_index | ||||||
|  | 
 | ||||||
|  |     def get_pathes_from_query(self, query): | ||||||
|  |         paths = [] | ||||||
|  |         image_names = [] | ||||||
|  |         for s in sorted( | ||||||
|  |             self.items(), key=lambda t: t[1][list(query.values())[0]], reverse=True | ||||||
|  |         ): | ||||||
|  |             if s[1]["rank " + list(query.values())[0]] is None: | ||||||
|  |                 break | ||||||
|  |             paths.append(s[1]["filename"]) | ||||||
|  |             image_names.append(s[0]) | ||||||
|  |         return paths, image_names | ||||||
|  | 
 | ||||||
|  |     def read_and_process_images_itm(self, image_paths, vis_processor): | ||||||
|  |         raw_images = [MultimodalSearch.read_img(self, path) for path in image_paths] | ||||||
|  |         images = [vis_processor(r_img) for r_img in raw_images] | ||||||
|  |         images_tensors = torch.stack(images).to(MultimodalSearch.multimodal_device) | ||||||
|  | 
 | ||||||
|  |         return raw_images, images_tensors | ||||||
|  | 
 | ||||||
|  |     def compute_gradcam_batch( | ||||||
|  |         self, | ||||||
|  |         model, | ||||||
|  |         visual_input, | ||||||
|  |         text_input, | ||||||
|  |         tokenized_text, | ||||||
|  |         block_num=6, | ||||||
|  |     ): | ||||||
|  |         model.text_encoder.base_model.base_model.encoder.layer[ | ||||||
|  |             block_num | ||||||
|  |         ].crossattention.self.save_attention = True | ||||||
|  | 
 | ||||||
|  |         output = model( | ||||||
|  |             {"image": visual_input, "text_input": text_input}, match_head="itm" | ||||||
|  |         ) | ||||||
|  |         loss = output[:, 1].sum() | ||||||
|  | 
 | ||||||
|  |         model.zero_grad() | ||||||
|  |         loss.backward() | ||||||
|  |         with torch.no_grad(): | ||||||
|  |             mask = tokenized_text.attention_mask.view( | ||||||
|  |                 tokenized_text.attention_mask.size(0), 1, -1, 1, 1 | ||||||
|  |             )  # (bsz,1,token_len, 1,1) | ||||||
|  |             token_length = mask.sum() - 2 | ||||||
|  |             token_length = token_length.cpu() | ||||||
|  |             # grads and cams [bsz, num_head, seq_len, image_patch] | ||||||
|  |             grads = model.text_encoder.base_model.base_model.encoder.layer[ | ||||||
|  |                 block_num | ||||||
|  |             ].crossattention.self.get_attn_gradients() | ||||||
|  |             cams = model.text_encoder.base_model.base_model.encoder.layer[ | ||||||
|  |                 block_num | ||||||
|  |             ].crossattention.self.get_attention_map() | ||||||
|  | 
 | ||||||
|  |             # assume using vit large with 576 num image patch | ||||||
|  |             cams = ( | ||||||
|  |                 cams[:, :, :, 1:].reshape(visual_input.size(0), 12, -1, 24, 24) * mask | ||||||
|  |             ) | ||||||
|  |             grads = ( | ||||||
|  |                 grads[:, :, :, 1:] | ||||||
|  |                 .clamp(0) | ||||||
|  |                 .reshape(visual_input.size(0), 12, -1, 24, 24) | ||||||
|  |                 * mask | ||||||
|  |             ) | ||||||
|  | 
 | ||||||
|  |             gradcam = cams * grads | ||||||
|  |             # [enc token gradcam, average gradcam across token, gradcam for individual token] | ||||||
|  |             # gradcam = torch.cat((gradcam[0:1,:], gradcam[1:token_length+1, :].sum(dim=0, keepdim=True)/token_length, gradcam[1:, :])) | ||||||
|  |             gradcam = gradcam.mean(1).cpu().detach() | ||||||
|  |             gradcam = ( | ||||||
|  |                 gradcam[:, 1 : token_length + 1, :].sum(dim=1, keepdim=True) | ||||||
|  |                 / token_length | ||||||
|  |             ) | ||||||
|  | 
 | ||||||
|  |         return gradcam, output | ||||||
|  | 
 | ||||||
|  |     def resize_img(self, raw_img): | ||||||
|  |         w, h = raw_img.size | ||||||
|  |         scaling_factor = 240 / w | ||||||
|  |         resized_image = raw_img.resize( | ||||||
|  |             (int(w * scaling_factor), int(h * scaling_factor)) | ||||||
|  |         ) | ||||||
|  |         return resized_image | ||||||
|  | 
 | ||||||
|  |     def get_att_map(self, img, att_map, blur=True, overlap=True): | ||||||
|  |         att_map -= att_map.min() | ||||||
|  |         if att_map.max() > 0: | ||||||
|  |             att_map /= att_map.max() | ||||||
|  |         att_map = skimage_transform.resize( | ||||||
|  |             att_map, (img.shape[:2]), order=3, mode="constant" | ||||||
|  |         ) | ||||||
|  |         if blur: | ||||||
|  |             att_map = filters.gaussian_filter(att_map, 0.02 * max(img.shape[:2])) | ||||||
|  |             att_map -= att_map.min() | ||||||
|  |             att_map /= att_map.max() | ||||||
|  |         cmap = plt.get_cmap("jet") | ||||||
|  |         att_mapv = cmap(att_map) | ||||||
|  |         att_mapv = np.delete(att_mapv, 3, 2) | ||||||
|  |         if overlap: | ||||||
|  |             att_map = ( | ||||||
|  |                 1 * (1 - att_map**0.7).reshape(att_map.shape + (1,)) * img | ||||||
|  |                 + (att_map**0.7).reshape(att_map.shape + (1,)) * att_mapv | ||||||
|  |             ) | ||||||
|  |         return att_map | ||||||
|  | 
 | ||||||
|  |     def upload_model_blip2_coco(self): | ||||||
|  |         itm_model = load_model( | ||||||
|  |             "blip2_image_text_matching", | ||||||
|  |             "coco", | ||||||
|  |             is_eval=True, | ||||||
|  |             device=MultimodalSearch.multimodal_device, | ||||||
|  |         ) | ||||||
|  |         vis_processor = load_processor("blip_image_eval").build(image_size=364) | ||||||
|  |         return itm_model, vis_processor | ||||||
|  | 
 | ||||||
|  |     def upload_model_blip_base(self): | ||||||
|  |         itm_model = load_model( | ||||||
|  |             "blip_image_text_matching", | ||||||
|  |             "base", | ||||||
|  |             is_eval=True, | ||||||
|  |             device=MultimodalSearch.multimodal_device, | ||||||
|  |         ) | ||||||
|  |         vis_processor = load_processor("blip_image_eval").build(image_size=384) | ||||||
|  |         return itm_model, vis_processor | ||||||
|  | 
 | ||||||
|  |     def upload_model_blip_large(self): | ||||||
|  |         itm_model = load_model( | ||||||
|  |             "blip_image_text_matching", | ||||||
|  |             "large", | ||||||
|  |             is_eval=True, | ||||||
|  |             device=MultimodalSearch.multimodal_device, | ||||||
|  |         ) | ||||||
|  |         vis_processor = load_processor("blip_image_eval").build(image_size=384) | ||||||
|  |         return itm_model, vis_processor | ||||||
|  | 
 | ||||||
|  |     def image_text_match_reordering( | ||||||
|  |         self, | ||||||
|  |         search_query, | ||||||
|  |         itm_model_type, | ||||||
|  |         image_keys, | ||||||
|  |         sorted_lists, | ||||||
|  |         batch_size=1, | ||||||
|  |         need_grad_cam=False, | ||||||
|  |     ): | ||||||
|  |         if itm_model_type == "blip2_coco" and need_grad_cam is True: | ||||||
|  |             raise SyntaxError( | ||||||
|  |                 "The blip2_coco model does not yet work with gradcam. Please set need_grad_cam to False" | ||||||
|  |             ) | ||||||
|  | 
 | ||||||
|  |         choose_model = { | ||||||
|  |             "blip_base": MultimodalSearch.upload_model_blip_base, | ||||||
|  |             "blip_large": MultimodalSearch.upload_model_blip_large, | ||||||
|  |             "blip2_coco": MultimodalSearch.upload_model_blip2_coco, | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         itm_model, vis_processor_itm = choose_model[itm_model_type](self) | ||||||
|  |         text_processor = load_processor("blip_caption") | ||||||
|  |         tokenizer = BlipBase.init_tokenizer() | ||||||
|  | 
 | ||||||
|  |         if itm_model_type == "blip2_coco": | ||||||
|  |             need_grad_cam = False | ||||||
|  | 
 | ||||||
|  |         text_query_index = MultimodalSearch.itm_text_precessing(self, search_query) | ||||||
|  | 
 | ||||||
|  |         avg_gradcams = [] | ||||||
|  |         itm_scores = [] | ||||||
|  |         itm_scores2 = [] | ||||||
|  |         image_gradcam_with_itm = {} | ||||||
|  | 
 | ||||||
|  |         for index_text_query in text_query_index: | ||||||
|  |             query = search_query[index_text_query] | ||||||
|  |             pathes, image_names = MultimodalSearch.get_pathes_from_query(self, query) | ||||||
|  |             num_batches = int(len(pathes) / batch_size) | ||||||
|  |             num_batches_residue = len(pathes) % batch_size | ||||||
|  | 
 | ||||||
|  |             local_itm_scores = [] | ||||||
|  |             local_avg_gradcams = [] | ||||||
|  | 
 | ||||||
|  |             if num_batches_residue != 0: | ||||||
|  |                 num_batches = num_batches + 1 | ||||||
|  |             for i in range(num_batches): | ||||||
|  |                 filenames_in_batch = pathes[i * batch_size : (i + 1) * batch_size] | ||||||
|  |                 current_len = len(filenames_in_batch) | ||||||
|  |                 raw_images, images = MultimodalSearch.read_and_process_images_itm( | ||||||
|  |                     self, filenames_in_batch, vis_processor_itm | ||||||
|  |                 ) | ||||||
|  |                 queries_batch = [text_processor(query["text_input"])] * current_len | ||||||
|  |                 queries_tok_batch = tokenizer(queries_batch, return_tensors="pt").to( | ||||||
|  |                     MultimodalSearch.multimodal_device | ||||||
|  |                 ) | ||||||
|  | 
 | ||||||
|  |                 if need_grad_cam: | ||||||
|  |                     gradcam, itm_output = MultimodalSearch.compute_gradcam_batch( | ||||||
|  |                         self, | ||||||
|  |                         itm_model, | ||||||
|  |                         images, | ||||||
|  |                         queries_batch, | ||||||
|  |                         queries_tok_batch, | ||||||
|  |                     ) | ||||||
|  |                     norm_imgs = [np.float32(r_img) / 255 for r_img in raw_images] | ||||||
|  | 
 | ||||||
|  |                     for norm_img, grad_cam in zip(norm_imgs, gradcam): | ||||||
|  |                         avg_gradcam = MultimodalSearch.get_att_map( | ||||||
|  |                             self, norm_img, np.float32(grad_cam[0]), blur=True | ||||||
|  |                         ) | ||||||
|  |                         local_avg_gradcams.append(avg_gradcam) | ||||||
|  | 
 | ||||||
|  |                 else: | ||||||
|  |                     itm_output = itm_model( | ||||||
|  |                         {"image": images, "text_input": queries_batch}, match_head="itm" | ||||||
|  |                     ) | ||||||
|  | 
 | ||||||
|  |                 with torch.no_grad(): | ||||||
|  |                     itm_score = torch.nn.functional.softmax(itm_output, dim=1) | ||||||
|  | 
 | ||||||
|  |                 local_itm_scores.append(itm_score) | ||||||
|  | 
 | ||||||
|  |             local_itm_scores2 = torch.cat(local_itm_scores)[:, 1] | ||||||
|  |             if need_grad_cam: | ||||||
|  |                 localimage_gradcam_with_itm = { | ||||||
|  |                     n: i * 255 for n, i in zip(image_names, local_avg_gradcams) | ||||||
|  |                 } | ||||||
|  |             else: | ||||||
|  |                 localimage_gradcam_with_itm = "" | ||||||
|  |             image_names_with_itm = { | ||||||
|  |                 n: i.item() for n, i in zip(image_names, local_itm_scores2) | ||||||
|  |             } | ||||||
|  |             itm_rank = torch.argsort(local_itm_scores2, descending=True) | ||||||
|  |             image_names_with_new_rank = { | ||||||
|  |                 image_names[i.item()]: rank | ||||||
|  |                 for i, rank in zip(itm_rank, range(len(itm_rank))) | ||||||
|  |             } | ||||||
|  |             for i, key in zip(range(len(image_keys)), sorted_lists[index_text_query]): | ||||||
|  |                 if image_keys[key] in image_names: | ||||||
|  |                     self[image_keys[key]][ | ||||||
|  |                         "itm " + list(search_query[index_text_query].values())[0] | ||||||
|  |                     ] = image_names_with_itm[image_keys[key]] | ||||||
|  |                     self[image_keys[key]][ | ||||||
|  |                         "itm_rank " + list(search_query[index_text_query].values())[0] | ||||||
|  |                     ] = image_names_with_new_rank[image_keys[key]] | ||||||
|  |                 else: | ||||||
|  |                     self[image_keys[key]][ | ||||||
|  |                         "itm " + list(search_query[index_text_query].values())[0] | ||||||
|  |                     ] = 0 | ||||||
|  |                     self[image_keys[key]][ | ||||||
|  |                         "itm_rank " + list(search_query[index_text_query].values())[0] | ||||||
|  |                     ] = None | ||||||
|  | 
 | ||||||
|  |             avg_gradcams.append(local_avg_gradcams) | ||||||
|  |             itm_scores.append(local_itm_scores) | ||||||
|  |             itm_scores2.append(local_itm_scores2) | ||||||
|  |             image_gradcam_with_itm[ | ||||||
|  |                 list(search_query[index_text_query].values())[0] | ||||||
|  |             ] = localimage_gradcam_with_itm | ||||||
|  |         del ( | ||||||
|  |             itm_model, | ||||||
|  |             vis_processor_itm, | ||||||
|  |             text_processor, | ||||||
|  |             raw_images, | ||||||
|  |             images, | ||||||
|  |             tokenizer, | ||||||
|  |             queries_batch, | ||||||
|  |             queries_tok_batch, | ||||||
|  |             itm_score, | ||||||
|  |         ) | ||||||
|  |         if need_grad_cam: | ||||||
|  |             del itm_output, gradcam, norm_img, grad_cam, avg_gradcam | ||||||
|  |         torch.cuda.empty_cache() | ||||||
|  |         return itm_scores2, image_gradcam_with_itm | ||||||
|  | 
 | ||||||
|  |     def show_results(self, query, itm=False, image_gradcam_with_itm=False): | ||||||
|         if "image" in query.keys(): |         if "image" in query.keys(): | ||||||
|             pic = Image.open(query["image"]).convert("RGB") |             pic = Image.open(query["image"]).convert("RGB") | ||||||
|             pic.thumbnail((400, 400)) |             pic.thumbnail((400, 400)) | ||||||
| @ -359,18 +645,29 @@ class MultimodalSearch(AnalysisMethod): | |||||||
|                 "--------------------------------------------------", |                 "--------------------------------------------------", | ||||||
|                 "Results:", |                 "Results:", | ||||||
|             ) |             ) | ||||||
|  |         if itm: | ||||||
|  |             current_querry_val = "itm " + list(query.values())[0] | ||||||
|  |             current_querry_rank = "itm_rank " + list(query.values())[0] | ||||||
|  |         else: | ||||||
|  |             current_querry_val = list(query.values())[0] | ||||||
|  |             current_querry_rank = "rank " + list(query.values())[0] | ||||||
|  | 
 | ||||||
|         for s in sorted( |         for s in sorted( | ||||||
|             self.items(), key=lambda t: t[1][list(query.values())[0]], reverse=True |             self.items(), key=lambda t: t[1][current_querry_val], reverse=True | ||||||
|         ): |         ): | ||||||
|             if s[1]["rank " + list(query.values())[0]] is None: |             if s[1][current_querry_rank] is None: | ||||||
|                 break |                 break | ||||||
|             p1 = Image.open(s[1]["filename"]).convert("RGB") |             if image_gradcam_with_itm is False: | ||||||
|  |                 p1 = Image.open(s[1]["filename"]).convert("RGB") | ||||||
|  |             else: | ||||||
|  |                 image = image_gradcam_with_itm[list(query.values())[0]][s[0]] | ||||||
|  |                 p1 = Image.fromarray(image.astype("uint8"), "RGB") | ||||||
|             p1.thumbnail((400, 400)) |             p1.thumbnail((400, 400)) | ||||||
|             display( |             display( | ||||||
|                 "Rank: " |                 "Rank: " | ||||||
|                 + str(s[1]["rank " + list(query.values())[0]]) |                 + str(s[1][current_querry_rank]) | ||||||
|                 + " Val: " |                 + " Val: " | ||||||
|                 + str(s[1][list(query.values())[0]]), |                 + str(s[1][current_querry_val]), | ||||||
|                 s[0], |                 s[0], | ||||||
|                 p1, |                 p1, | ||||||
|             ) |             ) | ||||||
|  | |||||||
| @ -16,3 +16,33 @@ def set_environ(request): | |||||||
|         mypath + "/../../data/seismic-bonfire-329406-412821a70264.json" |         mypath + "/../../data/seismic-bonfire-329406-412821a70264.json" | ||||||
|     ) |     ) | ||||||
|     print(os.environ.get("GOOGLE_APPLICATION_CREDENTIALS")) |     print(os.environ.get("GOOGLE_APPLICATION_CREDENTIALS")) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | @pytest.fixture | ||||||
|  | def get_testdict(get_path): | ||||||
|  |     testdict = { | ||||||
|  |         "IMG_2746": {"filename": get_path + "IMG_2746.png"}, | ||||||
|  |         "IMG_2809": {"filename": get_path + "IMG_2809.png"}, | ||||||
|  |     } | ||||||
|  |     return testdict | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | @pytest.fixture | ||||||
|  | def get_test_my_dict(get_path): | ||||||
|  |     test_my_dict = { | ||||||
|  |         "IMG_2746": { | ||||||
|  |             "filename": get_path + "IMG_2746.png", | ||||||
|  |             "rank A bus": 1, | ||||||
|  |             "A bus": 0.15640679001808167, | ||||||
|  |             "rank " + get_path + "IMG_3758.png": 1, | ||||||
|  |             get_path + "IMG_3758.png": 0.7533495426177979, | ||||||
|  |         }, | ||||||
|  |         "IMG_2809": { | ||||||
|  |             "filename": get_path + "IMG_2809.png", | ||||||
|  |             "rank A bus": 0, | ||||||
|  |             "A bus": 0.1970970332622528, | ||||||
|  |             "rank " + get_path + "IMG_3758.png": 0, | ||||||
|  |             get_path + "IMG_3758.png": 0.8907483816146851, | ||||||
|  |         }, | ||||||
|  |     } | ||||||
|  |     return test_my_dict | ||||||
|  | |||||||
| @ -5,22 +5,17 @@ import numpy | |||||||
| from torch import device, cuda | from torch import device, cuda | ||||||
| import misinformation.multimodal_search as ms | import misinformation.multimodal_search as ms | ||||||
| 
 | 
 | ||||||
| 
 |  | ||||||
| testdict = { |  | ||||||
|     "IMG_2746": {"filename": "./test/data/IMG_2746.png"}, |  | ||||||
|     "IMG_2809": {"filename": "./test/data/IMG_2809.png"}, |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| related_error = 1e-2 | related_error = 1e-2 | ||||||
| gpu_is_not_available = not cuda.is_available() | gpu_is_not_available = not cuda.is_available() | ||||||
| 
 | 
 | ||||||
| 
 |  | ||||||
| cuda.empty_cache() | cuda.empty_cache() | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def test_read_img(): | def test_read_img(get_testdict): | ||||||
|     my_dict = {} |     my_dict = {} | ||||||
|     test_img = ms.MultimodalSearch.read_img(my_dict, testdict["IMG_2746"]["filename"]) |     test_img = ms.MultimodalSearch.read_img( | ||||||
|  |         my_dict, get_testdict["IMG_2746"]["filename"] | ||||||
|  |     ) | ||||||
|     assert list(numpy.array(test_img)[257][34]) == [70, 66, 63] |     assert list(numpy.array(test_img)[257][34]) == [70, 66, 63] | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| @ -205,29 +200,29 @@ dict_image_gradcam_with_itm_for_blip = { | |||||||
|         "pre_sorted", |         "pre_sorted", | ||||||
|     ), |     ), | ||||||
|     [ |     [ | ||||||
|         # ( |         ( | ||||||
|         #     device("cpu"), |             device("cpu"), | ||||||
|         #     "blip2", |             "blip2", | ||||||
|         #     pre_proc_pic_blip2_blip_albef, |             pre_proc_pic_blip2_blip_albef, | ||||||
|         #     pre_proc_text_blip2_blip_albef, |             pre_proc_text_blip2_blip_albef, | ||||||
|         #     pre_extracted_feature_img_blip2, |             pre_extracted_feature_img_blip2, | ||||||
|         #     pre_extracted_feature_text_blip2, |             pre_extracted_feature_text_blip2, | ||||||
|         #     simularity_blip2, |             simularity_blip2, | ||||||
|         #     sorted_blip2, |             sorted_blip2, | ||||||
|         # ), |         ), | ||||||
|         # pytest.param( |         pytest.param( | ||||||
|         #     device("cuda"), |             device("cuda"), | ||||||
|         #     "blip2", |             "blip2", | ||||||
|         #     pre_proc_pic_blip2_blip_albef, |             pre_proc_pic_blip2_blip_albef, | ||||||
|         #     pre_proc_text_blip2_blip_albef, |             pre_proc_text_blip2_blip_albef, | ||||||
|         #     pre_extracted_feature_img_blip2, |             pre_extracted_feature_img_blip2, | ||||||
|         #     pre_extracted_feature_text_blip2, |             pre_extracted_feature_text_blip2, | ||||||
|         #     simularity_blip2, |             simularity_blip2, | ||||||
|         #     sorted_blip2, |             sorted_blip2, | ||||||
|         #     marks=pytest.mark.skipif( |             marks=pytest.mark.skipif( | ||||||
|         #         gpu_is_not_available, reason="gpu_is_not_availible" |                 gpu_is_not_available, reason="gpu_is_not_availible" | ||||||
|         #     ), |             ), | ||||||
|         # ), |         ), | ||||||
|         ( |         ( | ||||||
|             device("cpu"), |             device("cpu"), | ||||||
|             "blip", |             "blip", | ||||||
| @ -354,6 +349,8 @@ def test_parsing_images( | |||||||
|     pre_extracted_feature_text, |     pre_extracted_feature_text, | ||||||
|     pre_simularity, |     pre_simularity, | ||||||
|     pre_sorted, |     pre_sorted, | ||||||
|  |     get_path, | ||||||
|  |     get_testdict, | ||||||
|     tmp_path, |     tmp_path, | ||||||
| ): | ): | ||||||
|     ms.MultimodalSearch.multimodal_device = pre_multimodal_device |     ms.MultimodalSearch.multimodal_device = pre_multimodal_device | ||||||
| @ -365,7 +362,7 @@ def test_parsing_images( | |||||||
|         _, |         _, | ||||||
|         features_image_stacked, |         features_image_stacked, | ||||||
|     ) = ms.MultimodalSearch.parsing_images( |     ) = ms.MultimodalSearch.parsing_images( | ||||||
|         testdict, pre_model, path_to_saved_tensors=tmp_path |         get_testdict, pre_model, path_to_saved_tensors=tmp_path | ||||||
|     ) |     ) | ||||||
| 
 | 
 | ||||||
|     for i, num in zip(range(10), features_image_stacked[0, 10:12].tolist()): |     for i, num in zip(range(10), features_image_stacked[0, 10:12].tolist()): | ||||||
| @ -374,7 +371,7 @@ def test_parsing_images( | |||||||
|             is True |             is True | ||||||
|         ) |         ) | ||||||
| 
 | 
 | ||||||
|     test_pic = Image.open(testdict["IMG_2746"]["filename"]).convert("RGB") |     test_pic = Image.open(get_testdict["IMG_2746"]["filename"]).convert("RGB") | ||||||
|     test_querry = ( |     test_querry = ( | ||||||
|         "The bird sat on a tree located at the intersection of 23rd and 43rd streets." |         "The bird sat on a tree located at the intersection of 23rd and 43rd streets." | ||||||
|     ) |     ) | ||||||
| @ -390,10 +387,10 @@ def test_parsing_images( | |||||||
| 
 | 
 | ||||||
|     search_query = [ |     search_query = [ | ||||||
|         {"text_input": test_querry}, |         {"text_input": test_querry}, | ||||||
|         {"image": testdict["IMG_2746"]["filename"]}, |         {"image": get_testdict["IMG_2746"]["filename"]}, | ||||||
|     ] |     ] | ||||||
|     multi_features_stacked = ms.MultimodalSearch.querys_processing( |     multi_features_stacked = ms.MultimodalSearch.querys_processing( | ||||||
|         testdict, search_query, model, txt_processor, vis_processor, pre_model |         get_testdict, search_query, model, txt_processor, vis_processor, pre_model | ||||||
|     ) |     ) | ||||||
| 
 | 
 | ||||||
|     for i, num in zip(range(10), multi_features_stacked[0, 10:12].tolist()): |     for i, num in zip(range(10), multi_features_stacked[0, 10:12].tolist()): | ||||||
| @ -410,11 +407,11 @@ def test_parsing_images( | |||||||
| 
 | 
 | ||||||
|     search_query2 = [ |     search_query2 = [ | ||||||
|         {"text_input": "A bus"}, |         {"text_input": "A bus"}, | ||||||
|         {"image": "../misinformation/test/data/IMG_3758.png"}, |         {"image": get_path + "IMG_3758.png"}, | ||||||
|     ] |     ] | ||||||
| 
 | 
 | ||||||
|     similarity, sorted_list = ms.MultimodalSearch.multimodal_search( |     similarity, sorted_list = ms.MultimodalSearch.multimodal_search( | ||||||
|         testdict, |         get_testdict, | ||||||
|         model, |         model, | ||||||
|         vis_processor, |         vis_processor, | ||||||
|         txt_processor, |         txt_processor, | ||||||
| @ -445,3 +442,81 @@ def test_parsing_images( | |||||||
|         multi_features_stacked, |         multi_features_stacked, | ||||||
|     ) |     ) | ||||||
|     cuda.empty_cache() |     cuda.empty_cache() | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | @pytest.mark.long | ||||||
|  | def test_itm(get_test_my_dict, get_path): | ||||||
|  |     search_query3 = [ | ||||||
|  |         {"text_input": "A bus"}, | ||||||
|  |         {"image": get_path + "IMG_3758.png"}, | ||||||
|  |     ] | ||||||
|  |     image_keys = ["IMG_2746", "IMG_2809"] | ||||||
|  |     sorted_list = [[1, 0], [1, 0]] | ||||||
|  |     for itm_model in ["blip_base", "blip_large"]: | ||||||
|  |         ( | ||||||
|  |             itm_scores, | ||||||
|  |             image_gradcam_with_itm, | ||||||
|  |         ) = ms.MultimodalSearch.image_text_match_reordering( | ||||||
|  |             get_test_my_dict, | ||||||
|  |             search_query3, | ||||||
|  |             itm_model, | ||||||
|  |             image_keys, | ||||||
|  |             sorted_list, | ||||||
|  |             batch_size=1, | ||||||
|  |             need_grad_cam=True, | ||||||
|  |         ) | ||||||
|  |         for i, itm in zip( | ||||||
|  |             range(len(dict_itm_scores_for_blib[itm_model])), | ||||||
|  |             dict_itm_scores_for_blib[itm_model], | ||||||
|  |         ): | ||||||
|  |             assert ( | ||||||
|  |                 math.isclose(itm_scores[0].tolist()[i], itm, rel_tol=10 * related_error) | ||||||
|  |                 is True | ||||||
|  |             ) | ||||||
|  |         for i, grad_cam in zip( | ||||||
|  |             range(len(dict_image_gradcam_with_itm_for_blip[itm_model])), | ||||||
|  |             dict_image_gradcam_with_itm_for_blip[itm_model], | ||||||
|  |         ): | ||||||
|  |             assert ( | ||||||
|  |                 math.isclose( | ||||||
|  |                     image_gradcam_with_itm["A bus"]["IMG_2809"][0][0].tolist()[i], | ||||||
|  |                     grad_cam, | ||||||
|  |                     rel_tol=10 * related_error, | ||||||
|  |                 ) | ||||||
|  |                 is True | ||||||
|  |             ) | ||||||
|  |         del itm_scores, image_gradcam_with_itm | ||||||
|  |         cuda.empty_cache() | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | @pytest.mark.long | ||||||
|  | def test_itm_blip2_coco(get_test_my_dict, get_path): | ||||||
|  |     search_query3 = [ | ||||||
|  |         {"text_input": "A bus"}, | ||||||
|  |         {"image": get_path + "IMG_3758.png"}, | ||||||
|  |     ] | ||||||
|  |     image_keys = ["IMG_2746", "IMG_2809"] | ||||||
|  |     sorted_list = [[1, 0], [1, 0]] | ||||||
|  | 
 | ||||||
|  |     ( | ||||||
|  |         itm_scores, | ||||||
|  |         image_gradcam_with_itm, | ||||||
|  |     ) = ms.MultimodalSearch.image_text_match_reordering( | ||||||
|  |         get_test_my_dict, | ||||||
|  |         search_query3, | ||||||
|  |         "blip2_coco", | ||||||
|  |         image_keys, | ||||||
|  |         sorted_list, | ||||||
|  |         batch_size=1, | ||||||
|  |         need_grad_cam=False, | ||||||
|  |     ) | ||||||
|  |     for i, itm in zip( | ||||||
|  |         range(len(dict_itm_scores_for_blib["blip2_coco"])), | ||||||
|  |         dict_itm_scores_for_blib["blip2_coco"], | ||||||
|  |     ): | ||||||
|  |         assert ( | ||||||
|  |             math.isclose(itm_scores[0].tolist()[i], itm, rel_tol=10 * related_error) | ||||||
|  |             is True | ||||||
|  |         ) | ||||||
|  |     del itm_scores, image_gradcam_with_itm | ||||||
|  |     cuda.empty_cache() | ||||||
|  | |||||||
							
								
								
									
										115
									
								
								notebooks/multimodal_search.ipynb
									
									
									
										сгенерированный
									
									
									
								
							
							
						
						
									
										115
									
								
								notebooks/multimodal_search.ipynb
									
									
									
										сгенерированный
									
									
									
								
							| @ -138,9 +138,7 @@ | |||||||
|     "    image_keys,\n", |     "    image_keys,\n", | ||||||
|     "    image_names,\n", |     "    image_names,\n", | ||||||
|     "    features_image_stacked,\n", |     "    features_image_stacked,\n", | ||||||
|     ") = ms.MultimodalSearch.parsing_images(\n", |     ") = ms.MultimodalSearch.parsing_images(mydict, model_type, path_to_saved_tensors=\".\")" | ||||||
|     "    mydict, model_type, path_to_saved_tensors=\"./saved_tensors/\"\n", |  | ||||||
|     ")" |  | ||||||
|    ] |    ] | ||||||
|   }, |   }, | ||||||
|   { |   { | ||||||
| @ -170,7 +168,7 @@ | |||||||
|     "# ) = ms.MultimodalSearch.parsing_images(\n", |     "# ) = ms.MultimodalSearch.parsing_images(\n", | ||||||
|     "#     mydict,\n", |     "#     mydict,\n", | ||||||
|     "#     model_type,\n", |     "#     model_type,\n", | ||||||
|     "#     path_to_load_tensors=\"./saved_tensors/18_blip_saved_features_image.pt\",\n", |     "#     path_to_load_tensors=\".5_blip_saved_features_image.pt\",\n", | ||||||
|     "# )" |     "# )" | ||||||
|    ] |    ] | ||||||
|   }, |   }, | ||||||
| @ -194,15 +192,14 @@ | |||||||
|    "cell_type": "code", |    "cell_type": "code", | ||||||
|    "execution_count": null, |    "execution_count": null, | ||||||
|    "id": "c4196a52-d01e-42e4-8674-5712f7d6f792", |    "id": "c4196a52-d01e-42e4-8674-5712f7d6f792", | ||||||
|    "metadata": {}, |    "metadata": { | ||||||
|  |     "tags": [] | ||||||
|  |    }, | ||||||
|    "outputs": [], |    "outputs": [], | ||||||
|    "source": [ |    "source": [ | ||||||
|     "search_query3 = [\n", |     "search_query3 = [\n", | ||||||
|     "    {\"text_input\": \"politician press conference\"},\n", |     "    {\"text_input\": \"politician press conference\"},\n", | ||||||
|     "    {\"text_input\": \"a world map\"},\n", |     "    {\"text_input\": \"a world map\"},\n", | ||||||
|     "    {\"image\": \"../data/haos.png\"},\n", |  | ||||||
|     "    {\"image\": \"../data/image-34098-800.png\"},\n", |  | ||||||
|     "    {\"image\": \"../data/LeonPresserMorocco20032015_600.png\"},\n", |  | ||||||
|     "    {\"text_input\": \"a dog\"},\n", |     "    {\"text_input\": \"a dog\"},\n", | ||||||
|     "]" |     "]" | ||||||
|    ] |    ] | ||||||
| @ -222,10 +219,12 @@ | |||||||
|    "cell_type": "code", |    "cell_type": "code", | ||||||
|    "execution_count": null, |    "execution_count": null, | ||||||
|    "id": "7f7dc52f-7ee9-4590-96b7-e0d9d3b82378", |    "id": "7f7dc52f-7ee9-4590-96b7-e0d9d3b82378", | ||||||
|    "metadata": {}, |    "metadata": { | ||||||
|  |     "tags": [] | ||||||
|  |    }, | ||||||
|    "outputs": [], |    "outputs": [], | ||||||
|    "source": [ |    "source": [ | ||||||
|     "similarity = ms.MultimodalSearch.multimodal_search(\n", |     "similarity, sorted_lists = ms.MultimodalSearch.multimodal_search(\n", | ||||||
|     "    mydict,\n", |     "    mydict,\n", | ||||||
|     "    model,\n", |     "    model,\n", | ||||||
|     "    vis_processors,\n", |     "    vis_processors,\n", | ||||||
| @ -234,6 +233,7 @@ | |||||||
|     "    image_keys,\n", |     "    image_keys,\n", | ||||||
|     "    features_image_stacked,\n", |     "    features_image_stacked,\n", | ||||||
|     "    search_query3,\n", |     "    search_query3,\n", | ||||||
|  |     "    filter_number_of_images=20,\n", | ||||||
|     ")" |     ")" | ||||||
|    ] |    ] | ||||||
|   }, |   }, | ||||||
| @ -249,10 +249,12 @@ | |||||||
|    "cell_type": "code", |    "cell_type": "code", | ||||||
|    "execution_count": null, |    "execution_count": null, | ||||||
|    "id": "9ad74b21-6187-4a58-9ed8-fd3e80f5a4ed", |    "id": "9ad74b21-6187-4a58-9ed8-fd3e80f5a4ed", | ||||||
|    "metadata": {}, |    "metadata": { | ||||||
|  |     "tags": [] | ||||||
|  |    }, | ||||||
|    "outputs": [], |    "outputs": [], | ||||||
|    "source": [ |    "source": [ | ||||||
|     "mydict[\"100127S_ara\"]" |     "mydict[\"109237S_spa\"]" | ||||||
|    ] |    ] | ||||||
|   }, |   }, | ||||||
|   { |   { | ||||||
| @ -267,10 +269,79 @@ | |||||||
|    "cell_type": "code", |    "cell_type": "code", | ||||||
|    "execution_count": null, |    "execution_count": null, | ||||||
|    "id": "4324e4fd-e9aa-4933-bb12-074d54e0c510", |    "id": "4324e4fd-e9aa-4933-bb12-074d54e0c510", | ||||||
|    "metadata": {}, |    "metadata": { | ||||||
|  |     "tags": [] | ||||||
|  |    }, | ||||||
|    "outputs": [], |    "outputs": [], | ||||||
|    "source": [ |    "source": [ | ||||||
|     "ms.MultimodalSearch.show_results(mydict, search_query3[4])" |     "ms.MultimodalSearch.show_results(\n", | ||||||
|  |     "    mydict,\n", | ||||||
|  |     "    search_query3[0],\n", | ||||||
|  |     ")" | ||||||
|  |    ] | ||||||
|  |   }, | ||||||
|  |   { | ||||||
|  |    "cell_type": "markdown", | ||||||
|  |    "id": "0b750e9f-fe64-4028-9caf-52d7187462f1", | ||||||
|  |    "metadata": {}, | ||||||
|  |    "source": [ | ||||||
|  |     "For even better results, a slightly different approach has been prepared that can improve search results. It is quite resource-intensive, so it is applied after the main algorithm has found the most relevant images. This approach works only with text queries. Among the parameters you can choose 3 models: `\"blip_base\"`, `\"blip_large\"`, `\"blip2_coco\"`. If you get the Out of Memory error, try reducing the batch_size value (minimum = 1), which is the number of images being processed simultaneously. With the parameter `need_grad_cam = True/False` you can enable the calculation of the heat map of each image to be processed. Thus the `image_text_match_reordering` function calculates new similarity values and new ranks for each image. The resulting values are added to the general dictionary." | ||||||
|  |    ] | ||||||
|  |   }, | ||||||
|  |   { | ||||||
|  |    "cell_type": "code", | ||||||
|  |    "execution_count": null, | ||||||
|  |    "id": "b3af7b39-6d0d-4da3-9b8f-7dfd3f5779be", | ||||||
|  |    "metadata": { | ||||||
|  |     "tags": [] | ||||||
|  |    }, | ||||||
|  |    "outputs": [], | ||||||
|  |    "source": [ | ||||||
|  |     "itm_model = \"blip_base\"\n", | ||||||
|  |     "# itm_model = \"blip_large\"\n", | ||||||
|  |     "# itm_model = \"blip2_coco\"" | ||||||
|  |    ] | ||||||
|  |   }, | ||||||
|  |   { | ||||||
|  |    "cell_type": "code", | ||||||
|  |    "execution_count": null, | ||||||
|  |    "id": "caf1f4ae-4b37-4954-800e-7120f0419de5", | ||||||
|  |    "metadata": { | ||||||
|  |     "tags": [] | ||||||
|  |    }, | ||||||
|  |    "outputs": [], | ||||||
|  |    "source": [ | ||||||
|  |     "itm_scores, image_gradcam_with_itm = ms.MultimodalSearch.image_text_match_reordering(\n", | ||||||
|  |     "    mydict,\n", | ||||||
|  |     "    search_query3,\n", | ||||||
|  |     "    itm_model,\n", | ||||||
|  |     "    image_keys,\n", | ||||||
|  |     "    sorted_lists,\n", | ||||||
|  |     "    batch_size=1,\n", | ||||||
|  |     "    need_grad_cam=True,\n", | ||||||
|  |     ")" | ||||||
|  |    ] | ||||||
|  |   }, | ||||||
|  |   { | ||||||
|  |    "cell_type": "markdown", | ||||||
|  |    "id": "9e98c150-5fab-4251-bce7-0d8fc7b385b9", | ||||||
|  |    "metadata": {}, | ||||||
|  |    "source": [ | ||||||
|  |     "Then using the same output function you can add the `ITM=True` arguments to output the new image order. You can also add the `image_gradcam_with_itm` argument to output the heat maps of the calculated images. " | ||||||
|  |    ] | ||||||
|  |   }, | ||||||
|  |   { | ||||||
|  |    "cell_type": "code", | ||||||
|  |    "execution_count": null, | ||||||
|  |    "id": "6a829b99-5230-463a-8b11-30ffbb67fc3a", | ||||||
|  |    "metadata": { | ||||||
|  |     "tags": [] | ||||||
|  |    }, | ||||||
|  |    "outputs": [], | ||||||
|  |    "source": [ | ||||||
|  |     "ms.MultimodalSearch.show_results(\n", | ||||||
|  |     "    mydict, search_query3[0], itm=True, image_gradcam_with_itm=image_gradcam_with_itm\n", | ||||||
|  |     ")" | ||||||
|    ] |    ] | ||||||
|   }, |   }, | ||||||
|   { |   { | ||||||
| @ -320,7 +391,9 @@ | |||||||
|    "cell_type": "code", |    "cell_type": "code", | ||||||
|    "execution_count": null, |    "execution_count": null, | ||||||
|    "id": "e78646d6-80be-4d3e-8123-3360957bcaa8", |    "id": "e78646d6-80be-4d3e-8123-3360957bcaa8", | ||||||
|    "metadata": {}, |    "metadata": { | ||||||
|  |     "tags": [] | ||||||
|  |    }, | ||||||
|    "outputs": [], |    "outputs": [], | ||||||
|    "source": [ |    "source": [ | ||||||
|     "df.head(10)" |     "df.head(10)" | ||||||
| @ -338,11 +411,21 @@ | |||||||
|    "cell_type": "code", |    "cell_type": "code", | ||||||
|    "execution_count": null, |    "execution_count": null, | ||||||
|    "id": "185f7dde-20dc-44d8-9ab0-de41f9b5734d", |    "id": "185f7dde-20dc-44d8-9ab0-de41f9b5734d", | ||||||
|    "metadata": {}, |    "metadata": { | ||||||
|  |     "tags": [] | ||||||
|  |    }, | ||||||
|    "outputs": [], |    "outputs": [], | ||||||
|    "source": [ |    "source": [ | ||||||
|     "df.to_csv(\"./data_out.csv\")" |     "df.to_csv(\"./data_out.csv\")" | ||||||
|    ] |    ] | ||||||
|  |   }, | ||||||
|  |   { | ||||||
|  |    "cell_type": "code", | ||||||
|  |    "execution_count": null, | ||||||
|  |    "id": "b6a79201-7c17-496c-a6a1-b8ecfd3dd1e8", | ||||||
|  |    "metadata": {}, | ||||||
|  |    "outputs": [], | ||||||
|  |    "source": [] | ||||||
|   } |   } | ||||||
|  ], |  ], | ||||||
|  "metadata": { |  "metadata": { | ||||||
|  | |||||||
		Загрузка…
	
	
			
			x
			
			
		
	
		Ссылка в новой задаче
	
	Block a user
	 Petr Andriushchenko
						Petr Andriushchenko