зеркало из
				https://github.com/ssciwr/AMMICO.git
				synced 2025-10-30 21:46:04 +02:00 
			
		
		
		
	Wiriting docs to multimodal_search.py
Этот коммит содержится в:
		
							родитель
							
								
									ca3994c72f
								
							
						
					
					
						Коммит
						f524ecaad8
					
				| @ -20,7 +20,18 @@ class MultimodalSearch(AnalysisMethod): | ||||
| 
 | ||||
|     multimodal_device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | ||||
| 
 | ||||
|     def load_feature_extractor_model_blip2(self, device): | ||||
|     def load_feature_extractor_model_blip2(self, device: str = "cpu"): | ||||
|         """ | ||||
|         Load pretrain blip2_feature_extractor model and preprocessors for visual and text inputs from lavis.models. | ||||
| 
 | ||||
|         Args: | ||||
|             device (str): device to use. Can be "cpu" or "cuda". Default: "cpu". | ||||
| 
 | ||||
|         Returns: | ||||
|             model (torch.nn.Module): model. | ||||
|             vis_processors (dict): preprocessors for visual inputs. | ||||
|             txt_processors (dict): preprocessors for text inputs. | ||||
|         """ | ||||
|         model, vis_processors, txt_processors = load_model_and_preprocess( | ||||
|             name="blip2_feature_extractor", | ||||
|             model_type="pretrain", | ||||
| @ -29,7 +40,18 @@ class MultimodalSearch(AnalysisMethod): | ||||
|         ) | ||||
|         return model, vis_processors, txt_processors | ||||
| 
 | ||||
|     def load_feature_extractor_model_blip(self, device): | ||||
|     def load_feature_extractor_model_blip(self, device: str = "cpu"): | ||||
|         """ | ||||
|         Load base blip_feature_extractor model and preprocessors for visual and text inputs from lavis.models. | ||||
| 
 | ||||
|         Args: | ||||
|             device (str): device to use. Can be "cpu" or "cuda". Default: "cpu". | ||||
| 
 | ||||
|         Returns: | ||||
|             model (torch.nn.Module): model. | ||||
|             vis_processors (dict): preprocessors for visual inputs. | ||||
|             txt_processors (dict): preprocessors for text inputs. | ||||
|         """ | ||||
|         model, vis_processors, txt_processors = load_model_and_preprocess( | ||||
|             name="blip_feature_extractor", | ||||
|             model_type="base", | ||||
| @ -38,7 +60,18 @@ class MultimodalSearch(AnalysisMethod): | ||||
|         ) | ||||
|         return model, vis_processors, txt_processors | ||||
| 
 | ||||
|     def load_feature_extractor_model_albef(self, device): | ||||
|     def load_feature_extractor_model_albef(self, device: str = "cpu"): | ||||
|         """ | ||||
|         Load base albef_feature_extractor model and preprocessors for visual and text inputs from lavis.models. | ||||
| 
 | ||||
|         Args: | ||||
|             device (str): device to use. Can be "cpu" or "cuda". Default: "cpu". | ||||
| 
 | ||||
|         Returns: | ||||
|             model (torch.nn.Module): model. | ||||
|             vis_processors (dict): preprocessors for visual inputs. | ||||
|             txt_processors (dict): preprocessors for text inputs. | ||||
|         """ | ||||
|         model, vis_processors, txt_processors = load_model_and_preprocess( | ||||
|             name="albef_feature_extractor", | ||||
|             model_type="base", | ||||
| @ -47,7 +80,18 @@ class MultimodalSearch(AnalysisMethod): | ||||
|         ) | ||||
|         return model, vis_processors, txt_processors | ||||
| 
 | ||||
|     def load_feature_extractor_model_clip_base(self, device): | ||||
|     def load_feature_extractor_model_clip_base(self, device: str = "cpu"): | ||||
|         """ | ||||
|         Load base clip_feature_extractor model and preprocessors for visual and text inputs from lavis.models. | ||||
| 
 | ||||
|         Args: | ||||
|             device (str): device to use. Can be "cpu" or "cuda". Default: "cpu". | ||||
| 
 | ||||
|         Returns: | ||||
|             model (torch.nn.Module): model. | ||||
|             vis_processors (dict): preprocessors for visual inputs. | ||||
|             txt_processors (dict): preprocessors for text inputs. | ||||
|         """ | ||||
|         model, vis_processors, txt_processors = load_model_and_preprocess( | ||||
|             name="clip_feature_extractor", | ||||
|             model_type="base", | ||||
| @ -56,7 +100,18 @@ class MultimodalSearch(AnalysisMethod): | ||||
|         ) | ||||
|         return model, vis_processors, txt_processors | ||||
| 
 | ||||
