зеркало из
				https://github.com/ssciwr/AMMICO.git
				synced 2025-10-30 21:46:04 +02:00 
			
		
		
		
	fix if-else, added clip ViT-L-14=336 model
Этот коммит содержится в:
		
							родитель
							
								
									779c5227ae
								
							
						
					
					
						Коммит
						4e4b7fac75
					
				| @ -15,54 +15,58 @@ class MultimodalSearch(AnalysisMethod): | |||||||
| 
 | 
 | ||||||
|     multimodal_device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |     multimodal_device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | ||||||
| 
 | 
 | ||||||
|     def load_feature_extractor_model(device, model_type): |     def load_feature_extractor_model_blip2(device): | ||||||
|         if model_type == "blip2": |  | ||||||
|         model, vis_processors, txt_processors = load_model_and_preprocess( |         model, vis_processors, txt_processors = load_model_and_preprocess( | ||||||
|             name="blip2_feature_extractor", |             name="blip2_feature_extractor", | ||||||
|             model_type="pretrain", |             model_type="pretrain", | ||||||
|             is_eval=True, |             is_eval=True, | ||||||
|             device=device, |             device=device, | ||||||
|         ) |         ) | ||||||
|         elif model_type == "blip": |         return model, vis_processors, txt_processors | ||||||
|  | 
 | ||||||
|  |     def load_feature_extractor_model_blip(device): | ||||||
|         model, vis_processors, txt_processors = load_model_and_preprocess( |         model, vis_processors, txt_processors = load_model_and_preprocess( | ||||||
|             name="blip_feature_extractor", |             name="blip_feature_extractor", | ||||||
|             model_type="base", |             model_type="base", | ||||||
|             is_eval=True, |             is_eval=True, | ||||||
|             device=device, |             device=device, | ||||||
|         ) |         ) | ||||||
|         elif model_type == "albef": |         return model, vis_processors, txt_processors | ||||||
|  | 
 | ||||||
|  |     def load_feature_extractor_model_albef(device): | ||||||
|         model, vis_processors, txt_processors = load_model_and_preprocess( |         model, vis_processors, txt_processors = load_model_and_preprocess( | ||||||
|             name="albef_feature_extractor", |             name="albef_feature_extractor", | ||||||
|             model_type="base", |             model_type="base", | ||||||
|             is_eval=True, |             is_eval=True, | ||||||
|             device=device, |             device=device, | ||||||
|         ) |         ) | ||||||
|         elif model_type == "clip_base": |         return model, vis_processors, txt_processors | ||||||
|  | 
 | ||||||
|  |     def load_feature_extractor_model_clip_base(device): | ||||||
|         model, vis_processors, txt_processors = load_model_and_preprocess( |         model, vis_processors, txt_processors = load_model_and_preprocess( | ||||||
|             name="clip_feature_extractor", |             name="clip_feature_extractor", | ||||||
|             model_type="base", |             model_type="base", | ||||||
|             is_eval=True, |             is_eval=True, | ||||||
|             device=device, |             device=device, | ||||||
|         ) |         ) | ||||||
|         elif model_type == "clip_rn50": |         return model, vis_processors, txt_processors | ||||||
|             model, vis_processors, txt_processors = load_model_and_preprocess( | 
 | ||||||
|                 name="clip_feature_extractor", |     def load_feature_extractor_model_clip_vitl14(device): | ||||||
|                 model_type="RN50", |  | ||||||
|                 is_eval=True, |  | ||||||
|                 device=device, |  | ||||||
|             ) |  | ||||||
|         elif model_type == "clip_vitl14": |  | ||||||
|         model, vis_processors, txt_processors = load_model_and_preprocess( |         model, vis_processors, txt_processors = load_model_and_preprocess( | ||||||
|             name="clip_feature_extractor", |             name="clip_feature_extractor", | ||||||
|             model_type="ViT-L-14", |             model_type="ViT-L-14", | ||||||
|             is_eval=True, |             is_eval=True, | ||||||
|             device=device, |             device=device, | ||||||
|         ) |         ) | ||||||
|         else: |         return model, vis_processors, txt_processors | ||||||
|             print( |  | ||||||
|                 "Please, use one of the following models: blip2, blip, albef, clip_base, clip_rn50, clip_vitl14" |  | ||||||
|             ) |  | ||||||
| 
 | 
 | ||||||
