зеркало из
				https://github.com/ssciwr/AMMICO.git
				synced 2025-10-31 05:56:05 +02:00 
			
		
		
		
	Wiriting docs to multimodal_search.py
Этот коммит содержится в:
		
							родитель
							
								
									ca3994c72f
								
							
						
					
					
						Коммит
						f524ecaad8
					
				| @ -20,7 +20,18 @@ class MultimodalSearch(AnalysisMethod): | |||||||
| 
 | 
 | ||||||
|     multimodal_device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |     multimodal_device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | ||||||
| 
 | 
 | ||||||
|     def load_feature_extractor_model_blip2(self, device): |     def load_feature_extractor_model_blip2(self, device: str = "cpu"): | ||||||
|  |         """ | ||||||
|  |         Load pretrain blip2_feature_extractor model and preprocessors for visual and text inputs from lavis.models. | ||||||
|  | 
 | ||||||
|  |         Args: | ||||||
|  |             device (str): device to use. Can be "cpu" or "cuda". Default: "cpu". | ||||||
|  | 
 | ||||||
|  |         Returns: | ||||||
|  |             model (torch.nn.Module): model. | ||||||
|  |             vis_processors (dict): preprocessors for visual inputs. | ||||||
|  |             txt_processors (dict): preprocessors for text inputs. | ||||||
|  |         """ | ||||||
|         model, vis_processors, txt_processors = load_model_and_preprocess( |         model, vis_processors, txt_processors = load_model_and_preprocess( | ||||||
|             name="blip2_feature_extractor", |             name="blip2_feature_extractor", | ||||||
|             model_type="pretrain", |             model_type="pretrain", | ||||||
| @ -29,7 +40,18 @@ class MultimodalSearch(AnalysisMethod): | |||||||
|         ) |         ) | ||||||
|         return model, vis_processors, txt_processors |         return model, vis_processors, txt_processors | ||||||
| 
 | 
 | ||||||
|     def load_feature_extractor_model_blip(self, device): |     def load_feature_extractor_model_blip(self, device: str = "cpu"): | ||||||
|  |         """ | ||||||
|  |         Load base blip_feature_extractor model and preprocessors for visual and text inputs from lavis.models. | ||||||
|  | 
 | ||||||
|  |         Args: | ||||||
|  |             device (str): device to use. Can be "cpu" or "cuda". Default: "cpu". | ||||||
|  | 
 | ||||||
|  |         Returns: | ||||||
|  |             model (torch.nn.Module): model. | ||||||
|  |             vis_processors (dict): preprocessors for visual inputs. | ||||||
|  |             txt_processors (dict): preprocessors for text inputs. | ||||||
|  |         """ | ||||||
|         model, vis_processors, txt_processors = load_model_and_preprocess( |         model, vis_processors, txt_processors = load_model_and_preprocess( | ||||||
|             name="blip_feature_extractor", |             name="blip_feature_extractor", | ||||||
|             model_type="base", |             model_type="base", | ||||||
| @ -38,7 +60,18 @@ class MultimodalSearch(AnalysisMethod): | |||||||
|         ) |         ) | ||||||
|         return model, vis_processors, txt_processors |         return model, vis_processors, txt_processors | ||||||
| 
 | 
 | ||||||
|     def load_feature_extractor_model_albef(self, device): |     def load_feature_extractor_model_albef(self, device: str = "cpu"): | ||||||
|  |         """ | ||||||
|  |         Load base albef_feature_extractor model and preprocessors for visual and text inputs from lavis.models. | ||||||
|  | 
 | ||||||
|  |         Args: | ||||||
|  |             device (str): device to use. Can be "cpu" or "cuda". Default: "cpu". | ||||||
|  | 
 | ||||||
|  |         Returns: | ||||||
|  |             model (torch.nn.Module): model. | ||||||
|  |             vis_processors (dict): preprocessors for visual inputs. | ||||||
|  |             txt_processors (dict): preprocessors for text inputs. | ||||||
|  |         """ | ||||||
|         model, vis_processors, txt_processors = load_model_and_preprocess( |         model, vis_processors, txt_processors = load_model_and_preprocess( | ||||||
|             name="albef_feature_extractor", |             name="albef_feature_extractor", | ||||||
|             model_type="base", |             model_type="base", | ||||||
| @ -47,7 +80,18 @@ class MultimodalSearch(AnalysisMethod): | |||||||
|         ) |         ) | ||||||
|         return model, vis_processors, txt_processors |         return model, vis_processors, txt_processors | ||||||
| 
 | 
 | ||||||