|     def load_feature_extractor_model_clip_vitl14(self, device): | ||||
|     def load_feature_extractor_model_clip_vitl14(self, device: str = "cpu"): | ||||
|         """ | ||||
|         Load ViT-L-14 clip_feature_extractor model and preprocessors for visual and text inputs from lavis.models. | ||||
| 
 | ||||
|         Args: | ||||
|             device (str): device to use. Can be "cpu" or "cuda". Default: "cpu". | ||||
| 
 | ||||
|         Returns: | ||||
|             model (torch.nn.Module): model. | ||||
|             vis_processors (dict): preprocessors for visual inputs. | ||||
|             txt_processors (dict): preprocessors for text inputs. | ||||
|         """ | ||||
|         model, vis_processors, txt_processors = load_model_and_preprocess( | ||||
|             name="clip_feature_extractor", | ||||
|             model_type="ViT-L-14", | ||||
| @ -65,7 +120,18 @@ class MultimodalSearch(AnalysisMethod): | ||||
|         ) | ||||
|         return model, vis_processors, txt_processors | ||||
| 
 | ||||
|     def load_feature_extractor_model_clip_vitl14_336(self, device): | ||||
|     def load_feature_extractor_model_clip_vitl14_336(self, device: str = "cpu"): | ||||
|         """ | ||||
|         Load ViT-L-14-336 clip_feature_extractor model and preprocessors for visual and text inputs from lavis.models. | ||||
| 
 | ||||
|         Args: | ||||
|             device (str): device to use. Can be "cpu" or "cuda". Default: "cpu". | ||||
| 
 | ||||
|         Returns: | ||||
|             model (torch.nn.Module): model. | ||||
|             vis_processors (dict): preprocessors for visual inputs. | ||||
|             txt_processors (dict): preprocessors for text inputs. | ||||
|         """ | ||||
|         model, vis_processors, txt_processors = load_model_and_preprocess( | ||||
|             name="clip_feature_extractor", | ||||
|             model_type="ViT-L-14-336", | ||||
| @ -74,11 +140,31 @@ class MultimodalSearch(AnalysisMethod): | ||||
|         ) | ||||
|         return model, vis_processors, txt_processors | ||||
| 
 | ||||
|     def read_img(self, filepath): | ||||
|     def read_img(self, filepath: str) -> Image: | ||||
|         """ | ||||
|         Load Image from filepath. | ||||
| 
 | ||||
|         Args: | ||||
|             filepath (str): path to image. | ||||
| 
 | ||||
|         Returns: | ||||
|             raw_image (PIL.Image): image. | ||||
|         """ | ||||
|         raw_image = Image.open(filepath).convert("RGB") | ||||
|         return raw_image | ||||
| 
 | ||||
|     def read_and_process_images(self, image_paths, vis_processor): | ||||
|     def read_and_process_images(self, image_paths: str, vis_processor) -> tuple: | ||||
|         """ | ||||
|         Read and process images with vis_processor. | ||||
| 
 | ||||
|         Args: | ||||
|             image_paths (str): paths to images. | ||||
|             vis_processor (dict): preprocessors for visual inputs. | ||||
| 
 | ||||
|         Returns: | ||||
|             raw_images (list): list of images. | ||||
|             images_tensors (torch.Tensor): tensors of images stacked in device. | ||||
|         """ | ||||
|         raw_images = [MultimodalSearch.read_img(self, path) for path in image_paths] | ||||
|         images = [ | ||||
|             vis_processor["eval"](r_img) | ||||
| @ -90,7 +176,19 @@ class MultimodalSearch(AnalysisMethod): | ||||
| 
 | ||||
|         return raw_images, images_tensors | ||||
| 
 | ||||
|     def extract_image_features_blip2(self, model, images_tensors): | ||||
|     def extract_image_features_blip2( | ||||
|         self, model, images_tensors: torch.Tensor | ||||
|     ) -> torch.Tensor: | ||||
|         """ | ||||
|         Extract image features from images_tensors with blip2_feature_extractor model. | ||||
| 
 | ||||
|         Args: | ||||
|             model (torch.nn.Module): model. | ||||
|             images_tensors (torch.Tensor): tensors of images stacked in device. | ||||
| 
 | ||||
|         Returns: | ||||
|             features_image_stacked (torch.Tensor): tensors of images features stacked in device. | ||||
|         """ | ||||
|         with torch.cuda.amp.autocast( | ||||
|             enabled=(MultimodalSearch.multimodal_device != torch.device("cpu")) | ||||
|         ): | ||||
| @ -103,7 +201,19 @@ class MultimodalSearch(AnalysisMethod): | ||||
|             ) | ||||
|         return features_image_stacked | ||||
| 
 | ||||
