зеркало из
				https://github.com/ssciwr/AMMICO.git
				synced 2025-10-31 22:16:05 +02:00 
			
		
		
		
	fix code smells
Этот коммит содержится в:
		
							родитель
							
								
									4e4b7fac75
								
							
						
					
					
						Коммит
						c208039b7c
					
				| @ -15,7 +15,7 @@ class MultimodalSearch(AnalysisMethod): | |||||||
| 
 | 
 | ||||||
|     multimodal_device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |     multimodal_device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | ||||||
| 
 | 
 | ||||||
|     def load_feature_extractor_model_blip2(device): |     def load_feature_extractor_model_blip2(self, device): | ||||||
|         model, vis_processors, txt_processors = load_model_and_preprocess( |         model, vis_processors, txt_processors = load_model_and_preprocess( | ||||||
|             name="blip2_feature_extractor", |             name="blip2_feature_extractor", | ||||||
|             model_type="pretrain", |             model_type="pretrain", | ||||||
| @ -24,7 +24,7 @@ class MultimodalSearch(AnalysisMethod): | |||||||
|         ) |         ) | ||||||
|         return model, vis_processors, txt_processors |         return model, vis_processors, txt_processors | ||||||
| 
 | 
 | ||||||
|     def load_feature_extractor_model_blip(device): |     def load_feature_extractor_model_blip(self, device): | ||||||
|         model, vis_processors, txt_processors = load_model_and_preprocess( |         model, vis_processors, txt_processors = load_model_and_preprocess( | ||||||
|             name="blip_feature_extractor", |             name="blip_feature_extractor", | ||||||
|             model_type="base", |             model_type="base", | ||||||
| @ -33,7 +33,7 @@ class MultimodalSearch(AnalysisMethod): | |||||||
|         ) |         ) | ||||||
|         return model, vis_processors, txt_processors |         return model, vis_processors, txt_processors | ||||||
| 
 | 
 | ||||||
|     def load_feature_extractor_model_albef(device): |     def load_feature_extractor_model_albef(self, device): | ||||||
|         model, vis_processors, txt_processors = load_model_and_preprocess( |         model, vis_processors, txt_processors = load_model_and_preprocess( | ||||||
|             name="albef_feature_extractor", |             name="albef_feature_extractor", | ||||||
|             model_type="base", |             model_type="base", | ||||||
| @ -42,7 +42,7 @@ class MultimodalSearch(AnalysisMethod): | |||||||
|         ) |         ) | ||||||
|         return model, vis_processors, txt_processors |         return model, vis_processors, txt_processors | ||||||
| 
 | 
 | ||||||
|     def load_feature_extractor_model_clip_base(device): |     def load_feature_extractor_model_clip_base(self, device): | ||||||
|         model, vis_processors, txt_processors = load_model_and_preprocess( |         model, vis_processors, txt_processors = load_model_and_preprocess( | ||||||
|             name="clip_feature_extractor", |             name="clip_feature_extractor", | ||||||
|             model_type="base", |             model_type="base", | ||||||
| @ -51,7 +51,7 @@ class MultimodalSearch(AnalysisMethod): | |||||||
|         ) |         ) | ||||||
|         return model, vis_processors, txt_processors |         return model, vis_processors, txt_processors | ||||||
| 
 | 
 | ||||||
|     def load_feature_extractor_model_clip_vitl14(device): |     def load_feature_extractor_model_clip_vitl14(self, device): | ||||||
|         model, vis_processors, txt_processors = load_model_and_preprocess( |         model, vis_processors, txt_processors = load_model_and_preprocess( | ||||||
|             name="clip_feature_extractor", |             name="clip_feature_extractor", | ||||||
|             model_type="ViT-L-14", |             model_type="ViT-L-14", | ||||||
| @ -60,7 +60,7 @@ class MultimodalSearch(AnalysisMethod): | |||||||
|         ) |         ) | ||||||
|         return model, vis_processors, txt_processors |         return model, vis_processors, txt_processors | ||||||
| 
 | 
 | ||||||
|     def load_feature_extractor_model_clip_vitl14_336(device): |     def load_feature_extractor_model_clip_vitl14_336(self, device): | ||||||
|         model, vis_processors, txt_processors = load_model_and_preprocess( |         model, vis_processors, txt_processors = load_model_and_preprocess( | ||||||
|             name="clip_feature_extractor", |             name="clip_feature_extractor", | ||||||
|             model_type="ViT-L-14-336", |             model_type="ViT-L-14-336", | ||||||
| @ -69,12 +69,12 @@ class MultimodalSearch(AnalysisMethod): | |||||||
|         ) |         ) | ||||||
|         return model, vis_processors, txt_processors |         return model, vis_processors, txt_processors | ||||||
| 
 | 
 | ||||||