|     def load_feature_extractor_model_clip_base(self, device): |     def load_feature_extractor_model_clip_base(self, device: str = "cpu"): | ||||||
|  |         """ | ||||||
|  |         Load base clip_feature_extractor model and preprocessors for visual and text inputs from lavis.models. | ||||||
|  | 
 | ||||||
|  |         Args: | ||||||
|  |             device (str): device to use. Can be "cpu" or "cuda". Default: "cpu". | ||||||
|  | 
 | ||||||
|  |         Returns: | ||||||
|  |             model (torch.nn.Module): model. | ||||||
|  |             vis_processors (dict): preprocessors for visual inputs. | ||||||
|  |             txt_processors (dict): preprocessors for text inputs. | ||||||
|  |         """ | ||||||
|         model, vis_processors, txt_processors = load_model_and_preprocess( |         model, vis_processors, txt_processors = load_model_and_preprocess( | ||||||
|             name="clip_feature_extractor", |             name="clip_feature_extractor", | ||||||
|             model_type="base", |             model_type="base", | ||||||
| @ -56,7 +100,18 @@ class MultimodalSearch(AnalysisMethod): | |||||||
|         ) |         ) | ||||||
|         return model, vis_processors, txt_processors |         return model, vis_processors, txt_processors | ||||||
| 
 | 
 | ||||||
|     def load_feature_extractor_model_clip_vitl14(self, device): |     def load_feature_extractor_model_clip_vitl14(self, device: str = "cpu"): | ||||||
|  |         """ | ||||||
|  |         Load ViT-L-14 clip_feature_extractor model and preprocessors for visual and text inputs from lavis.models. | ||||||
|  | 
 | ||||||
|  |         Args: | ||||||
|  |             device (str): device to use. Can be "cpu" or "cuda". Default: "cpu". | ||||||
|  | 
 | ||||||
|  |         Returns: | ||||||
|  |             model (torch.nn.Module): model. | ||||||
|  |             vis_processors (dict): preprocessors for visual inputs. | ||||||
|  |             txt_processors (dict): preprocessors for text inputs. | ||||||
|  |         """ | ||||||
|         model, vis_processors, txt_processors = load_model_and_preprocess( |         model, vis_processors, txt_processors = load_model_and_preprocess( | ||||||
|             name="clip_feature_extractor", |             name="clip_feature_extractor", | ||||||
|             model_type="ViT-L-14", |             model_type="ViT-L-14", | ||||||
| @ -65,7 +120,18 @@ class MultimodalSearch(AnalysisMethod): | |||||||
|         ) |         ) | ||||||
|         return model, vis_processors, txt_processors |         return model, vis_processors, txt_processors | ||||||
| 
 | 
 | ||||||
|     def load_feature_extractor_model_clip_vitl14_336(self, device): |     def load_feature_extractor_model_clip_vitl14_336(self, device: str = "cpu"): | ||||||
|  |         """ | ||||||
|  |         Load ViT-L-14-336 clip_feature_extractor model and preprocessors for visual and text inputs from lavis.models. | ||||||
|  | 
 | ||||||
|  |         Args: | ||||||
|  |             device (str): device to use. Can be "cpu" or "cuda". Default: "cpu". | ||||||
|  | 
 | ||||||
|  |         Returns: | ||||||
|  |             model (torch.nn.Module): model. | ||||||
|  |             vis_processors (dict): preprocessors for visual inputs. | ||||||
|  |             txt_processors (dict): preprocessors for text inputs. | ||||||
|  |         """ | ||||||
|         model, vis_processors, txt_processors = load_model_and_preprocess( |         model, vis_processors, txt_processors = load_model_and_preprocess( | ||||||
|             name="clip_feature_extractor", |             name="clip_feature_extractor", | ||||||
|             model_type="ViT-L-14-336", |             model_type="ViT-L-14-336", | ||||||
| @ -74,11 +140,31 @@ class MultimodalSearch(AnalysisMethod): | |||||||
|         ) |         ) | ||||||
|         return model, vis_processors, txt_processors |         return model, vis_processors, txt_processors | ||||||
| 
 | 
 | ||||||
|     def read_img(self, filepath): |     def read_img(self, filepath: str) -> Image: | ||||||
|  |         """ | ||||||
|  |         Load Image from filepath. | ||||||
|  | 
 | ||||||
|  |         Args: | ||||||
|  |             filepath (str): path to image. | ||||||
|  | 
 | ||||||
|  |         Returns: | ||||||
|  |             raw_image (PIL.Image): image. | ||||||
|  |         """ | ||||||
|         raw_image = Image.open(filepath).convert("RGB") |         raw_image = Image.open(filepath).convert("RGB") | ||||||
|         return raw_image |         return raw_image | ||||||
| 
 | 
 | ||||||