|     def extract_image_features_clip(self, model, images_tensors): | ||||
|     def extract_image_features_clip( | ||||
|         self, model, images_tensors: torch.Tensor | ||||
|     ) -> torch.Tensor: | ||||
|         """ | ||||
|         Extract image features from images_tensors with clip_feature_extractor model. | ||||
| 
 | ||||
|         Args: | ||||
|             model (torch.nn.Module): model. | ||||
|             images_tensors (torch.Tensor): tensors of images stacked in device. | ||||
| 
 | ||||
|         Returns: | ||||
|             features_image_stacked (torch.Tensor): tensors of images features stacked in device. | ||||
|         """ | ||||
|         features_image = [ | ||||
|             model.extract_features({"image": ten}) for ten in images_tensors | ||||
|         ] | ||||
| @ -112,7 +222,19 @@ class MultimodalSearch(AnalysisMethod): | ||||
|         ) | ||||
|         return features_image_stacked | ||||
| 
 | ||||
|     def extract_image_features_basic(self, model, images_tensors): | ||||
|     def extract_image_features_basic( | ||||
|         self, model, images_tensors: torch.Tensor | ||||
|     ) -> torch.Tensor: | ||||
|         """ | ||||
|         Extract image features from images_tensors with blip_feature_extractor or albef_feature_extractor model. | ||||
| 
 | ||||
|         Args: | ||||
|             model (torch.nn.Module): model. | ||||
|             images_tensors (torch.Tensor): tensors of images stacked in device. | ||||
| 
 | ||||
|         Returns: | ||||
|             features_image_stacked (torch.Tensor): tensors of images features stacked in device. | ||||
|         """ | ||||
|         features_image = [ | ||||
|             model.extract_features({"image": ten, "text_input": ""}, mode="image") | ||||
|             for ten in images_tensors | ||||
| @ -124,11 +246,23 @@ class MultimodalSearch(AnalysisMethod): | ||||
| 
 | ||||
|     def save_tensors( | ||||
|         self, | ||||
|         model_type, | ||||
|         features_image_stacked, | ||||
|         name="saved_features_image.pt", | ||||
|         path="./saved_tensors/", | ||||
|     ): | ||||
|         model_type: str, | ||||
|         features_image_stacked: torch.Tensor, | ||||
|         name: str = "saved_features_image.pt", | ||||
|         path: str = "./saved_tensors/", | ||||
|     ) -> str: | ||||
|         """ | ||||
|         Save tensors as binary to given path. | ||||
| 
 | ||||
|         Args: | ||||
|             model_type (str): type of the model. | ||||
|             features_image_stacked (torch.Tensor): tensors of images features stacked in device. | ||||
|             name (str): name of the file. Default: "saved_features_image.pt". | ||||
|             path (str): path to save the file. Default: "./saved_tensors/". | ||||
| 
 | ||||
|         Returns: | ||||
|             name (str): name of the file. | ||||
|         """ | ||||
|         if not os.path.exists(path): | ||||
|             os.makedirs(path) | ||||
|         with open( | ||||
| @ -143,11 +277,30 @@ class MultimodalSearch(AnalysisMethod): | ||||
|             torch.save(features_image_stacked, f) | ||||
|         return name | ||||
| 
 | ||||
|     def load_tensors(self, name): | ||||
|     def load_tensors(self, name: str) -> torch.Tensor: | ||||
|         """ | ||||
|         Load tensors from given path. | ||||
| 
 | ||||
|         Args: | ||||
|             name (str): name of the file. | ||||
| 
 | ||||
|         Returns: | ||||
|             features_image_stacked (torch.Tensor): tensors of images features. | ||||
|         """ | ||||
|         features_image_stacked = torch.load(name) | ||||
|         return features_image_stacked | ||||
| 
 | ||||
|     def extract_text_features(self, model, text_input): | ||||
|     def extract_text_features(self, model, text_input: str) -> torch.Tensor: | ||||
|         """ | ||||
|         Extract text features from text_input with feature_extractor model. | ||||
| 
 | ||||
|         Args: | ||||
|             model (torch.nn.Module): model. | ||||
|             text_input (str): text. | ||||
| 
 | ||||
|         Returns: | ||||
|             features_text (torch.Tensor): tensors of text features. | ||||
|         """ | ||||
|         sample_text = {"text_input": [text_input]} | ||||
|         features_text = model.extract_features(sample_text, mode="text") | ||||
| 
 | ||||
| @ -155,10 +308,26 @@ class MultimodalSearch(AnalysisMethod): | ||||
| 
 | ||||
|     def parsing_images( | ||||
|         self, | ||||
|         model_type, | ||||
|         path_to_saved_tensors="./saved_tensors/", | ||||
|         path_to_load_tensors=None, | ||||
|     ): | ||||
|         model_type: str, | ||||
|         path_to_save_tensors: str = "./saved_tensors/", | ||||
|         path_to_load_tensors: str = None, | ||||
|     ) -> tuple: | ||||
|         """ | ||||
|         Parsing images with feature_extractor model. | ||||
| 
 | ||||
|         Args: | ||||
|             model_type (str): type of the model. | ||||
|             path_to_save_tensors (str): path to save the tensors. Default: "./saved_tensors/". | ||||
|             path_to_load_tensors (str): path to load the tesors. Default: None. | ||||
| 
 | ||||
|         Returns: | ||||
|             model (torch.nn.Module): model. | ||||
|             vis_processors (dict): preprocessors for visual inputs. | ||||
|             txt_processors (dict): preprocessors for text inputs. | ||||
|             image_keys (list): sorted list of image keys. | ||||
|             image_names (list): sorted list of image names. | ||||
|             features_image_stacked (torch.Tensor): tensors of images features stacked in device. | ||||
|         """ | ||||
|         if model_type in ("clip_base", "clip_vitl14_336", "clip_vitl14"): | ||||
|             path_to_lib = lavis.__file__[:-11] + "models/clip_models/" | ||||
|             url = "https://raw.githubusercontent.com/salesforce/LAVIS/main/lavis/models/clip_models/bpe_simple_vocab_16e6.txt.gz" | ||||
| @ -208,7 +377,7 @@ class MultimodalSearch(AnalysisMethod): | ||||
|                     self, model, images_tensors | ||||
|                 ) | ||||
|             MultimodalSearch.save_tensors( | ||||
|                 self, model_type, features_image_stacked, path=path_to_saved_tensors | ||||
|                 self, model_type, features_image_stacked, path=path_to_save_tensors | ||||
|             ) | ||||
|         else: | ||||
|             features_image_stacked = MultimodalSearch.load_tensors( | ||||
| @ -225,8 +394,21 @@ class MultimodalSearch(AnalysisMethod): | ||||
|         ) | ||||
| 
 | ||||