|     def read_img(filepath): |     def read_img(self, filepath): | ||||||
|         raw_image = Image.open(filepath).convert("RGB") |         raw_image = Image.open(filepath).convert("RGB") | ||||||
|         return raw_image |         return raw_image | ||||||
| 
 | 
 | ||||||
|     def read_and_process_images(image_paths, vis_processor): |     def read_and_process_images(self, image_paths, vis_processor): | ||||||
|         raw_images = [MultimodalSearch.read_img(path) for path in image_paths] |         raw_images = [MultimodalSearch.read_img(self, path) for path in image_paths] | ||||||
|         images = [ |         images = [ | ||||||
|             vis_processor["eval"](r_img) |             vis_processor["eval"](r_img) | ||||||
|             .unsqueeze(0) |             .unsqueeze(0) | ||||||
| @ -85,7 +85,7 @@ class MultimodalSearch(AnalysisMethod): | |||||||
| 
 | 
 | ||||||
|         return raw_images, images_tensors |         return raw_images, images_tensors | ||||||
| 
 | 
 | ||||||
|     def extract_image_features_blip2(model, images_tensors): |     def extract_image_features_blip2(self, model, images_tensors): | ||||||
|         with torch.cuda.amp.autocast( |         with torch.cuda.amp.autocast( | ||||||
|             enabled=(MultimodalSearch.multimodal_device != torch.device("cpu")) |             enabled=(MultimodalSearch.multimodal_device != torch.device("cpu")) | ||||||
|         ): |         ): | ||||||
| @ -98,7 +98,7 @@ class MultimodalSearch(AnalysisMethod): | |||||||
|             ) |             ) | ||||||
|         return features_image_stacked |         return features_image_stacked | ||||||
| 
 | 
 | ||||||
|     def extract_image_features_clip(model, images_tensors): |     def extract_image_features_clip(self, model, images_tensors): | ||||||
|         features_image = [ |         features_image = [ | ||||||
|             model.extract_features({"image": ten}) for ten in images_tensors |             model.extract_features({"image": ten}) for ten in images_tensors | ||||||
|         ] |         ] | ||||||
| @ -107,7 +107,7 @@ class MultimodalSearch(AnalysisMethod): | |||||||
|         ) |         ) | ||||||
|         return features_image_stacked |         return features_image_stacked | ||||||
| 
 | 
 | ||||||
