зеркало из
				https://github.com/ssciwr/AMMICO.git
				synced 2025-10-30 21:46:04 +02:00 
			
		
		
		
	fix code smells
Этот коммит содержится в:
		
							родитель
							
								
									4e4b7fac75
								
							
						
					
					
						Коммит
						c208039b7c
					
				| @ -15,7 +15,7 @@ class MultimodalSearch(AnalysisMethod): | ||||
| 
 | ||||
|     multimodal_device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | ||||
| 
 | ||||
|     def load_feature_extractor_model_blip2(device): | ||||
|     def load_feature_extractor_model_blip2(self, device): | ||||
|         model, vis_processors, txt_processors = load_model_and_preprocess( | ||||
|             name="blip2_feature_extractor", | ||||
|             model_type="pretrain", | ||||
| @ -24,7 +24,7 @@ class MultimodalSearch(AnalysisMethod): | ||||
|         ) | ||||
|         return model, vis_processors, txt_processors | ||||
| 
 | ||||
|     def load_feature_extractor_model_blip(device): | ||||
|     def load_feature_extractor_model_blip(self, device): | ||||
|         model, vis_processors, txt_processors = load_model_and_preprocess( | ||||
|             name="blip_feature_extractor", | ||||
|             model_type="base", | ||||
| @ -33,7 +33,7 @@ class MultimodalSearch(AnalysisMethod): | ||||
|         ) | ||||
|         return model, vis_processors, txt_processors | ||||
| 
 | ||||
|     def load_feature_extractor_model_albef(device): | ||||
|     def load_feature_extractor_model_albef(self, device): | ||||
|         model, vis_processors, txt_processors = load_model_and_preprocess( | ||||
|             name="albef_feature_extractor", | ||||
|             model_type="base", | ||||
| @ -42,7 +42,7 @@ class MultimodalSearch(AnalysisMethod): | ||||
|         ) | ||||
|         return model, vis_processors, txt_processors | ||||
| 
 | ||||
|     def load_feature_extractor_model_clip_base(device): | ||||
|     def load_feature_extractor_model_clip_base(self, device): | ||||
|         model, vis_processors, txt_processors = load_model_and_preprocess( | ||||
|             name="clip_feature_extractor", | ||||
|             model_type="base", | ||||
| @ -51,7 +51,7 @@ class MultimodalSearch(AnalysisMethod): | ||||
|         ) | ||||
|         return model, vis_processors, txt_processors | ||||
| 
 | ||||
|     def load_feature_extractor_model_clip_vitl14(device): | ||||
|     def load_feature_extractor_model_clip_vitl14(self, device): | ||||
|         model, vis_processors, txt_processors = load_model_and_preprocess( | ||||
|             name="clip_feature_extractor", | ||||
|             model_type="ViT-L-14", | ||||
| @ -60,7 +60,7 @@ class MultimodalSearch(AnalysisMethod): | ||||
|         ) | ||||
|         return model, vis_processors, txt_processors | ||||
| 
 | ||||
|     def load_feature_extractor_model_clip_vitl14_336(device): | ||||
|     def load_feature_extractor_model_clip_vitl14_336(self, device): | ||||
|         model, vis_processors, txt_processors = load_model_and_preprocess( | ||||
|             name="clip_feature_extractor", | ||||
|             model_type="ViT-L-14-336", | ||||
| @ -69,12 +69,12 @@ class MultimodalSearch(AnalysisMethod): | ||||
|         ) | ||||
|         return model, vis_processors, txt_processors | ||||
| 
 | ||||
|     def read_img(filepath): | ||||
|     def read_img(self, filepath): | ||||
|         raw_image = Image.open(filepath).convert("RGB") | ||||
|         return raw_image | ||||
| 
 | ||||
|     def read_and_process_images(image_paths, vis_processor): | ||||
|         raw_images = [MultimodalSearch.read_img(path) for path in image_paths] | ||||
|     def read_and_process_images(self, image_paths, vis_processor): | ||||
|         raw_images = [MultimodalSearch.read_img(self, path) for path in image_paths] | ||||
|         images = [ | ||||
|             vis_processor["eval"](r_img) | ||||
|             .unsqueeze(0) | ||||
| @ -85,7 +85,7 @@ class MultimodalSearch(AnalysisMethod): | ||||
| 
 | ||||
|         return raw_images, images_tensors | ||||
| 
 | ||||
|     def extract_image_features_blip2(model, images_tensors): | ||||
|     def extract_image_features_blip2(self, model, images_tensors): | ||||
|         with torch.cuda.amp.autocast( | ||||
|             enabled=(MultimodalSearch.multimodal_device != torch.device("cpu")) | ||||
|         ): | ||||
| @ -98,7 +98,7 @@ class MultimodalSearch(AnalysisMethod): | ||||
|             ) | ||||
|         return features_image_stacked | ||||
| 
 | ||||
|     def extract_image_features_clip(model, images_tensors): | ||||
|     def extract_image_features_clip(self, model, images_tensors): | ||||
|         features_image = [ | ||||
|             model.extract_features({"image": ten}) for ten in images_tensors | ||||
|         ] | ||||
| @ -107,7 +107,7 @@ class MultimodalSearch(AnalysisMethod): | ||||
|         ) | ||||
|         return features_image_stacked | ||||
| 
 | ||||