|     def querys_processing( | ||||
|         self, search_query, model, txt_processors, vis_processors, model_type | ||||
|     ): | ||||
|         self, search_query: list, model, txt_processors, vis_processors, model_type: str | ||||
|     ) -> torch.Tensor: | ||||
|         """ | ||||
|         Process querys. | ||||
| 
 | ||||
|         Args: | ||||
|             search_query (list): list of querys. | ||||
|             model (torch.nn.Module): model. | ||||
|             txt_processors (dict): preprocessors for text inputs. | ||||
|             vis_processors (dict): preprocessors for visual inputs. | ||||
|             model_type (str): type of the model. | ||||
| 
 | ||||
|         Returns: | ||||
|             multi_features_stacked (torch.Tensor): tensors of querys features. | ||||
|         """ | ||||
|         select_extract_image_features = { | ||||
|             "blip2": MultimodalSearch.extract_image_features_blip2, | ||||
|             "blip": MultimodalSearch.extract_image_features_basic, | ||||
| @ -295,14 +477,33 @@ class MultimodalSearch(AnalysisMethod): | ||||
|         model, | ||||
|         vis_processors, | ||||
|         txt_processors, | ||||
|         model_type, | ||||
|         image_keys, | ||||
|         features_image_stacked, | ||||
|         search_query, | ||||
|         filter_number_of_images=None, | ||||
|         filter_val_limit=None, | ||||
|         filter_rel_error=None, | ||||
|     ): | ||||
|         model_type: str, | ||||
|         image_keys: list, | ||||
|         features_image_stacked: torch.Tensor, | ||||
|         search_query: list, | ||||
|         filter_number_of_images: str = None, | ||||
|         filter_val_limit: str = None, | ||||
|         filter_rel_error: str = None, | ||||
|     ) -> tuple: | ||||
|         """ | ||||
|         Search for images with given querys. | ||||
| 
 | ||||
|         Args: | ||||
|             model (torch.nn.Module): model. | ||||
|             vis_processors (dict): preprocessors for visual inputs. | ||||
|             txt_processors (dict): preprocessors for text inputs. | ||||
|             model_type (str): type of the model. | ||||
|             image_keys (list): sorted list of image keys. | ||||
|             features_image_stacked (torch.Tensor): tensors of images features stacked in device. | ||||
|             search_query (list): list of querys. | ||||
|             filter_number_of_images (str): number of images to show. Default: None. | ||||
|             filter_val_limit (str): limit of similarity value. Default: None. | ||||
|             filter_rel_error (str): limit of relative error. Default: None. | ||||
| 
 | ||||
|         Returns: | ||||
|             similarity (torch.Tensor): similarity between images and querys. | ||||
|             sorted_lists (list): sorted list of similarity. | ||||
|         """ | ||||
|         if filter_number_of_images is None: | ||||
|             filter_number_of_images = len(self) | ||||
|         if filter_val_limit is None: | ||||
| @ -343,7 +544,16 @@ class MultimodalSearch(AnalysisMethod): | ||||
|                     self[image_keys[key]][list(search_query[q].values())[0]] = 0 | ||||
|         return similarity, sorted_lists | ||||
| 
 | ||||