|     def read_and_process_images(self, image_paths, vis_processor): |     def read_and_process_images(self, image_paths: str, vis_processor) -> tuple: | ||||||
|  |         """ | ||||||
|  |         Read and process images with vis_processor. | ||||||
|  | 
 | ||||||
|  |         Args: | ||||||
|  |             image_paths (str): paths to images. | ||||||
|  |             vis_processor (dict): preprocessors for visual inputs. | ||||||
|  | 
 | ||||||
|  |         Returns: | ||||||
|  |             raw_images (list): list of images. | ||||||
|  |             images_tensors (torch.Tensor): tensors of images stacked in device. | ||||||
|  |         """ | ||||||
|         raw_images = [MultimodalSearch.read_img(self, path) for path in image_paths] |         raw_images = [MultimodalSearch.read_img(self, path) for path in image_paths] | ||||||
|         images = [ |         images = [ | ||||||
|             vis_processor["eval"](r_img) |             vis_processor["eval"](r_img) | ||||||
| @ -90,7 +176,19 @@ class MultimodalSearch(AnalysisMethod): | |||||||
| 
 | 
 | ||||||
|         return raw_images, images_tensors |         return raw_images, images_tensors | ||||||
| 
 | 
 | ||||||
|     def extract_image_features_blip2(self, model, images_tensors): |     def extract_image_features_blip2( | ||||||
|  |         self, model, images_tensors: torch.Tensor | ||||||
|  |     ) -> torch.Tensor: | ||||||
|  |         """ | ||||||
|  |         Extract image features from images_tensors with blip2_feature_extractor model. | ||||||
|  | 
 | ||||||
|  |         Args: | ||||||
|  |             model (torch.nn.Module): model. | ||||||
|  |             images_tensors (torch.Tensor): tensors of images stacked in device. | ||||||
|  | 
 | ||||||
|  |         Returns: | ||||||
|  |             features_image_stacked (torch.Tensor): tensors of images features stacked in device. | ||||||
|  |         """ | ||||||
|         with torch.cuda.amp.autocast( |         with torch.cuda.amp.autocast( | ||||||
|             enabled=(MultimodalSearch.multimodal_device != torch.device("cpu")) |             enabled=(MultimodalSearch.multimodal_device != torch.device("cpu")) | ||||||
|         ): |         ): | ||||||
| @ -103,7 +201,19 @@ class MultimodalSearch(AnalysisMethod): | |||||||
|             ) |             ) | ||||||
|         return features_image_stacked |         return features_image_stacked | ||||||
| 
 | 
 | ||||||
|     def extract_image_features_clip(self, model, images_tensors): |     def extract_image_features_clip( | ||||||
|  |         self, model, images_tensors: torch.Tensor | ||||||
|  |     ) -> torch.Tensor: | ||||||
|  |         """ | ||||||
|  |         Extract image features from images_tensors with clip_feature_extractor model. | ||||||
|  | 
 | ||||||
|  |         Args: | ||||||
|  |             model (torch.nn.Module): model. | ||||||
|  |             images_tensors (torch.Tensor): tensors of images stacked in device. | ||||||
|  | 
 | ||||||
|  |         Returns: | ||||||
|  |             features_image_stacked (torch.Tensor): tensors of images features stacked in device. | ||||||
|  |         """ | ||||||
|         features_image = [ |         features_image = [ | ||||||
|             model.extract_features({"image": ten}) for ten in images_tensors |             model.extract_features({"image": ten}) for ten in images_tensors | ||||||
|         ] |         ] | ||||||
| @ -112,7 +222,19 @@ class MultimodalSearch(AnalysisMethod): | |||||||
|         ) |         ) | ||||||
|         return features_image_stacked |         return features_image_stacked | ||||||
| 
 | 
 | ||||||