|  |     def load_feature_extractor_model_clip_vitl14_336(device): | ||||||
|  |         model, vis_processors, txt_processors = load_model_and_preprocess( | ||||||
|  |             name="clip_feature_extractor", | ||||||
|  |             model_type="ViT-L-14-336", | ||||||
|  |             is_eval=True, | ||||||
|  |             device=device, | ||||||
|  |         ) | ||||||
|         return model, vis_processors, txt_processors |         return model, vis_processors, txt_processors | ||||||
| 
 | 
 | ||||||
|     def read_img(filepath): |     def read_img(filepath): | ||||||
| @ -81,34 +85,29 @@ class MultimodalSearch(AnalysisMethod): | |||||||
| 
 | 
 | ||||||
|         return raw_images, images_tensors |         return raw_images, images_tensors | ||||||
| 
 | 
 | ||||||
|     def extract_image_features(model, images_tensors, model_type): |     def extract_image_features_blip2(model, images_tensors): | ||||||
|         if model_type == "blip2": |  | ||||||
|         with torch.cuda.amp.autocast( |         with torch.cuda.amp.autocast( | ||||||
|             enabled=(MultimodalSearch.multimodal_device != torch.device("cpu")) |             enabled=(MultimodalSearch.multimodal_device != torch.device("cpu")) | ||||||
|         ): |         ): | ||||||
|             features_image = [ |             features_image = [ | ||||||
|                     model.extract_features( |                 model.extract_features({"image": ten, "text_input": ""}, mode="image") | ||||||
|                         {"image": ten, "text_input": ""}, mode="image" |  | ||||||
|                     ) |  | ||||||
|                 for ten in images_tensors |                 for ten in images_tensors | ||||||
|             ] |             ] | ||||||
|             features_image_stacked = torch.stack( |             features_image_stacked = torch.stack( | ||||||
|                     [ |                 [feat.image_embeds_proj[:, 0, :].squeeze(0) for feat in features_image] | ||||||
|                         feat.image_embeds_proj[:, 0, :].squeeze(0) |  | ||||||
|                         for feat in features_image |  | ||||||
|                     ] |  | ||||||
|             ) |             ) | ||||||
|         elif model_type in ("clip_base", "clip_rn50", "clip_vitl14"): |         return features_image_stacked | ||||||
|  | 
 | ||||||
|  |     def extract_image_features_clip(model, images_tensors): | ||||||
|         features_image = [ |         features_image = [ | ||||||
|             model.extract_features({"image": ten}) for ten in images_tensors |             model.extract_features({"image": ten}) for ten in images_tensors | ||||||
|         ] |         ] | ||||||
|         features_image_stacked = torch.stack( |         features_image_stacked = torch.stack( | ||||||
|                 [ |             [Func.normalize(feat.float(), dim=-1).squeeze(0) for feat in features_image] | ||||||
|                     Func.normalize(feat.float(), dim=-1).squeeze(0) |  | ||||||
|                     for feat in features_image |  | ||||||
|                 ] |  | ||||||
|         ) |         ) | ||||||
|         else: |         return features_image_stacked | ||||||
|  | 
 | ||||||