|     def itm_text_precessing(self, search_query): | ||||
|     def itm_text_precessing(self, search_query: list[dict[str, str]]) -> list: | ||||
|         """ | ||||
|         Process text querys for itm model. | ||||
| 
 | ||||
|         Args: | ||||
|             search_query (list): list of querys. | ||||
| 
 | ||||
|         Returns: | ||||
|             text_query_index (list): list of indexes of text querys. | ||||
|         """ | ||||
|         for query in search_query: | ||||
|             if (len(query) != 1) and (query in ("image", "text_input")): | ||||
|                 raise SyntaxError( | ||||
| @ -356,7 +566,17 @@ class MultimodalSearch(AnalysisMethod): | ||||
| 
 | ||||
|         return text_query_index | ||||
| 
 | ||||
|     def get_pathes_from_query(self, query): | ||||
|     def get_pathes_from_query(self, query: dict[str, str]) -> tuple: | ||||
|         """ | ||||
|         Get pathes and image names from query. | ||||
| 
 | ||||
|         Args: | ||||
|             query (dict): query. | ||||
| 
 | ||||
|         Returns: | ||||
|             paths (list): list of pathes. | ||||
|             image_names (list): list of image names. | ||||
|         """ | ||||
|         paths = [] | ||||
|         image_names = [] | ||||
|         for s in sorted( | ||||
| @ -368,7 +588,18 @@ class MultimodalSearch(AnalysisMethod): | ||||
|             image_names.append(s[0]) | ||||
|         return paths, image_names | ||||
| 
 | ||||
|     def read_and_process_images_itm(self, image_paths, vis_processor): | ||||
|     def read_and_process_images_itm(self, image_paths: list, vis_processor) -> tuple: | ||||
|         """ | ||||
|         Read and process images with vis_processor for itm model. | ||||
| 
 | ||||
|         Args: | ||||
|             image_paths (list): paths to images. | ||||
|             vis_processor (dict): preprocessors for visual inputs. | ||||
| 
 | ||||
|         Returns: | ||||
|             raw_images (list): list of images. | ||||
|             images_tensors (torch.Tensor): tensors of images stacked in device. | ||||
|         """ | ||||
|         raw_images = [MultimodalSearch.read_img(self, path) for path in image_paths] | ||||
|         images = [vis_processor(r_img) for r_img in raw_images] | ||||
|         images_tensors = torch.stack(images).to(MultimodalSearch.multimodal_device) | ||||
| @ -377,12 +608,26 @@ class MultimodalSearch(AnalysisMethod): | ||||
| 
 | ||||
|     def compute_gradcam_batch( | ||||
|         self, | ||||
|         model, | ||||
|         visual_input, | ||||
|         text_input, | ||||
|         tokenized_text, | ||||
|         block_num=6, | ||||
|     ): | ||||
|         model: torch.nn.Module, | ||||
|         visual_input: torch.Tensor, | ||||
|         text_input: str, | ||||
|         tokenized_text: torch.Tensor, | ||||
|         block_num: str = 6, | ||||
|     ) -> tuple: | ||||
|         """ | ||||
|         Compute gradcam for itm model. | ||||
| 
 | ||||
|         Args: | ||||
|             model (torch.nn.Module): model. | ||||
|             visual_input (torch.Tensor): tensors of images features stacked in device. | ||||
|             text_input (str): text. | ||||
|             tokenized_text (torch.Tensor): tokenized text. | ||||
|             block_num (int): number of block. Default: 6. | ||||
| 
 | ||||
|         Returns: | ||||
|             gradcam (torch.Tensor): gradcam. | ||||
|             output (torch.Tensor): output of model. | ||||
|         """ | ||||
|         model.text_encoder.base_model.base_model.encoder.layer[ | ||||
|             block_num | ||||
|         ].crossattention.self.save_attention = True | ||||
| @ -430,7 +675,16 @@ class MultimodalSearch(AnalysisMethod): | ||||
| 
 | ||||
|         return gradcam, output | ||||
| 
 | ||||
|     def resize_img(self, raw_img): | ||||
|     def resize_img(self, raw_img: Image): | ||||
|         """ | ||||
|         Proportional resize image to 240 p width. | ||||
| 
 | ||||
|         Args: | ||||
|             raw_img (PIL.Image): image. | ||||
| 
 | ||||
|         Returns: | ||||
|             resized_image (PIL.Image): proportional resized image to 240p. | ||||
|         """ | ||||
|         w, h = raw_img.size | ||||
|         scaling_factor = 240 / w | ||||
|         resized_image = raw_img.resize( | ||||
| @ -438,7 +692,25 @@ class MultimodalSearch(AnalysisMethod): | ||||
|         ) | ||||
|         return resized_image | ||||
| 
 | ||||
|     def get_att_map(self, img, att_map, blur=True, overlap=True): | ||||
|     def get_att_map( | ||||
|         self, | ||||
|         img: np.ndarray, | ||||
|         att_map: np.ndarray, | ||||
|         blur: bool = True, | ||||
|         overlap: bool = True, | ||||
|     ) -> np.ndarray: | ||||
|         """ | ||||
|         Get attention map. | ||||
| 
 | ||||
|         Args: | ||||
|             img (np.ndarray): image. | ||||
|             att_map (np.ndarray): attention map. | ||||
|             blur (bool): blur attention map. Default: True. | ||||
|             overlap (bool): overlap attention map with image. Default: True. | ||||
| 
 | ||||
|         Returns: | ||||
|             att_map (np.ndarray): attention map. | ||||
|         """ | ||||
|         att_map -= att_map.min() | ||||
|         if att_map.max() > 0: | ||||
|             att_map /= att_map.max() | ||||
| @ -459,7 +731,14 @@ class MultimodalSearch(AnalysisMethod): | ||||
|             ) | ||||
|         return att_map | ||||
| 
 | ||||