|     def extract_image_features_basic(self, model, images_tensors): |     def extract_image_features_basic( | ||||||
|  |         self, model, images_tensors: torch.Tensor | ||||||
|  |     ) -> torch.Tensor: | ||||||
|  |         """ | ||||||
|  |         Extract image features from images_tensors with blip_feature_extractor or albef_feature_extractor model. | ||||||
|  | 
 | ||||||
|  |         Args: | ||||||
|  |             model (torch.nn.Module): model. | ||||||
|  |             images_tensors (torch.Tensor): tensors of images stacked in device. | ||||||
|  | 
 | ||||||
|  |         Returns: | ||||||
|  |             features_image_stacked (torch.Tensor): tensors of images features stacked in device. | ||||||
|  |         """ | ||||||
|         features_image = [ |         features_image = [ | ||||||
|             model.extract_features({"image": ten, "text_input": ""}, mode="image") |             model.extract_features({"image": ten, "text_input": ""}, mode="image") | ||||||
|             for ten in images_tensors |             for ten in images_tensors | ||||||
| @ -124,11 +246,23 @@ class MultimodalSearch(AnalysisMethod): | |||||||
| 
 | 
 | ||||||
|     def save_tensors( |     def save_tensors( | ||||||
|         self, |         self, | ||||||
|         model_type, |         model_type: str, | ||||||
|         features_image_stacked, |         features_image_stacked: torch.Tensor, | ||||||
|         name="saved_features_image.pt", |         name: str = "saved_features_image.pt", | ||||||
|         path="./saved_tensors/", |         path: str = "./saved_tensors/", | ||||||
|     ): |     ) -> str: | ||||||
|  |         """ | ||||||
|  |         Save tensors as binary to given path. | ||||||
|  | 
 | ||||||
|  |         Args: | ||||||
|  |             model_type (str): type of the model. | ||||||
|  |             features_image_stacked (torch.Tensor): tensors of images features stacked in device. | ||||||
|  |             name (str): name of the file. Default: "saved_features_image.pt". | ||||||
|  |             path (str): path to save the file. Default: "./saved_tensors/". | ||||||
|  | 
 | ||||||
|  |         Returns: | ||||||
|  |             name (str): name of the file. | ||||||
|  |         """ | ||||||
|         if not os.path.exists(path): |         if not os.path.exists(path): | ||||||
|             os.makedirs(path) |             os.makedirs(path) | ||||||
|         with open( |         with open( | ||||||
| @ -143,11 +277,30 @@ class MultimodalSearch(AnalysisMethod): | |||||||
|             torch.save(features_image_stacked, f) |             torch.save(features_image_stacked, f) | ||||||
|         return name |         return name | ||||||
| 
 | 
 | ||||||
|     def load_tensors(self, name): |     def load_tensors(self, name: str) -> torch.Tensor: | ||||||
|  |         """ | ||||||
|  |         Load tensors from given path. | ||||||
|  | 
 | ||||||
|  |         Args: | ||||||
|  |             name (str): name of the file. | ||||||
|  | 
 | ||||||
|  |         Returns: | ||||||
|  |             features_image_stacked (torch.Tensor): tensors of images features. | ||||||
|  |         """ | ||||||
|         features_image_stacked = torch.load(name) |         features_image_stacked = torch.load(name) | ||||||
|         return features_image_stacked |         return features_image_stacked | ||||||
| 
 | 
 | ||||||
|     def extract_text_features(self, model, text_input): |     def extract_text_features(self, model, text_input: str) -> torch.Tensor: | ||||||
|  |         """ | ||||||
|  |         Extract text features from text_input with feature_extractor model. | ||||||
|  | 
 | ||||||
|  |         Args: | ||||||
|  |             model (torch.nn.Module): model. | ||||||
|  |             text_input (str): text. | ||||||
|  | 
 | ||||||
|  |         Returns: | ||||||
|  |             features_text (torch.Tensor): tensors of text features. | ||||||
|  |         """ | ||||||
|         sample_text = {"text_input": [text_input]} |         sample_text = {"text_input": [text_input]} | ||||||
|         features_text = model.extract_features(sample_text, mode="text") |         features_text = model.extract_features(sample_text, mode="text") | ||||||
| 
 | 
 | ||||||
| @ -155,10 +308,26 @@ class MultimodalSearch(AnalysisMethod): | |||||||
| 
 | 
 | ||||||
|     def parsing_images( |     def parsing_images( | ||||||
|         self, |         self, | ||||||
|         model_type, |         model_type: str, | ||||||
|         path_to_saved_tensors="./saved_tensors/", |         path_to_save_tensors: str = "./saved_tensors/", | ||||||
|         path_to_load_tensors=None, |         path_to_load_tensors: str = None, | ||||||
|     ): |     ) -> tuple: | ||||||
|  |         """ | ||||||
|  |         Parsing images with feature_extractor model. | ||||||
|  | 
 | ||||||
|  |         Args: | ||||||
|  |             model_type (str): type of the model. | ||||||
|  |             path_to_save_tensors (str): path to save the tensors. Default: "./saved_tensors/". | ||||||
|  |             path_to_load_tensors (str): path to load the tesors. Default: None. | ||||||
|  | 
 | ||||||
|  |         Returns: | ||||||
|  |             model (torch.nn.Module): model. | ||||||
|  |             vis_processors (dict): preprocessors for visual inputs. | ||||||
|  |             txt_processors (dict): preprocessors for text inputs. | ||||||
|  |             image_keys (list): sorted list of image keys. | ||||||
|  |             image_names (list): sorted list of image names. | ||||||
|  |             features_image_stacked (torch.Tensor): tensors of images features stacked in device. | ||||||
|  |         """ | ||||||
|         if model_type in ("clip_base", "clip_vitl14_336", "clip_vitl14"): |         if model_type in ("clip_base", "clip_vitl14_336", "clip_vitl14"): | ||||||
|             path_to_lib = lavis.__file__[:-11] + "models/clip_models/" |             path_to_lib = lavis.__file__[:-11] + "models/clip_models/" | ||||||
|             url = "https://raw.githubusercontent.com/salesforce/LAVIS/main/lavis/models/clip_models/bpe_simple_vocab_16e6.txt.gz" |             url = "https://raw.githubusercontent.com/salesforce/LAVIS/main/lavis/models/clip_models/bpe_simple_vocab_16e6.txt.gz" | ||||||
| @ -208,7 +377,7 @@ class MultimodalSearch(AnalysisMethod): | |||||||
|                     self, model, images_tensors |                     self, model, images_tensors | ||||||
|                 ) |                 ) | ||||||
|             MultimodalSearch.save_tensors( |             MultimodalSearch.save_tensors( | ||||||
|                 self, model_type, features_image_stacked, path=path_to_saved_tensors |                 self, model_type, features_image_stacked, path=path_to_save_tensors | ||||||
|             ) |             ) | ||||||
|         else: |         else: | ||||||
|             features_image_stacked = MultimodalSearch.load_tensors( |             features_image_stacked = MultimodalSearch.load_tensors( | ||||||
| @ -225,8 +394,21 @@ class MultimodalSearch(AnalysisMethod): | |||||||
|         ) |         ) | ||||||
| 
 | 
 | ||||||
|     def querys_processing( |     def querys_processing( | ||||||
|         self, search_query, model, txt_processors, vis_processors, model_type |         self, search_query: list, model, txt_processors, vis_processors, model_type: str | ||||||
|     ): |     ) -> torch.Tensor: | ||||||
|  |         """ | ||||||
|  |         Process querys. | ||||||
|  | 
 | ||||||
|  |         Args: | ||||||
|  |             search_query (list): list of querys. | ||||||
|  |             model (torch.nn.Module): model. | ||||||
|  |             txt_processors (dict): preprocessors for text inputs. | ||||||
|  |             vis_processors (dict): preprocessors for visual inputs. | ||||||
|  |             model_type (str): type of the model. | ||||||
|  | 
 | ||||||
|  |         Returns: | ||||||
|  |             multi_features_stacked (torch.Tensor): tensors of querys features. | ||||||
|  |         """ | ||||||
|         select_extract_image_features = { |         select_extract_image_features = { | ||||||
|             "blip2": MultimodalSearch.extract_image_features_blip2, |             "blip2": MultimodalSearch.extract_image_features_blip2, | ||||||
|             "blip": MultimodalSearch.extract_image_features_basic, |             "blip": MultimodalSearch.extract_image_features_basic, | ||||||
| @ -295,14 +477,33 @@ class MultimodalSearch(AnalysisMethod): | |||||||
|         model, |         model, | ||||||
|         vis_processors, |         vis_processors, | ||||||
|         txt_processors, |         txt_processors, | ||||||
|         model_type, |         model_type: str, | ||||||
|         image_keys, |         image_keys: list, | ||||||
|         features_image_stacked, |         features_image_stacked: torch.Tensor, | ||||||
|         search_query, |         search_query: list, | ||||||
|         filter_number_of_images=None, |         filter_number_of_images: str = None, | ||||||
|         filter_val_limit=None, |         filter_val_limit: str = None, | ||||||
|         filter_rel_error=None, |         filter_rel_error: str = None, | ||||||
|     ): |     ) -> tuple: | ||||||
|  |         """ | ||||||
|  |         Search for images with given querys. | ||||||
|  | 
 | ||||||
|  |         Args: | ||||||
|  |             model (torch.nn.Module): model. | ||||||
|  |             vis_processors (dict): preprocessors for visual inputs. | ||||||
|  |             txt_processors (dict): preprocessors for text inputs. | ||||||
|  |             model_type (str): type of the model. | ||||||
|  |             image_keys (list): sorted list of image keys. | ||||||
|  |             features_image_stacked (torch.Tensor): tensors of images features stacked in device. | ||||||
|  |             search_query (list): list of querys. | ||||||
|  |             filter_number_of_images (str): number of images to show. Default: None. | ||||||
|  |             filter_val_limit (str): limit of similarity value. Default: None. | ||||||
|  |             filter_rel_error (str): limit of relative error. Default: None. | ||||||
|  | 
 | ||||||
|  |         Returns: | ||||||
|  |             similarity (torch.Tensor): similarity between images and querys. | ||||||
|  |             sorted_lists (list): sorted list of similarity. | ||||||
|  |         """ | ||||||
|         if filter_number_of_images is None: |         if filter_number_of_images is None: | ||||||
|             filter_number_of_images = len(self) |             filter_number_of_images = len(self) | ||||||
|         if filter_val_limit is None: |         if filter_val_limit is None: | ||||||
| @ -343,7 +544,16 @@ class MultimodalSearch(AnalysisMethod): | |||||||
|                     self[image_keys[key]][list(search_query[q].values())[0]] = 0 |                     self[image_keys[key]][list(search_query[q].values())[0]] = 0 | ||||||
|         return similarity, sorted_lists |         return similarity, sorted_lists | ||||||
| 
 | 
 | ||||||
|     def itm_text_precessing(self, search_query): |     def itm_text_precessing(self, search_query: list[dict[str, str]]) -> list: | ||||||
|  |         """ | ||||||
|  |         Process text querys for itm model. | ||||||
|  | 
 | ||||||
|  |         Args: | ||||||
|  |             search_query (list): list of querys. | ||||||
|  | 
 | ||||||
|  |         Returns: | ||||||
|  |             text_query_index (list): list of indexes of text querys. | ||||||
|  |         """ | ||||||
|         for query in search_query: |         for query in search_query: | ||||||
|             if (len(query) != 1) and (query in ("image", "text_input")): |             if (len(query) != 1) and (query in ("image", "text_input")): | ||||||
|                 raise SyntaxError( |                 raise SyntaxError( | ||||||
| @ -356,7 +566,17 @@ class MultimodalSearch(AnalysisMethod): | |||||||
| 
 | 
 | ||||||
|         return text_query_index |         return text_query_index | ||||||
| 
 | 
 | ||||||
|     def get_pathes_from_query(self, query): |     def get_pathes_from_query(self, query: dict[str, str]) -> tuple: | ||||||
|  |         """ | ||||||
|  |         Get pathes and image names from query. | ||||||
|  | 
 | ||||||
|  |         Args: | ||||||
|  |             query (dict): query. | ||||||
|  | 
 | ||||||
|  |         Returns: | ||||||
|  |             paths (list): list of pathes. | ||||||
|  |             image_names (list): list of image names. | ||||||
|  |         """ | ||||||
|         paths = [] |         paths = [] | ||||||
|         image_names = [] |         image_names = [] | ||||||
|         for s in sorted( |         for s in sorted( | ||||||
| @ -368,7 +588,18 @@ class MultimodalSearch(AnalysisMethod): | |||||||
|             image_names.append(s[0]) |             image_names.append(s[0]) | ||||||
|         return paths, image_names |         return paths, image_names | ||||||
| 
 | 
 | ||||||
|     def read_and_process_images_itm(self, image_paths, vis_processor): |     def read_and_process_images_itm(self, image_paths: list, vis_processor) -> tuple: | ||||||
|  |         """ | ||||||
|  |         Read and process images with vis_processor for itm model. | ||||||
|  | 
 | ||||||
|  |         Args: | ||||||
|  |             image_paths (list): paths to images. | ||||||
|  |             vis_processor (dict): preprocessors for visual inputs. | ||||||
|  | 
 | ||||||
|  |         Returns: | ||||||
|  |             raw_images (list): list of images. | ||||||
|  |             images_tensors (torch.Tensor): tensors of images stacked in device. | ||||||
|  |         """ | ||||||
|         raw_images = [MultimodalSearch.read_img(self, path) for path in image_paths] |         raw_images = [MultimodalSearch.read_img(self, path) for path in image_paths] | ||||||
|         images = [vis_processor(r_img) for r_img in raw_images] |         images = [vis_processor(r_img) for r_img in raw_images] | ||||||
|         images_tensors = torch.stack(images).to(MultimodalSearch.multimodal_device) |         images_tensors = torch.stack(images).to(MultimodalSearch.multimodal_device) | ||||||
| @ -377,12 +608,26 @@ class MultimodalSearch(AnalysisMethod): | |||||||
| 
 | 
 | ||||||
|     def compute_gradcam_batch( |     def compute_gradcam_batch( | ||||||
|         self, |         self, | ||||||
|         model, |         model: torch.nn.Module, | ||||||
|         visual_input, |         visual_input: torch.Tensor, | ||||||
|         text_input, |         text_input: str, | ||||||
|         tokenized_text, |         tokenized_text: torch.Tensor, | ||||||
|         block_num=6, |         block_num: str = 6, | ||||||
|     ): |     ) -> tuple: | ||||||
|  |         """ | ||||||
|  |         Compute gradcam for itm model. | ||||||
|  | 
 | ||||||
|  |         Args: | ||||||
|  |             model (torch.nn.Module): model. | ||||||
|  |             visual_input (torch.Tensor): tensors of images features stacked in device. | ||||||
|  |             text_input (str): text. | ||||||
|  |             tokenized_text (torch.Tensor): tokenized text. | ||||||
|  |             block_num (int): number of block. Default: 6. | ||||||
|  | 
 | ||||||
|  |         Returns: | ||||||
|  |             gradcam (torch.Tensor): gradcam. | ||||||
|  |             output (torch.Tensor): output of model. | ||||||
|  |         """ | ||||||
|         model.text_encoder.base_model.base_model.encoder.layer[ |         model.text_encoder.base_model.base_model.encoder.layer[ | ||||||
|             block_num |             block_num | ||||||
|         ].crossattention.self.save_attention = True |         ].crossattention.self.save_attention = True | ||||||
| @ -430,7 +675,16 @@ class MultimodalSearch(AnalysisMethod): | |||||||
| 
 | 
 | ||||||
|         return gradcam, output |         return gradcam, output | ||||||
| 
 | 
 | ||||||
|     def resize_img(self, raw_img): |     def resize_img(self, raw_img: Image): | ||||||
|  |         """ | ||||||
|  |         Proportional resize image to 240 p width. | ||||||
|  | 
 | ||||||
|  |         Args: | ||||||
|  |             raw_img (PIL.Image): image. | ||||||
|  | 
 | ||||||
|  |         Returns: | ||||||
|  |             resized_image (PIL.Image): proportional resized image to 240p. | ||||||
|  |         """ | ||||||
|         w, h = raw_img.size |         w, h = raw_img.size | ||||||
|         scaling_factor = 240 / w |         scaling_factor = 240 / w | ||||||
|         resized_image = raw_img.resize( |         resized_image = raw_img.resize( | ||||||
| @ -438,7 +692,25 @@ class MultimodalSearch(AnalysisMethod): | |||||||
|         ) |         ) | ||||||
|         return resized_image |         return resized_image | ||||||
| 
 | 
 | ||||||
|     def get_att_map(self, img, att_map, blur=True, overlap=True): |     def get_att_map( | ||||||
|  |         self, | ||||||
|  |         img: np.ndarray, | ||||||
|  |         att_map: np.ndarray, | ||||||
|  |         blur: bool = True, | ||||||
|  |         overlap: bool = True, | ||||||
|  |     ) -> np.ndarray: | ||||||
|  |         """ | ||||||
|  |         Get attention map. | ||||||
|  | 
 | ||||||
|  |         Args: | ||||||
|  |             img (np.ndarray): image. | ||||||
|  |             att_map (np.ndarray): attention map. | ||||||
|  |             blur (bool): blur attention map. Default: True. | ||||||
|  |             overlap (bool): overlap attention map with image. Default: True. | ||||||
|  | 
 | ||||||
|  |         Returns: | ||||||
|  |             att_map (np.ndarray): attention map. | ||||||
|  |         """ | ||||||
|         att_map -= att_map.min() |         att_map -= att_map.min() | ||||||
|         if att_map.max() > 0: |         if att_map.max() > 0: | ||||||
|             att_map /= att_map.max() |             att_map /= att_map.max() | ||||||
| @ -459,7 +731,14 @@ class MultimodalSearch(AnalysisMethod): | |||||||
|             ) |             ) | ||||||
|         return att_map |         return att_map | ||||||
| 
 | 
 | ||||||
|     def upload_model_blip2_coco(self): |     def upload_model_blip2_coco(self) -> tuple: | ||||||
|  |         """ | ||||||
|  |         Load coco blip2_image_text_matching model and preprocessors for visual inputs from lavis.models. | ||||||
|  | 
 | ||||||
|  |         Returns: | ||||||
|  |             itm_model (torch.nn.Module): model. | ||||||
|  |             vis_processor (dict): preprocessors for visual inputs. | ||||||
|  |         """ | ||||||
|         itm_model = load_model( |         itm_model = load_model( | ||||||
|             "blip2_image_text_matching", |             "blip2_image_text_matching", | ||||||
|             "coco", |             "coco", | ||||||
| @ -469,7 +748,14 @@ class MultimodalSearch(AnalysisMethod): | |||||||
|         vis_processor = load_processor("blip_image_eval").build(image_size=364) |         vis_processor = load_processor("blip_image_eval").build(image_size=364) | ||||||
|         return itm_model, vis_processor |         return itm_model, vis_processor | ||||||
| 
 | 
 | ||||||
|     def upload_model_blip_base(self): |     def upload_model_blip_base(self) -> tuple: | ||||||
|  |         """ | ||||||
|  |         Load base blip_image_text_matching model and preprocessors for visual input from lavis.models. | ||||||
|  | 
 | ||||||
|  |         Returns: | ||||||
|  |             itm_model (torch.nn.Module): model. | ||||||
|  |             vis_processor (dict): preprocessors for visual inputs. | ||||||
|  |         """ | ||||||
|         itm_model = load_model( |         itm_model = load_model( | ||||||
|             "blip_image_text_matching", |             "blip_image_text_matching", | ||||||
|             "base", |             "base", | ||||||
| @ -479,7 +765,14 @@ class MultimodalSearch(AnalysisMethod): | |||||||
|         vis_processor = load_processor("blip_image_eval").build(image_size=384) |         vis_processor = load_processor("blip_image_eval").build(image_size=384) | ||||||
|         return itm_model, vis_processor |         return itm_model, vis_processor | ||||||
| 
 | 
 | ||||||
|     def upload_model_blip_large(self): |     def upload_model_blip_large(self) -> tuple: | ||||||
|  |         """ | ||||||
|  |         Load large blip_image_text_matching model and preprocessors for visual input from lavis.models. | ||||||
|  | 
 | ||||||
|  |         Returns: | ||||||
|  |             itm_model (torch.nn.Module): model. | ||||||
|  |             vis_processor (dict): preprocessors for visual inputs. | ||||||
|  |         """ | ||||||
|         itm_model = load_model( |         itm_model = load_model( | ||||||
|             "blip_image_text_matching", |             "blip_image_text_matching", | ||||||
|             "large", |             "large", | ||||||
| @ -491,13 +784,28 @@ class MultimodalSearch(AnalysisMethod): | |||||||
| 
 | 
 | ||||||
|     def image_text_match_reordering( |     def image_text_match_reordering( | ||||||
|         self, |         self, | ||||||
|         search_query, |         search_query: list[dict[str, str]], | ||||||
|         itm_model_type, |         itm_model_type: str, | ||||||
|         image_keys, |         image_keys: list, | ||||||
|         sorted_lists, |         sorted_lists: list[list], | ||||||
|         batch_size=1, |         batch_size: int = 1, | ||||||
|         need_grad_cam=False, |         need_grad_cam: bool = False, | ||||||
|     ): |     ) -> tuple: | ||||||
|  |         """ | ||||||
|  |         Reorder images with itm model. | ||||||
|  | 
 | ||||||
|  |         Args: | ||||||
|  |             search_query (list): list of querys. | ||||||
|  |             itm_model_type (str): type of the model. | ||||||
|  |             image_keys (list): sorted list of image keys. | ||||||
|  |             sorted_lists (list): sorted list of similarity. | ||||||
|  |             batch_size (int): batch size. Default: 1. | ||||||
|  |             need_grad_cam (bool): need gradcam. Default: False. blip2_coco model does not yet work with gradcam. | ||||||
|  | 
 | ||||||
|  |         Returns: | ||||||
|  |             itm_scores2: list of itm scores. | ||||||
|  |             image_gradcam_with_itm: dict of image names and gradcam. | ||||||
|  |         """ | ||||||
|         if itm_model_type == "blip2_coco" and need_grad_cam is True: |         if itm_model_type == "blip2_coco" and need_grad_cam is True: | ||||||
|             raise SyntaxError( |             raise SyntaxError( | ||||||
|                 "The blip2_coco model does not yet work with gradcam. Please set need_grad_cam to False" |                 "The blip2_coco model does not yet work with gradcam. Please set need_grad_cam to False" | ||||||
| @ -624,7 +932,17 @@ class MultimodalSearch(AnalysisMethod): | |||||||
|         torch.cuda.empty_cache() |         torch.cuda.empty_cache() | ||||||
|         return itm_scores2, image_gradcam_with_itm |         return itm_scores2, image_gradcam_with_itm | ||||||
| 
 | 
 | ||||||
|     def show_results(self, query, itm=False, image_gradcam_with_itm=False): |     def show_results( | ||||||
|  |         self, query: dict, itm=False, image_gradcam_with_itm=False | ||||||
|  |     ) -> None: | ||||||
|  |         """ | ||||||
|  |         Show results of search. | ||||||
|  | 
 | ||||||
|  |         Args: | ||||||
|  |             query (dict): query. | ||||||
|  |             itm (bool): use itm model. Default: False. | ||||||
|  |             image_gradcam_with_itm (bool): use gradcam. Default: False. | ||||||
|  |         """ | ||||||
|         if "image" in query.keys(): |         if "image" in query.keys(): | ||||||
|             pic = Image.open(query["image"]).convert("RGB") |             pic = Image.open(query["image"]).convert("RGB") | ||||||
|             pic.thumbnail((400, 400)) |             pic.thumbnail((400, 400)) | ||||||
|  | |||||||
		Загрузка…
	
	
			
			x
			
			
		
	
		Ссылка в новой задаче
	
	Block a user
	 Petr Andriushchenko
						Petr Andriushchenko