|     def extract_image_features_basic(model, images_tensors): |     def extract_image_features_basic(self, model, images_tensors): | ||||||
|         features_image = [ |         features_image = [ | ||||||
|             model.extract_features({"image": ten, "text_input": ""}, mode="image") |             model.extract_features({"image": ten, "text_input": ""}, mode="image") | ||||||
|             for ten in images_tensors |             for ten in images_tensors | ||||||
| @ -118,7 +118,7 @@ class MultimodalSearch(AnalysisMethod): | |||||||
|         return features_image_stacked |         return features_image_stacked | ||||||
| 
 | 
 | ||||||
|     def save_tensors( |     def save_tensors( | ||||||
|         model_type, features_image_stacked, name="saved_features_image.pt" |         self, model_type, features_image_stacked, name="saved_features_image.pt" | ||||||
|     ): |     ): | ||||||
|         with open( |         with open( | ||||||
|             str(len(features_image_stacked)) + "_" + model_type + "_" + name, "wb" |             str(len(features_image_stacked)) + "_" + model_type + "_" + name, "wb" | ||||||
| @ -126,11 +126,11 @@ class MultimodalSearch(AnalysisMethod): | |||||||
|             torch.save(features_image_stacked, f) |             torch.save(features_image_stacked, f) | ||||||
|         return name |         return name | ||||||
| 
 | 
 | ||||||
|     def load_tensors(name="saved_features_image.pt"): |     def load_tensors(self, name="saved_features_image.pt"): | ||||||
|         features_image_stacked = torch.load(name) |         features_image_stacked = torch.load(name) | ||||||
|         return features_image_stacked |         return features_image_stacked | ||||||
| 
 | 
 | ||||||
|     def extract_text_features(model, text_input): |     def extract_text_features(self, model, text_input): | ||||||
|         sample_text = {"text_input": [text_input]} |         sample_text = {"text_input": [text_input]} | ||||||
|         features_text = model.extract_features(sample_text, mode="text") |         features_text = model.extract_features(sample_text, mode="text") | ||||||
| 
 | 
 | ||||||
| @ -168,24 +168,24 @@ class MultimodalSearch(AnalysisMethod): | |||||||
|         if model_type in select_model.keys(): |         if model_type in select_model.keys(): | ||||||
|             (model, vis_processors, txt_processors,) = select_model[ |             (model, vis_processors, txt_processors,) = select_model[ | ||||||
|                 model_type |                 model_type | ||||||
|             ](MultimodalSearch.multimodal_device) |             ](self, MultimodalSearch.multimodal_device) | ||||||
|         else: |         else: | ||||||
|             raise SyntaxError( |             raise SyntaxError( | ||||||
|                 "Please, use one of the following models: blip2, blip, albef, clip_base, clip_vitl14, clip_vitl14_336" |                 "Please, use one of the following models: blip2, blip, albef, clip_base, clip_vitl14, clip_vitl14_336" | ||||||
|             ) |             ) | ||||||
| 
 | 
 | ||||||
|         raw_images, images_tensors = MultimodalSearch.read_and_process_images( |         raw_images, images_tensors = MultimodalSearch.read_and_process_images( | ||||||
|             image_names, vis_processors |             self, image_names, vis_processors | ||||||
|         ) |         ) | ||||||
|         if path_to_saved_tensors is None: |         if path_to_saved_tensors is None: | ||||||
|             with torch.no_grad(): |             with torch.no_grad(): | ||||||
|                 features_image_stacked = select_extract_image_features[model_type]( |                 features_image_stacked = select_extract_image_features[model_type]( | ||||||
|                     model, images_tensors |                     self, model, images_tensors | ||||||
|                 ) |                 ) | ||||||
|             MultimodalSearch.save_tensors(model_type, features_image_stacked) |             MultimodalSearch.save_tensors(self, model_type, features_image_stacked) | ||||||
|         else: |         else: | ||||||
|             features_image_stacked = MultimodalSearch.load_tensors( |             features_image_stacked = MultimodalSearch.load_tensors( | ||||||
|                 str(path_to_saved_tensors) |                 self, str(path_to_saved_tensors) | ||||||
|             ) |             ) | ||||||
| 
 | 
 | ||||||
|         return ( |         return ( | ||||||
| @ -222,7 +222,7 @@ class MultimodalSearch(AnalysisMethod): | |||||||
|                 images_tensors = "" |                 images_tensors = "" | ||||||
|             elif "image" in query.keys(): |             elif "image" in query.keys(): | ||||||
|                 _, images_tensors = MultimodalSearch.read_and_process_images( |                 _, images_tensors = MultimodalSearch.read_and_process_images( | ||||||
|                     [query["image"]], vis_processors |                     self, [query["image"]], vis_processors | ||||||
|                 ) |                 ) | ||||||
|                 text_processing = "" |                 text_processing = "" | ||||||
|             multi_sample.append( |             multi_sample.append( | ||||||
| @ -253,7 +253,9 @@ class MultimodalSearch(AnalysisMethod): | |||||||
|                     multi_features_query.append(features_squeeze) |                     multi_features_query.append(features_squeeze) | ||||||
|             if query["text_input"] == "": |             if query["text_input"] == "": | ||||||
|                 multi_features_query.append( |                 multi_features_query.append( | ||||||
|                     select_extract_image_features[model_type](model, query["image"]) |                     select_extract_image_features[model_type]( | ||||||
|  |                         self, model, query["image"] | ||||||
|  |                     ) | ||||||
|                 ) |                 ) | ||||||
| 
 | 
 | ||||||
|         multi_features_stacked = torch.stack( |         multi_features_stacked = torch.stack( | ||||||
| @ -280,7 +282,7 @@ class MultimodalSearch(AnalysisMethod): | |||||||
|             ) |             ) | ||||||
| 
 | 
 | ||||||
|         similarity = features_image_stacked @ multi_features_stacked.t() |         similarity = features_image_stacked @ multi_features_stacked.t() | ||||||
|         similarity_soft_max = torch.nn.Softmax(dim=0)(similarity / 0.01) |         # similarity_soft_max = torch.nn.Softmax(dim=0)(similarity / 0.01) | ||||||
|         sorted_lists = [ |         sorted_lists = [ | ||||||
|             sorted(range(len(similarity)), key=lambda k: similarity[k, i], reverse=True) |             sorted(range(len(similarity)), key=lambda k: similarity[k, i], reverse=True) | ||||||
|             for i in range(len(similarity[0])) |             for i in range(len(similarity[0])) | ||||||
| @ -292,7 +294,7 @@ class MultimodalSearch(AnalysisMethod): | |||||||
|                 self[key]["rank " + list(search_query[q].values())[0]] = places[q][i] |                 self[key]["rank " + list(search_query[q].values())[0]] = places[q][i] | ||||||
|                 self[key][list(search_query[q].values())[0]] = similarity[i][q].item() |                 self[key][list(search_query[q].values())[0]] = similarity[i][q].item() | ||||||
| 
 | 
 | ||||||
|         return similarity, similarity_soft_max |         return similarity, sorted_lists | ||||||
| 
 | 
 | ||||||
|     def show_results(self, query): |     def show_results(self, query): | ||||||
|         if "image" in query.keys(): |         if "image" in query.keys(): | ||||||
|  | |||||||
		Загрузка…
	
	
			
			x
			
			
		
	
		Ссылка в новой задаче
	
	Block a user
	 Petr Andriushchenko
						Petr Andriushchenko