зеркало из
				https://github.com/ssciwr/AMMICO.git
				synced 2025-10-30 21:46:04 +02:00 
			
		
		
		
	Change input format for multimodal search
Этот коммит содержится в:
		
							родитель
							
								
									b709f69d58
								
							
						
					
					
						Коммит
						70866dfc69
					
				| @ -84,12 +84,12 @@ class MultimodalSearch(AnalysisMethod): | |||||||
|                 enabled=(MultimodalSearch.multimodal_device != torch.device("cpu")) |                 enabled=(MultimodalSearch.multimodal_device != torch.device("cpu")) | ||||||
|             ): |             ): | ||||||
|                 features_image = [ |                 features_image = [ | ||||||
|                     model.extract_features({"image": ten}, mode="image") |                     model.extract_features({"image": ten, "text_input": ""}, mode="image") | ||||||
|                     for ten in images_tensors |                     for ten in images_tensors | ||||||
|                 ] |                 ] | ||||||
|         else: |         else: | ||||||
|             features_image = [ |             features_image = [ | ||||||
|                 model.extract_features({"image": ten}, mode="image") |                 model.extract_features({"image": ten, "text_input": ""}, mode="image") | ||||||
|                 for ten in images_tensors |                 for ten in images_tensors | ||||||
|             ] |             ] | ||||||
| 
 | 
 | ||||||
| @ -114,6 +114,7 @@ class MultimodalSearch(AnalysisMethod): | |||||||
| 
 | 
 | ||||||
|         return features_text |         return features_text | ||||||
|      |      | ||||||
|  |      | ||||||
|     def parsing_images(self, model_type): |     def parsing_images(self, model_type): | ||||||
|         image_keys = sorted(self.keys()) |         image_keys = sorted(self.keys()) | ||||||
|         image_names = [self[k]["filename"] for k in image_keys] |         image_names = [self[k]["filename"] for k in image_keys] | ||||||
| @ -132,31 +133,40 @@ class MultimodalSearch(AnalysisMethod): | |||||||
|         ) |         ) | ||||||
|         MultimodalSearch.save_tensors(features_image_stacked) |         MultimodalSearch.save_tensors(features_image_stacked) | ||||||
| 
 | 
 | ||||||
|         return image_keys, image_names, features_image_stacked |         return model, vis_processors, txt_processors, image_keys, image_names, features_image_stacked | ||||||
| 
 | 
 | ||||||
|     def multimodal_search( |     def multimodal_search( | ||||||
|         self, model_type, image_keys, features_image_stacked, search_query |         self, model, vis_processors, txt_processors, model_type, image_keys, features_image_stacked, search_query | ||||||
|     ): |     ): | ||||||
|         features_image_stacked.to(MultimodalSearch.multimodal_device) |         features_image_stacked.to(MultimodalSearch.multimodal_device) | ||||||
|         ( |          | ||||||
|             model, |         for query in search_query: | ||||||
|             vis_processors, |             if (len(query)!=1): | ||||||
|             txt_processors, |                raise SyntaxError('Each querry must contain either an "image" or a "text_input"') | ||||||
|         ) = MultimodalSearch.load_feature_extractor_model( | 
 | ||||||
|             MultimodalSearch.multimodal_device, model_type |         multi_sample = [] | ||||||
|         ) |         for query in search_query: | ||||||
|         multi_text_input = [txt_processors["eval"](query) for query in search_query] |             if "text_input" in query.keys(): | ||||||
|         multi_sample = [{"text_input": [query]} for query in multi_text_input] |                 text_processing = txt_processors["eval"](query["text_input"]) | ||||||
|         multi_features_text = [ |                 image_processing = "" | ||||||
|             model.extract_features(sample, mode="text") for sample in multi_sample |             elif "image" in query.keys(): | ||||||
|         ] |                 _, image_processing = MultimodalSearch.read_and_process_images([query["image"]], vis_processors) | ||||||
|         multi_features_text_stacked = torch.stack( |                 text_processing = "" | ||||||
|             [ |             multi_sample.append({"image": image_processing, "text_input": text_processing}) | ||||||
|                 features.text_embeds_proj[:, 0, :].squeeze(0) |          | ||||||
|                 for features in multi_features_text |         multi_features_query = [] | ||||||
|             ] |         for query in multi_sample: | ||||||
|         ).to(MultimodalSearch.multimodal_device) |             if query["image"] == "": | ||||||
|         similarity = features_image_stacked @ multi_features_text_stacked.t() |                 features = model.extract_features(query, mode="text") | ||||||
|  |                 features_squeeze = features.text_embeds_proj[:, 0, :].squeeze(0).to(MultimodalSearch.multimodal_device) | ||||||
|  |                 multi_features_query.append(features_squeeze) | ||||||
|  |             if query["text_input"] == "": | ||||||
|  |                 multi_features_query.append( MultimodalSearch.extract_image_features( | ||||||
|  |             model, query["image"], model_type)) | ||||||
|  |          | ||||||
|  |         multi_features_stacked = torch.stack([query.squeeze(0) for query in multi_features_query]).to(MultimodalSearch.multimodal_device) | ||||||
|  | 
 | ||||||