|     def upload_model_blip2_coco(self): | ||||
|     def upload_model_blip2_coco(self) -> tuple: | ||||
|         """ | ||||
|         Load coco blip2_image_text_matching model and preprocessors for visual inputs from lavis.models. | ||||
| 
 | ||||
|         Returns: | ||||
|             itm_model (torch.nn.Module): model. | ||||
|             vis_processor (dict): preprocessors for visual inputs. | ||||
|         """ | ||||
|         itm_model = load_model( | ||||
|             "blip2_image_text_matching", | ||||
|             "coco", | ||||
| @ -469,7 +748,14 @@ class MultimodalSearch(AnalysisMethod): | ||||
|         vis_processor = load_processor("blip_image_eval").build(image_size=364) | ||||
|         return itm_model, vis_processor | ||||
| 
 | ||||
|     def upload_model_blip_base(self): | ||||
|     def upload_model_blip_base(self) -> tuple: | ||||
|         """ | ||||
|         Load base blip_image_text_matching model and preprocessors for visual input from lavis.models. | ||||
| 
 | ||||
|         Returns: | ||||
|             itm_model (torch.nn.Module): model. | ||||
|             vis_processor (dict): preprocessors for visual inputs. | ||||
|         """ | ||||
|         itm_model = load_model( | ||||
|             "blip_image_text_matching", | ||||
|             "base", | ||||
| @ -479,7 +765,14 @@ class MultimodalSearch(AnalysisMethod): | ||||
|         vis_processor = load_processor("blip_image_eval").build(image_size=384) | ||||
|         return itm_model, vis_processor | ||||
| 
 | ||||
|     def upload_model_blip_large(self): | ||||
|     def upload_model_blip_large(self) -> tuple: | ||||
|         """ | ||||
|         Load large blip_image_text_matching model and preprocessors for visual input from lavis.models. | ||||
| 
 | ||||
|         Returns: | ||||
|             itm_model (torch.nn.Module): model. | ||||
|             vis_processor (dict): preprocessors for visual inputs. | ||||
|         """ | ||||
|         itm_model = load_model( | ||||
|             "blip_image_text_matching", | ||||
|             "large", | ||||
| @ -491,13 +784,28 @@ class MultimodalSearch(AnalysisMethod): | ||||
| 
 | ||||
|     def image_text_match_reordering( | ||||
|         self, | ||||
|         search_query, | ||||
|         itm_model_type, | ||||
|         image_keys, | ||||
|         sorted_lists, | ||||
|         batch_size=1, | ||||
|         need_grad_cam=False, | ||||
|     ): | ||||
|         search_query: list[dict[str, str]], | ||||
|         itm_model_type: str, | ||||
|         image_keys: list, | ||||
|         sorted_lists: list[list], | ||||
|         batch_size: int = 1, | ||||
|         need_grad_cam: bool = False, | ||||
|     ) -> tuple: | ||||
|         """ | ||||
|         Reorder images with itm model. | ||||
| 
 | ||||
|         Args: | ||||
|             search_query (list): list of querys. | ||||
|             itm_model_type (str): type of the model. | ||||
|             image_keys (list): sorted list of image keys. | ||||
|             sorted_lists (list): sorted list of similarity. | ||||
|             batch_size (int): batch size. Default: 1. | ||||
|             need_grad_cam (bool): need gradcam. Default: False. blip2_coco model does not yet work with gradcam. | ||||
| 
 | ||||
|         Returns: | ||||
|             itm_scores2: list of itm scores. | ||||
|             image_gradcam_with_itm: dict of image names and gradcam. | ||||
|         """ | ||||
|         if itm_model_type == "blip2_coco" and need_grad_cam is True: | ||||
|             raise SyntaxError( | ||||
|                 "The blip2_coco model does not yet work with gradcam. Please set need_grad_cam to False" | ||||
| @ -624,7 +932,17 @@ class MultimodalSearch(AnalysisMethod): | ||||
|         torch.cuda.empty_cache() | ||||
|         return itm_scores2, image_gradcam_with_itm | ||||
| 
 | ||||
|     def show_results(self, query, itm=False, image_gradcam_with_itm=False): | ||||
|     def show_results( | ||||
|         self, query: dict, itm=False, image_gradcam_with_itm=False | ||||
|     ) -> None: | ||||
|         """ | ||||
|         Show results of search. | ||||
| 
 | ||||
|         Args: | ||||
|             query (dict): query. | ||||
|             itm (bool): use itm model. Default: False. | ||||
|             image_gradcam_with_itm (bool): use gradcam. Default: False. | ||||
|         """ | ||||
|         if "image" in query.keys(): | ||||
|             pic = Image.open(query["image"]).convert("RGB") | ||||
|             pic.thumbnail((400, 400)) | ||||
|  | ||||
		Загрузка…
	
	
			
			x
			
			
		
	
		Ссылка в новой задаче
	
	Block a user
	 Petr Andriushchenko
						Petr Andriushchenko