diff --git a/ammico/multimodal_search.py b/ammico/multimodal_search.py index 4cfd1b2..bec8f50 100644 --- a/ammico/multimodal_search.py +++ b/ammico/multimodal_search.py @@ -20,7 +20,18 @@ class MultimodalSearch(AnalysisMethod): multimodal_device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - def load_feature_extractor_model_blip2(self, device): + def load_feature_extractor_model_blip2(self, device: str = "cpu"): + """ + Load pretrain blip2_feature_extractor model and preprocessors for visual and text inputs from lavis.models. + + Args: + device (str): device to use. Can be "cpu" or "cuda". Default: "cpu". + + Returns: + model (torch.nn.Module): model. + vis_processors (dict): preprocessors for visual inputs. + txt_processors (dict): preprocessors for text inputs. + """ model, vis_processors, txt_processors = load_model_and_preprocess( name="blip2_feature_extractor", model_type="pretrain", @@ -29,7 +40,18 @@ class MultimodalSearch(AnalysisMethod): ) return model, vis_processors, txt_processors - def load_feature_extractor_model_blip(self, device): + def load_feature_extractor_model_blip(self, device: str = "cpu"): + """ + Load base blip_feature_extractor model and preprocessors for visual and text inputs from lavis.models. + + Args: + device (str): device to use. Can be "cpu" or "cuda". Default: "cpu". + + Returns: + model (torch.nn.Module): model. + vis_processors (dict): preprocessors for visual inputs. + txt_processors (dict): preprocessors for text inputs. + """ model, vis_processors, txt_processors = load_model_and_preprocess( name="blip_feature_extractor", model_type="base", @@ -38,7 +60,18 @@ class MultimodalSearch(AnalysisMethod): ) return model, vis_processors, txt_processors - def load_feature_extractor_model_albef(self, device): + def load_feature_extractor_model_albef(self, device: str = "cpu"): + """ + Load base albef_feature_extractor model and preprocessors for visual and text inputs from lavis.models. + + Args: + device (str): device to use. Can be "cpu" or "cuda". Default: "cpu". + + Returns: + model (torch.nn.Module): model. + vis_processors (dict): preprocessors for visual inputs. + txt_processors (dict): preprocessors for text inputs. + """ model, vis_processors, txt_processors = load_model_and_preprocess( name="albef_feature_extractor", model_type="base", @@ -47,7 +80,18 @@ class MultimodalSearch(AnalysisMethod): ) return model, vis_processors, txt_processors - def load_feature_extractor_model_clip_base(self, device): + def load_feature_extractor_model_clip_base(self, device: str = "cpu"): + """ + Load base clip_feature_extractor model and preprocessors for visual and text inputs from lavis.models. + + Args: + device (str): device to use. Can be "cpu" or "cuda". Default: "cpu". + + Returns: + model (torch.nn.Module): model. + vis_processors (dict): preprocessors for visual inputs. + txt_processors (dict): preprocessors for text inputs. + """ model, vis_processors, txt_processors = load_model_and_preprocess( name="clip_feature_extractor", model_type="base", @@ -56,7 +100,18 @@ class MultimodalSearch(AnalysisMethod): ) return model, vis_processors, txt_processors - def load_feature_extractor_model_clip_vitl14(self, device): + def load_feature_extractor_model_clip_vitl14(self, device: str = "cpu"): + """ + Load ViT-L-14 clip_feature_extractor model and preprocessors for visual and text inputs from lavis.models. + + Args: + device (str): device to use. Can be "cpu" or "cuda". Default: "cpu". + + Returns: + model (torch.nn.Module): model. + vis_processors (dict): preprocessors for visual inputs. + txt_processors (dict): preprocessors for text inputs. + """ model, vis_processors, txt_processors = load_model_and_preprocess( name="clip_feature_extractor", model_type="ViT-L-14", @@ -65,7 +120,18 @@ class MultimodalSearch(AnalysisMethod): ) return model, vis_processors, txt_processors - def load_feature_extractor_model_clip_vitl14_336(self, device): + def load_feature_extractor_model_clip_vitl14_336(self, device: str = "cpu"): + """ + Load ViT-L-14-336 clip_feature_extractor model and preprocessors for visual and text inputs from lavis.models. + + Args: + device (str): device to use. Can be "cpu" or "cuda". Default: "cpu". + + Returns: + model (torch.nn.Module): model. + vis_processors (dict): preprocessors for visual inputs. + txt_processors (dict): preprocessors for text inputs. + """ model, vis_processors, txt_processors = load_model_and_preprocess( name="clip_feature_extractor", model_type="ViT-L-14-336", @@ -74,11 +140,31 @@ class MultimodalSearch(AnalysisMethod): ) return model, vis_processors, txt_processors - def read_img(self, filepath): + def read_img(self, filepath: str) -> Image: + """ + Load Image from filepath. + + Args: + filepath (str): path to image. + + Returns: + raw_image (PIL.Image): image. + """ raw_image = Image.open(filepath).convert("RGB") return raw_image - def read_and_process_images(self, image_paths, vis_processor): + def read_and_process_images(self, image_paths: str, vis_processor) -> tuple: + """ + Read and process images with vis_processor. + + Args: + image_paths (str): paths to images. + vis_processor (dict): preprocessors for visual inputs. + + Returns: + raw_images (list): list of images. + images_tensors (torch.Tensor): tensors of images stacked in device. + """ raw_images = [MultimodalSearch.read_img(self, path) for path in image_paths] images = [ vis_processor["eval"](r_img) @@ -90,7 +176,19 @@ class MultimodalSearch(AnalysisMethod): return raw_images, images_tensors - def extract_image_features_blip2(self, model, images_tensors): + def extract_image_features_blip2( + self, model, images_tensors: torch.Tensor + ) -> torch.Tensor: + """ + Extract image features from images_tensors with blip2_feature_extractor model. + + Args: + model (torch.nn.Module): model. + images_tensors (torch.Tensor): tensors of images stacked in device. + + Returns: + features_image_stacked (torch.Tensor): tensors of images features stacked in device. + """ with torch.cuda.amp.autocast( enabled=(MultimodalSearch.multimodal_device != torch.device("cpu")) ): @@ -103,7 +201,19 @@ class MultimodalSearch(AnalysisMethod): ) return features_image_stacked - def extract_image_features_clip(self, model, images_tensors): + def extract_image_features_clip( + self, model, images_tensors: torch.Tensor + ) -> torch.Tensor: + """ + Extract image features from images_tensors with clip_feature_extractor model. + + Args: + model (torch.nn.Module): model. + images_tensors (torch.Tensor): tensors of images stacked in device. + + Returns: + features_image_stacked (torch.Tensor): tensors of images features stacked in device. + """ features_image = [ model.extract_features({"image": ten}) for ten in images_tensors ] @@ -112,7 +222,19 @@ class MultimodalSearch(AnalysisMethod): ) return features_image_stacked - def extract_image_features_basic(self, model, images_tensors): + def extract_image_features_basic( + self, model, images_tensors: torch.Tensor + ) -> torch.Tensor: + """ + Extract image features from images_tensors with blip_feature_extractor or albef_feature_extractor model. + + Args: + model (torch.nn.Module): model. + images_tensors (torch.Tensor): tensors of images stacked in device. + + Returns: + features_image_stacked (torch.Tensor): tensors of images features stacked in device. + """ features_image = [ model.extract_features({"image": ten, "text_input": ""}, mode="image") for ten in images_tensors @@ -124,11 +246,23 @@ class MultimodalSearch(AnalysisMethod): def save_tensors( self, - model_type, - features_image_stacked, - name="saved_features_image.pt", - path="./saved_tensors/", - ): + model_type: str, + features_image_stacked: torch.Tensor, + name: str = "saved_features_image.pt", + path: str = "./saved_tensors/", + ) -> str: + """ + Save tensors as binary to given path. + + Args: + model_type (str): type of the model. + features_image_stacked (torch.Tensor): tensors of images features stacked in device. + name (str): name of the file. Default: "saved_features_image.pt". + path (str): path to save the file. Default: "./saved_tensors/". + + Returns: + name (str): name of the file. + """ if not os.path.exists(path): os.makedirs(path) with open( @@ -143,11 +277,30 @@ class MultimodalSearch(AnalysisMethod): torch.save(features_image_stacked, f) return name - def load_tensors(self, name): + def load_tensors(self, name: str) -> torch.Tensor: + """ + Load tensors from given path. + + Args: + name (str): name of the file. + + Returns: + features_image_stacked (torch.Tensor): tensors of images features. + """ features_image_stacked = torch.load(name) return features_image_stacked - def extract_text_features(self, model, text_input): + def extract_text_features(self, model, text_input: str) -> torch.Tensor: + """ + Extract text features from text_input with feature_extractor model. + + Args: + model (torch.nn.Module): model. + text_input (str): text. + + Returns: + features_text (torch.Tensor): tensors of text features. + """ sample_text = {"text_input": [text_input]} features_text = model.extract_features(sample_text, mode="text") @@ -155,10 +308,26 @@ class MultimodalSearch(AnalysisMethod): def parsing_images( self, - model_type, - path_to_saved_tensors="./saved_tensors/", - path_to_load_tensors=None, - ): + model_type: str, + path_to_save_tensors: str = "./saved_tensors/", + path_to_load_tensors: str = None, + ) -> tuple: + """ + Parsing images with feature_extractor model. + + Args: + model_type (str): type of the model. + path_to_save_tensors (str): path to save the tensors. Default: "./saved_tensors/". + path_to_load_tensors (str): path to load the tesors. Default: None. + + Returns: + model (torch.nn.Module): model. + vis_processors (dict): preprocessors for visual inputs. + txt_processors (dict): preprocessors for text inputs. + image_keys (list): sorted list of image keys. + image_names (list): sorted list of image names. + features_image_stacked (torch.Tensor): tensors of images features stacked in device. + """ if model_type in ("clip_base", "clip_vitl14_336", "clip_vitl14"): path_to_lib = lavis.__file__[:-11] + "models/clip_models/" url = "https://raw.githubusercontent.com/salesforce/LAVIS/main/lavis/models/clip_models/bpe_simple_vocab_16e6.txt.gz" @@ -208,7 +377,7 @@ class MultimodalSearch(AnalysisMethod): self, model, images_tensors ) MultimodalSearch.save_tensors( - self, model_type, features_image_stacked, path=path_to_saved_tensors + self, model_type, features_image_stacked, path=path_to_save_tensors ) else: features_image_stacked = MultimodalSearch.load_tensors( @@ -225,8 +394,21 @@ class MultimodalSearch(AnalysisMethod): ) def querys_processing( - self, search_query, model, txt_processors, vis_processors, model_type - ): + self, search_query: list, model, txt_processors, vis_processors, model_type: str + ) -> torch.Tensor: + """ + Process querys. + + Args: + search_query (list): list of querys. + model (torch.nn.Module): model. + txt_processors (dict): preprocessors for text inputs. + vis_processors (dict): preprocessors for visual inputs. + model_type (str): type of the model. + + Returns: + multi_features_stacked (torch.Tensor): tensors of querys features. + """ select_extract_image_features = { "blip2": MultimodalSearch.extract_image_features_blip2, "blip": MultimodalSearch.extract_image_features_basic, @@ -295,14 +477,33 @@ class MultimodalSearch(AnalysisMethod): model, vis_processors, txt_processors, - model_type, - image_keys, - features_image_stacked, - search_query, - filter_number_of_images=None, - filter_val_limit=None, - filter_rel_error=None, - ): + model_type: str, + image_keys: list, + features_image_stacked: torch.Tensor, + search_query: list, + filter_number_of_images: str = None, + filter_val_limit: str = None, + filter_rel_error: str = None, + ) -> tuple: + """ + Search for images with given querys. + + Args: + model (torch.nn.Module): model. + vis_processors (dict): preprocessors for visual inputs. + txt_processors (dict): preprocessors for text inputs. + model_type (str): type of the model. + image_keys (list): sorted list of image keys. + features_image_stacked (torch.Tensor): tensors of images features stacked in device. + search_query (list): list of querys. + filter_number_of_images (str): number of images to show. Default: None. + filter_val_limit (str): limit of similarity value. Default: None. + filter_rel_error (str): limit of relative error. Default: None. + + Returns: + similarity (torch.Tensor): similarity between images and querys. + sorted_lists (list): sorted list of similarity. + """ if filter_number_of_images is None: filter_number_of_images = len(self) if filter_val_limit is None: @@ -343,7 +544,16 @@ class MultimodalSearch(AnalysisMethod): self[image_keys[key]][list(search_query[q].values())[0]] = 0 return similarity, sorted_lists - def itm_text_precessing(self, search_query): + def itm_text_precessing(self, search_query: list[dict[str, str]]) -> list: + """ + Process text querys for itm model. + + Args: + search_query (list): list of querys. + + Returns: + text_query_index (list): list of indexes of text querys. + """ for query in search_query: if (len(query) != 1) and (query in ("image", "text_input")): raise SyntaxError( @@ -356,7 +566,17 @@ class MultimodalSearch(AnalysisMethod): return text_query_index - def get_pathes_from_query(self, query): + def get_pathes_from_query(self, query: dict[str, str]) -> tuple: + """ + Get pathes and image names from query. + + Args: + query (dict): query. + + Returns: + paths (list): list of pathes. + image_names (list): list of image names. + """ paths = [] image_names = [] for s in sorted( @@ -368,7 +588,18 @@ class MultimodalSearch(AnalysisMethod): image_names.append(s[0]) return paths, image_names - def read_and_process_images_itm(self, image_paths, vis_processor): + def read_and_process_images_itm(self, image_paths: list, vis_processor) -> tuple: + """ + Read and process images with vis_processor for itm model. + + Args: + image_paths (list): paths to images. + vis_processor (dict): preprocessors for visual inputs. + + Returns: + raw_images (list): list of images. + images_tensors (torch.Tensor): tensors of images stacked in device. + """ raw_images = [MultimodalSearch.read_img(self, path) for path in image_paths] images = [vis_processor(r_img) for r_img in raw_images] images_tensors = torch.stack(images).to(MultimodalSearch.multimodal_device) @@ -377,12 +608,26 @@ class MultimodalSearch(AnalysisMethod): def compute_gradcam_batch( self, - model, - visual_input, - text_input, - tokenized_text, - block_num=6, - ): + model: torch.nn.Module, + visual_input: torch.Tensor, + text_input: str, + tokenized_text: torch.Tensor, + block_num: str = 6, + ) -> tuple: + """ + Compute gradcam for itm model. + + Args: + model (torch.nn.Module): model. + visual_input (torch.Tensor): tensors of images features stacked in device. + text_input (str): text. + tokenized_text (torch.Tensor): tokenized text. + block_num (int): number of block. Default: 6. + + Returns: + gradcam (torch.Tensor): gradcam. + output (torch.Tensor): output of model. + """ model.text_encoder.base_model.base_model.encoder.layer[ block_num ].crossattention.self.save_attention = True @@ -430,7 +675,16 @@ class MultimodalSearch(AnalysisMethod): return gradcam, output - def resize_img(self, raw_img): + def resize_img(self, raw_img: Image): + """ + Proportional resize image to 240 p width. + + Args: + raw_img (PIL.Image): image. + + Returns: + resized_image (PIL.Image): proportional resized image to 240p. + """ w, h = raw_img.size scaling_factor = 240 / w resized_image = raw_img.resize( @@ -438,7 +692,25 @@ class MultimodalSearch(AnalysisMethod): ) return resized_image - def get_att_map(self, img, att_map, blur=True, overlap=True): + def get_att_map( + self, + img: np.ndarray, + att_map: np.ndarray, + blur: bool = True, + overlap: bool = True, + ) -> np.ndarray: + """ + Get attention map. + + Args: + img (np.ndarray): image. + att_map (np.ndarray): attention map. + blur (bool): blur attention map. Default: True. + overlap (bool): overlap attention map with image. Default: True. + + Returns: + att_map (np.ndarray): attention map. + """ att_map -= att_map.min() if att_map.max() > 0: att_map /= att_map.max() @@ -459,7 +731,14 @@ class MultimodalSearch(AnalysisMethod): ) return att_map - def upload_model_blip2_coco(self): + def upload_model_blip2_coco(self) -> tuple: + """ + Load coco blip2_image_text_matching model and preprocessors for visual inputs from lavis.models. + + Returns: + itm_model (torch.nn.Module): model. + vis_processor (dict): preprocessors for visual inputs. + """ itm_model = load_model( "blip2_image_text_matching", "coco", @@ -469,7 +748,14 @@ class MultimodalSearch(AnalysisMethod): vis_processor = load_processor("blip_image_eval").build(image_size=364) return itm_model, vis_processor - def upload_model_blip_base(self): + def upload_model_blip_base(self) -> tuple: + """ + Load base blip_image_text_matching model and preprocessors for visual input from lavis.models. + + Returns: + itm_model (torch.nn.Module): model. + vis_processor (dict): preprocessors for visual inputs. + """ itm_model = load_model( "blip_image_text_matching", "base", @@ -479,7 +765,14 @@ class MultimodalSearch(AnalysisMethod): vis_processor = load_processor("blip_image_eval").build(image_size=384) return itm_model, vis_processor - def upload_model_blip_large(self): + def upload_model_blip_large(self) -> tuple: + """ + Load large blip_image_text_matching model and preprocessors for visual input from lavis.models. + + Returns: + itm_model (torch.nn.Module): model. + vis_processor (dict): preprocessors for visual inputs. + """ itm_model = load_model( "blip_image_text_matching", "large", @@ -491,13 +784,28 @@ class MultimodalSearch(AnalysisMethod): def image_text_match_reordering( self, - search_query, - itm_model_type, - image_keys, - sorted_lists, - batch_size=1, - need_grad_cam=False, - ): + search_query: list[dict[str, str]], + itm_model_type: str, + image_keys: list, + sorted_lists: list[list], + batch_size: int = 1, + need_grad_cam: bool = False, + ) -> tuple: + """ + Reorder images with itm model. + + Args: + search_query (list): list of querys. + itm_model_type (str): type of the model. + image_keys (list): sorted list of image keys. + sorted_lists (list): sorted list of similarity. + batch_size (int): batch size. Default: 1. + need_grad_cam (bool): need gradcam. Default: False. blip2_coco model does not yet work with gradcam. + + Returns: + itm_scores2: list of itm scores. + image_gradcam_with_itm: dict of image names and gradcam. + """ if itm_model_type == "blip2_coco" and need_grad_cam is True: raise SyntaxError( "The blip2_coco model does not yet work with gradcam. Please set need_grad_cam to False" @@ -624,7 +932,17 @@ class MultimodalSearch(AnalysisMethod): torch.cuda.empty_cache() return itm_scores2, image_gradcam_with_itm - def show_results(self, query, itm=False, image_gradcam_with_itm=False): + def show_results( + self, query: dict, itm=False, image_gradcam_with_itm=False + ) -> None: + """ + Show results of search. + + Args: + query (dict): query. + itm (bool): use itm model. Default: False. + image_gradcam_with_itm (bool): use gradcam. Default: False. + """ if "image" in query.keys(): pic = Image.open(query["image"]).convert("RGB") pic.thumbnail((400, 400))