|     def extract_image_features_basic(model, images_tensors): | ||||
|     def extract_image_features_basic(self, model, images_tensors): | ||||
|         features_image = [ | ||||
|             model.extract_features({"image": ten, "text_input": ""}, mode="image") | ||||
|             for ten in images_tensors | ||||
| @ -118,7 +118,7 @@ class MultimodalSearch(AnalysisMethod): | ||||
|         return features_image_stacked | ||||
| 
 | ||||
|     def save_tensors( | ||||
|         model_type, features_image_stacked, name="saved_features_image.pt" | ||||
|         self, model_type, features_image_stacked, name="saved_features_image.pt" | ||||
|     ): | ||||
|         with open( | ||||
|             str(len(features_image_stacked)) + "_" + model_type + "_" + name, "wb" | ||||
| @ -126,11 +126,11 @@ class MultimodalSearch(AnalysisMethod): | ||||
|             torch.save(features_image_stacked, f) | ||||
|         return name | ||||
| 
 | ||||
|     def load_tensors(name="saved_features_image.pt"): | ||||
|     def load_tensors(self, name="saved_features_image.pt"): | ||||
|         features_image_stacked = torch.load(name) | ||||
|         return features_image_stacked | ||||
| 
 | ||||
|     def extract_text_features(model, text_input): | ||||
|     def extract_text_features(self, model, text_input): | ||||
|         sample_text = {"text_input": [text_input]} | ||||
|         features_text = model.extract_features(sample_text, mode="text") | ||||
| 
 | ||||
| @ -168,24 +168,24 @@ class MultimodalSearch(AnalysisMethod): | ||||
|         if model_type in select_model.keys(): | ||||
|             (model, vis_processors, txt_processors,) = select_model[ | ||||
|                 model_type | ||||
|             ](MultimodalSearch.multimodal_device) | ||||
|             ](self, MultimodalSearch.multimodal_device) | ||||
|         else: | ||||
|             raise SyntaxError( | ||||
|                 "Please, use one of the following models: blip2, blip, albef, clip_base, clip_vitl14, clip_vitl14_336" | ||||
|             ) | ||||
| 
 | ||||
|         raw_images, images_tensors = MultimodalSearch.read_and_process_images( | ||||
|             image_names, vis_processors | ||||
|             self, image_names, vis_processors | ||||
|         ) | ||||
|         if path_to_saved_tensors is None: | ||||
|             with torch.no_grad(): | ||||
|                 features_image_stacked = select_extract_image_features[model_type]( | ||||
|                     model, images_tensors | ||||
|                     self, model, images_tensors | ||||
|                 ) | ||||
|             MultimodalSearch.save_tensors(model_type, features_image_stacked) | ||||
|             MultimodalSearch.save_tensors(self, model_type, features_image_stacked) | ||||
|         else: | ||||
|             features_image_stacked = MultimodalSearch.load_tensors( | ||||
|                 str(path_to_saved_tensors) | ||||
|                 self, str(path_to_saved_tensors) | ||||
|             ) | ||||
| 
 | ||||
|         return ( | ||||
| @ -222,7 +222,7 @@ class MultimodalSearch(AnalysisMethod): | ||||
|                 images_tensors = "" | ||||
|             elif "image" in query.keys(): | ||||
|                 _, images_tensors = MultimodalSearch.read_and_process_images( | ||||
|                     [query["image"]], vis_processors | ||||
|                     self, [query["image"]], vis_processors | ||||
|                 ) | ||||
|                 text_processing = "" | ||||
|             multi_sample.append( | ||||
| @ -253,7 +253,9 @@ class MultimodalSearch(AnalysisMethod): | ||||
|                     multi_features_query.append(features_squeeze) | ||||
|             if query["text_input"] == "": | ||||
|                 multi_features_query.append( | ||||
|                     select_extract_image_features[model_type](model, query["image"]) | ||||
|                     select_extract_image_features[model_type]( | ||||
|                         self, model, query["image"] | ||||
|                     ) | ||||
|                 ) | ||||
| 
 | ||||
|         multi_features_stacked = torch.stack( | ||||
| @ -280,7 +282,7 @@ class MultimodalSearch(AnalysisMethod): | ||||
|             ) | ||||
| 
 | ||||
|         similarity = features_image_stacked @ multi_features_stacked.t() | ||||
|         similarity_soft_max = torch.nn.Softmax(dim=0)(similarity / 0.01) | ||||
|         # similarity_soft_max = torch.nn.Softmax(dim=0)(similarity / 0.01) | ||||
|         sorted_lists = [ | ||||
|             sorted(range(len(similarity)), key=lambda k: similarity[k, i], reverse=True) | ||||
|             for i in range(len(similarity[0])) | ||||
| @ -292,7 +294,7 @@ class MultimodalSearch(AnalysisMethod): | ||||
|                 self[key]["rank " + list(search_query[q].values())[0]] = places[q][i] | ||||
|                 self[key][list(search_query[q].values())[0]] = similarity[i][q].item() | ||||
| 
 | ||||
|         return similarity, similarity_soft_max | ||||
|         return similarity, sorted_lists | ||||
| 
 | ||||
|     def show_results(self, query): | ||||
|         if "image" in query.keys(): | ||||
|  | ||||
		Загрузка…
	
	
			
			x
			
			
		
	
		Ссылка в новой задаче
	
	Block a user
	 Petr Andriushchenko
						Petr Andriushchenko