|  |         similarity = features_image_stacked @ multi_features_stacked.t() | ||||||
|         sorted_lists = [ |         sorted_lists = [ | ||||||
|             sorted(range(len(similarity)), key=lambda k: similarity[k, i], reverse=True) |             sorted(range(len(similarity)), key=lambda k: similarity[k, i], reverse=True) | ||||||
|             for i in range(len(similarity[0])) |             for i in range(len(similarity[0])) | ||||||
| @ -165,13 +175,19 @@ class MultimodalSearch(AnalysisMethod): | |||||||
| 
 | 
 | ||||||
|         for q in range(len(search_query)): |         for q in range(len(search_query)): | ||||||
|             for i, key in zip(range(len(image_keys)), image_keys): |             for i, key in zip(range(len(image_keys)), image_keys): | ||||||
|                 self[key]["rank " + search_query[q]] = places[q][i] |                 self[key]["rank " + list(search_query[q].values())[0]] = places[q][i] | ||||||
|                 self[key][search_query[q]] = similarity[i][q].item() |                 self[key][list(search_query[q].values())[0]] = similarity[i][q].item() | ||||||
| 
 | 
 | ||||||
|         return self |         return similarity | ||||||
| 
 | 
 | ||||||
|     def show_results(self, query): |     def show_results(self, query): | ||||||
|         for s in sorted(self.items(), key=lambda t: t[1][query], reverse=True): |         if "image" in query.keys(): | ||||||
|  |             pic = Image.open(query["image"]).convert("RGB") | ||||||
|  |             pic.thumbnail((400, 400)) | ||||||
|  |             display("Your search query: ", pic,"--------------------------------------------------", "Results:") | ||||||
|  |         elif "text_input" in query.keys(): | ||||||
|  |             display("Your search query: " + query["text_input"], "--------------------------------------------------", "Results:") | ||||||
|  |         for s in sorted(self.items(), key=lambda t: t[1][list(query.values())[0]], reverse=True): | ||||||
|             p1 = Image.open(s[1]["filename"]).convert("RGB") |             p1 = Image.open(s[1]["filename"]).convert("RGB") | ||||||
|             p1.thumbnail((400, 400)) |             p1.thumbnail((400, 400)) | ||||||
|             display(p1, s[1][query]) |             display(p1, "Rank: " + str(s[1]["rank " + list(query.values())[0]]) + " Val: " + str(s[1][list(query.values())[0]])) | ||||||
|  | |||||||
							
								
								
									
										944
									
								
								notebooks/multimodal_search.ipynb
									
									
									
										сгенерированный
									
									
									
								
							
							
						
						
									
										944
									
								
								notebooks/multimodal_search.ipynb
									
									
									
										сгенерированный
									
									
									
								
							
										
											
												Различия файлов скрыты, потому что одна или несколько строк слишком длинны
											
										
									
								
							
		Загрузка…
	
	
			
			x
			
			
		
	
		Ссылка в новой задаче
	
	Block a user
	 Petr Andriushchenko
						Petr Andriushchenko