|  |     def extract_image_features_basic(model, images_tensors): | ||||||
|         features_image = [ |         features_image = [ | ||||||
|             model.extract_features({"image": ten, "text_input": ""}, mode="image") |             model.extract_features({"image": ten, "text_input": ""}, mode="image") | ||||||
|             for ten in images_tensors |             for ten in images_tensors | ||||||
| @ -137,9 +136,9 @@ class MultimodalSearch(AnalysisMethod): | |||||||
| 
 | 
 | ||||||
|         return features_text |         return features_text | ||||||
| 
 | 
 | ||||||
|     def parsing_images(self, model_type): |     def parsing_images(self, model_type, path_to_saved_tensors=None): | ||||||
| 
 | 
 | ||||||
|         if model_type in ("clip_base", "clip_rn50", "clip_vitl14"): |         if model_type in ("clip_base", "clip_vitl14_336", "clip_vitl14"): | ||||||
|             path_to_lib = lavis.__file__[:-11] + "models/clip_models/" |             path_to_lib = lavis.__file__[:-11] + "models/clip_models/" | ||||||
|             url = "https://raw.githubusercontent.com/salesforce/LAVIS/main/lavis/models/clip_models/bpe_simple_vocab_16e6.txt.gz" |             url = "https://raw.githubusercontent.com/salesforce/LAVIS/main/lavis/models/clip_models/bpe_simple_vocab_16e6.txt.gz" | ||||||
|             r = requests.get(url, allow_redirects=False) |             r = requests.get(url, allow_redirects=False) | ||||||
| @ -148,20 +147,46 @@ class MultimodalSearch(AnalysisMethod): | |||||||
|         image_keys = sorted(self.keys()) |         image_keys = sorted(self.keys()) | ||||||
|         image_names = [self[k]["filename"] for k in image_keys] |         image_names = [self[k]["filename"] for k in image_keys] | ||||||
| 
 | 
 | ||||||
|         ( |         select_model = { | ||||||
|             model, |             "blip2": MultimodalSearch.load_feature_extractor_model_blip2, | ||||||
|             vis_processors, |             "blip": MultimodalSearch.load_feature_extractor_model_blip, | ||||||
|             txt_processors, |             "albef": MultimodalSearch.load_feature_extractor_model_albef, | ||||||
|         ) = MultimodalSearch.load_feature_extractor_model( |             "clip_base": MultimodalSearch.load_feature_extractor_model_clip_base, | ||||||
|             MultimodalSearch.multimodal_device, model_type |             "clip_vitl14": MultimodalSearch.load_feature_extractor_model_clip_vitl14, | ||||||
|  |             "clip_vitl14_336": MultimodalSearch.load_feature_extractor_model_clip_vitl14_336, | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         select_extract_image_features = { | ||||||
|  |             "blip2": MultimodalSearch.extract_image_features_blip2, | ||||||
|  |             "blip": MultimodalSearch.extract_image_features_basic, | ||||||
|  |             "albef": MultimodalSearch.extract_image_features_basic, | ||||||
|  |             "clip_base": MultimodalSearch.extract_image_features_clip, | ||||||
|  |             "clip_vitl14": MultimodalSearch.extract_image_features_clip, | ||||||
|  |             "clip_vitl14_336": MultimodalSearch.extract_image_features_clip, | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         if model_type in select_model.keys(): | ||||||
|  |             (model, vis_processors, txt_processors,) = select_model[ | ||||||
|  |                 model_type | ||||||
|  |             ](MultimodalSearch.multimodal_device) | ||||||
|  |         else: | ||||||
|  |             raise SyntaxError( | ||||||
|  |                 "Please, use one of the following models: blip2, blip, albef, clip_base, clip_vitl14, clip_vitl14_336" | ||||||
|             ) |             ) | ||||||
|  | 
 | ||||||
|         raw_images, images_tensors = MultimodalSearch.read_and_process_images( |         raw_images, images_tensors = MultimodalSearch.read_and_process_images( | ||||||
|             image_names, vis_processors |             image_names, vis_processors | ||||||
|         ) |         ) | ||||||
|         features_image_stacked = MultimodalSearch.extract_image_features( |         if path_to_saved_tensors is None: | ||||||
|             model, images_tensors, model_type |             with torch.no_grad(): | ||||||
|  |                 features_image_stacked = select_extract_image_features[model_type]( | ||||||
|  |                     model, images_tensors | ||||||
|                 ) |                 ) | ||||||
|             MultimodalSearch.save_tensors(model_type, features_image_stacked) |             MultimodalSearch.save_tensors(model_type, features_image_stacked) | ||||||
|  |         else: | ||||||
|  |             features_image_stacked = MultimodalSearch.load_tensors( | ||||||
|  |                 str(path_to_saved_tensors) | ||||||
|  |             ) | ||||||
| 
 | 
 | ||||||
|         return ( |         return ( | ||||||
|             model, |             model, | ||||||
| @ -175,6 +200,16 @@ class MultimodalSearch(AnalysisMethod): | |||||||
|     def querys_processing( |     def querys_processing( | ||||||
|         self, search_query, model, txt_processors, vis_processors, model_type |         self, search_query, model, txt_processors, vis_processors, model_type | ||||||
|     ): |     ): | ||||||
|  | 
 | ||||||
|  |         select_extract_image_features = { | ||||||
|  |             "blip2": MultimodalSearch.extract_image_features_blip2, | ||||||
|  |             "blip": MultimodalSearch.extract_image_features_basic, | ||||||
|  |             "albef": MultimodalSearch.extract_image_features_basic, | ||||||
|  |             "clip_base": MultimodalSearch.extract_image_features_clip, | ||||||
|  |             "clip_vitl14": MultimodalSearch.extract_image_features_clip, | ||||||
|  |             "clip_vitl14_336": MultimodalSearch.extract_image_features_clip, | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|         for query in search_query: |         for query in search_query: | ||||||
|             if not (len(query) == 1) and (query in ("image", "text_input")): |             if not (len(query) == 1) and (query in ("image", "text_input")): | ||||||
|                 raise SyntaxError( |                 raise SyntaxError( | ||||||
| @ -194,10 +229,10 @@ class MultimodalSearch(AnalysisMethod): | |||||||
|                 {"image": images_tensors, "text_input": text_processing} |                 {"image": images_tensors, "text_input": text_processing} | ||||||
|             ) |             ) | ||||||
| 
 | 
 | ||||||
|         if model_type in ("clip_base", "clip_rn50", "clip_vitl14"): |  | ||||||
|         multi_features_query = [] |         multi_features_query = [] | ||||||
|         for query in multi_sample: |         for query in multi_sample: | ||||||
|             if query["image"] == "": |             if query["image"] == "": | ||||||
|  |                 if model_type in ("clip_base", "clip_vitl14_336", "clip_vitl14"): | ||||||
|                     features = model.extract_features( |                     features = model.extract_features( | ||||||
|                         {"text_input": query["text_input"]} |                         {"text_input": query["text_input"]} | ||||||
|                     ) |                     ) | ||||||
| @ -208,17 +243,7 @@ class MultimodalSearch(AnalysisMethod): | |||||||
|                     multi_features_query.append( |                     multi_features_query.append( | ||||||
|                         Func.normalize(features_squeeze, dim=-1) |                         Func.normalize(features_squeeze, dim=-1) | ||||||
|                     ) |                     ) | ||||||
|                 if query["text_input"] == "": |  | ||||||
|                     multi_features_query.append( |  | ||||||
|                         MultimodalSearch.extract_image_features( |  | ||||||
|                             model, query["image"], model_type |  | ||||||
|                         ) |  | ||||||
|                     ) |  | ||||||
| 
 |  | ||||||
|                 else: |                 else: | ||||||
|             multi_features_query = [] |  | ||||||
|             for query in multi_sample: |  | ||||||
|                 if query["image"] == "": |  | ||||||
|                     features = model.extract_features(query, mode="text") |                     features = model.extract_features(query, mode="text") | ||||||
|                     features_squeeze = ( |                     features_squeeze = ( | ||||||
|                         features.text_embeds_proj[:, 0, :] |                         features.text_embeds_proj[:, 0, :] | ||||||
| @ -228,9 +253,7 @@ class MultimodalSearch(AnalysisMethod): | |||||||
|                     multi_features_query.append(features_squeeze) |                     multi_features_query.append(features_squeeze) | ||||||
|             if query["text_input"] == "": |             if query["text_input"] == "": | ||||||
|                 multi_features_query.append( |                 multi_features_query.append( | ||||||
|                         MultimodalSearch.extract_image_features( |                     select_extract_image_features[model_type](model, query["image"]) | ||||||
|                             model, query["image"], model_type |  | ||||||
|                         ) |  | ||||||
|                 ) |                 ) | ||||||
| 
 | 
 | ||||||
|         multi_features_stacked = torch.stack( |         multi_features_stacked = torch.stack( | ||||||
| @ -251,11 +274,13 @@ class MultimodalSearch(AnalysisMethod): | |||||||
|     ): |     ): | ||||||
|         features_image_stacked.to(MultimodalSearch.multimodal_device) |         features_image_stacked.to(MultimodalSearch.multimodal_device) | ||||||
| 
 | 
 | ||||||
|  |         with torch.no_grad(): | ||||||
|             multi_features_stacked = MultimodalSearch.querys_processing( |             multi_features_stacked = MultimodalSearch.querys_processing( | ||||||
|                 self, search_query, model, txt_processors, vis_processors, model_type |                 self, search_query, model, txt_processors, vis_processors, model_type | ||||||
|             ) |             ) | ||||||
| 
 | 
 | ||||||
|         similarity = features_image_stacked @ multi_features_stacked.t() |         similarity = features_image_stacked @ multi_features_stacked.t() | ||||||
|  |         similarity_soft_max = torch.nn.Softmax(dim=0)(similarity / 0.01) | ||||||
|         sorted_lists = [ |         sorted_lists = [ | ||||||
|             sorted(range(len(similarity)), key=lambda k: similarity[k, i], reverse=True) |             sorted(range(len(similarity)), key=lambda k: similarity[k, i], reverse=True) | ||||||
|             for i in range(len(similarity[0])) |             for i in range(len(similarity[0])) | ||||||
| @ -267,7 +292,7 @@ class MultimodalSearch(AnalysisMethod): | |||||||
|                 self[key]["rank " + list(search_query[q].values())[0]] = places[q][i] |                 self[key]["rank " + list(search_query[q].values())[0]] = places[q][i] | ||||||
|                 self[key][list(search_query[q].values())[0]] = similarity[i][q].item() |                 self[key][list(search_query[q].values())[0]] = similarity[i][q].item() | ||||||
| 
 | 
 | ||||||
|         return similarity |         return similarity, similarity_soft_max | ||||||
| 
 | 
 | ||||||
|     def show_results(self, query): |     def show_results(self, query): | ||||||
|         if "image" in query.keys(): |         if "image" in query.keys(): | ||||||
|  | |||||||
							
								
								
									
										29
									
								
								notebooks/multimodal_search.ipynb
									
									
									
										сгенерированный
									
									
									
								
							
							
						
						
									
										29
									
								
								notebooks/multimodal_search.ipynb
									
									
									
										сгенерированный
									
									
									
								
							| @ -81,7 +81,7 @@ | |||||||
|    "id": "66d6ede4-00bc-4aeb-9a36-e52d7de33fe5", |    "id": "66d6ede4-00bc-4aeb-9a36-e52d7de33fe5", | ||||||
|    "metadata": {}, |    "metadata": {}, | ||||||
|    "source": [ |    "source": [ | ||||||
|     "You can choose one of the following models: blip, blip2, albef, clip_base, clip_rn50, clip_vitl14" |     "You can choose one of the following models: blip, blip2, albef, clip_base, clip_vitl14, clip_vitl14_336" | ||||||
|    ] |    ] | ||||||
|   }, |   }, | ||||||
|   { |   { | ||||||
| @ -91,7 +91,7 @@ | |||||||
|    "metadata": {}, |    "metadata": {}, | ||||||
|    "outputs": [], |    "outputs": [], | ||||||
|    "source": [ |    "source": [ | ||||||
|     "model_type = \"blip\"" |     "model_type = \"clip_vitl14_336\"" | ||||||
|    ] |    ] | ||||||
|   }, |   }, | ||||||
|   { |   { | ||||||
| @ -116,17 +116,32 @@ | |||||||
|    "id": "9ff8a894-566b-4c4f-acca-21c50b5b1f52", |    "id": "9ff8a894-566b-4c4f-acca-21c50b5b1f52", | ||||||
|    "metadata": {}, |    "metadata": {}, | ||||||
|    "source": [ |    "source": [ | ||||||
|     "The of all images `features_image_stacked` was saved in `saved_features_image.pt`. If you run it once for current model and set of images you do not need to repeat it again. Instead you can load this features with the command:" |     "The tensors of all images `features_image_stacked` was saved in `<Number_of_images>_<model_name>_saved_features_image.pt`. If you run it once for current model and current set of images you do not need to repeat it again. Instead you can load this features with the command:" | ||||||
|    ] |    ] | ||||||
|   }, |   }, | ||||||
|   { |   { | ||||||
|    "cell_type": "code", |    "cell_type": "code", | ||||||
|    "execution_count": null, |    "execution_count": null, | ||||||
|    "id": "c40e93f0-6bea-4886-b904-8b46ed6ec819", |    "id": "56c6d488-f093-4661-835a-5c73a329c874", | ||||||
|    "metadata": {}, |    "metadata": {}, | ||||||
|    "outputs": [], |    "outputs": [], | ||||||
|    "source": [ |    "source": [ | ||||||
|     "# features_image_stacked = ms.MultimodalSearch.load_tensors('saved_features_image.pt')" |     "# (\n", | ||||||
|  |     "#    model,\n", | ||||||
|  |     "#    vis_processors,\n", | ||||||
|  |     "#    txt_processors,\n", | ||||||
|  |     "#    image_keys,\n", | ||||||
|  |     "#    image_names,\n", | ||||||
|  |     "#    features_image_stacked,\n", | ||||||
|  |     "# ) = ms.MultimodalSearch.parsing_images(mydict, model_type,\"18_clip_base_saved_features_image.pt\")" | ||||||
|  |    ] | ||||||
|  |   }, | ||||||
|  |   { | ||||||
|  |    "cell_type": "markdown", | ||||||
|  |    "id": "309923c1-d6f8-4424-8fca-bde5f3a98b38", | ||||||
|  |    "metadata": {}, | ||||||
|  |    "source": [ | ||||||
|  |     "Here we already processed our image folder with 18 images with `clip_base` model. So you need just write the name `18_clip_base_saved_features_image.pt` of the saved file that consists of tensors of all images as a 3rd argument to the previous function. " | ||||||
|    ] |    ] | ||||||
|   }, |   }, | ||||||
|   { |   { | ||||||
| @ -170,7 +185,7 @@ | |||||||
|     "    image_keys,\n", |     "    image_keys,\n", | ||||||
|     "    features_image_stacked,\n", |     "    features_image_stacked,\n", | ||||||
|     "    search_query3,\n", |     "    search_query3,\n", | ||||||
|     ");" |     ")" | ||||||
|    ] |    ] | ||||||
|   }, |   }, | ||||||
|   { |   { | ||||||
| @ -206,7 +221,7 @@ | |||||||
|    "metadata": {}, |    "metadata": {}, | ||||||
|    "outputs": [], |    "outputs": [], | ||||||
|    "source": [ |    "source": [ | ||||||
|     "ms.MultimodalSearch.show_results(mydict, search_query3[0])" |     "ms.MultimodalSearch.show_results(mydict, search_query3[4])" | ||||||
|    ] |    ] | ||||||
|   }, |   }, | ||||||
|   { |   { | ||||||
|  | |||||||
		Загрузка…
	
	
			
			x
			
			
		
	
		Ссылка в новой задаче
	
	Block a user
	 Petr Andriushchenko
						Petr Andriushchenko