From f2c97e26ff42f817fe87d4a1206a4d3ce003a653 Mon Sep 17 00:00:00 2001 From: Inga Ulusoy Date: Mon, 8 Sep 2025 15:21:49 +0200 Subject: [PATCH] maintain: remove summary module (VQA) --- ammico/multimodal_search.py | 992 -------------------------- ammico/test/test_multimodal_search.py | 521 -------------- 2 files changed, 1513 deletions(-) delete mode 100644 ammico/multimodal_search.py delete mode 100644 ammico/test/test_multimodal_search.py diff --git a/ammico/multimodal_search.py b/ammico/multimodal_search.py deleted file mode 100644 index 864a59c..0000000 --- a/ammico/multimodal_search.py +++ /dev/null @@ -1,992 +0,0 @@ -from ammico.utils import AnalysisMethod -import torch -import torch.nn.functional as Func -import requests -import lavis -import os -import numpy as np -from PIL import Image -from skimage import transform as skimage_transform -from scipy.ndimage import filters -from matplotlib import pyplot as plt -from IPython.display import display -from lavis.models import load_model_and_preprocess, load_model, BlipBase -from lavis.processors import load_processor - - -class MultimodalSearch(AnalysisMethod): - def __init__(self, subdict: dict) -> None: - super().__init__(subdict) - - multimodal_device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - - def load_feature_extractor_model_blip2(self, device: str = "cpu"): - """ - Load pretrain blip2_feature_extractor model and preprocessors for visual and text inputs from lavis.models. - - Args: - device (str): device to use. Can be "cpu" or "cuda". Default: "cpu". - - Returns: - model (torch.nn.Module): model. - vis_processors (dict): preprocessors for visual inputs. - txt_processors (dict): preprocessors for text inputs. - """ - model, vis_processors, txt_processors = load_model_and_preprocess( - name="blip2_feature_extractor", - model_type="pretrain", - is_eval=True, - device=device, - ) - return model, vis_processors, txt_processors - - def load_feature_extractor_model_blip(self, device: str = "cpu"): - """ - Load base blip_feature_extractor model and preprocessors for visual and text inputs from lavis.models. - - Args: - device (str): device to use. Can be "cpu" or "cuda". Default: "cpu". - - Returns: - model (torch.nn.Module): model. - vis_processors (dict): preprocessors for visual inputs. - txt_processors (dict): preprocessors for text inputs. - """ - model, vis_processors, txt_processors = load_model_and_preprocess( - name="blip_feature_extractor", - model_type="base", - is_eval=True, - device=device, - ) - return model, vis_processors, txt_processors - - def load_feature_extractor_model_albef(self, device: str = "cpu"): - """ - Load base albef_feature_extractor model and preprocessors for visual and text inputs from lavis.models. - - Args: - device (str): device to use. Can be "cpu" or "cuda". Default: "cpu". - - Returns: - model (torch.nn.Module): model. - vis_processors (dict): preprocessors for visual inputs. - txt_processors (dict): preprocessors for text inputs. - """ - model, vis_processors, txt_processors = load_model_and_preprocess( - name="albef_feature_extractor", - model_type="base", - is_eval=True, - device=device, - ) - return model, vis_processors, txt_processors - - def load_feature_extractor_model_clip_base(self, device: str = "cpu"): - """ - Load base clip_feature_extractor model and preprocessors for visual and text inputs from lavis.models. - - Args: - device (str): device to use. Can be "cpu" or "cuda". Default: "cpu". - - Returns: - model (torch.nn.Module): model. - vis_processors (dict): preprocessors for visual inputs. - txt_processors (dict): preprocessors for text inputs. - """ - model, vis_processors, txt_processors = load_model_and_preprocess( - name="clip_feature_extractor", - model_type="base", - is_eval=True, - device=device, - ) - return model, vis_processors, txt_processors - - def load_feature_extractor_model_clip_vitl14(self, device: str = "cpu"): - """ - Load ViT-L-14 clip_feature_extractor model and preprocessors for visual and text inputs from lavis.models. - - Args: - device (str): device to use. Can be "cpu" or "cuda". Default: "cpu". - - Returns: - model (torch.nn.Module): model. - vis_processors (dict): preprocessors for visual inputs. - txt_processors (dict): preprocessors for text inputs. - """ - model, vis_processors, txt_processors = load_model_and_preprocess( - name="clip_feature_extractor", - model_type="ViT-L-14", - is_eval=True, - device=device, - ) - return model, vis_processors, txt_processors - - def load_feature_extractor_model_clip_vitl14_336(self, device: str = "cpu"): - """ - Load ViT-L-14-336 clip_feature_extractor model and preprocessors for visual and text inputs from lavis.models. - - Args: - device (str): device to use. Can be "cpu" or "cuda". Default: "cpu". - - Returns: - model (torch.nn.Module): model. - vis_processors (dict): preprocessors for visual inputs. - txt_processors (dict): preprocessors for text inputs. - """ - model, vis_processors, txt_processors = load_model_and_preprocess( - name="clip_feature_extractor", - model_type="ViT-L-14-336", - is_eval=True, - device=device, - ) - return model, vis_processors, txt_processors - - def read_img(self, filepath: str) -> Image: - """ - Load Image from filepath. - - Args: - filepath (str): path to image. - - Returns: - raw_image (PIL.Image): image. - """ - raw_image = Image.open(filepath).convert("RGB") - return raw_image - - def read_and_process_images(self, image_paths: list, vis_processor) -> tuple: - """ - Read and process images with vis_processor. - - Args: - image_paths (str): paths to images. - vis_processor (dict): preprocessors for visual inputs. - - Returns: - raw_images (list): list of images. - images_tensors (torch.Tensor): tensors of images stacked in device. - """ - raw_images = [MultimodalSearch.read_img(self, path) for path in image_paths] - images = [ - vis_processor["eval"](r_img) - .unsqueeze(0) - .to(MultimodalSearch.multimodal_device) - for r_img in raw_images - ] - images_tensors = torch.stack(images) - - return raw_images, images_tensors - - def extract_image_features_blip2( - self, model, images_tensors: torch.Tensor - ) -> torch.Tensor: - """ - Extract image features from images_tensors with blip2_feature_extractor model. - - Args: - model (torch.nn.Module): model. - images_tensors (torch.Tensor): tensors of images stacked in device. - - Returns: - features_image_stacked (torch.Tensor): tensors of images features stacked in device. - """ - with torch.cuda.amp.autocast( - enabled=(MultimodalSearch.multimodal_device != torch.device("cpu")) - ): - features_image = [ - model.extract_features({"image": ten, "text_input": ""}, mode="image") - for ten in images_tensors - ] - features_image_stacked = torch.stack( - [feat.image_embeds_proj[:, 0, :].squeeze(0) for feat in features_image] - ) - return features_image_stacked - - def extract_image_features_clip( - self, model, images_tensors: torch.Tensor - ) -> torch.Tensor: - """ - Extract image features from images_tensors with clip_feature_extractor model. - - Args: - model (torch.nn.Module): model. - images_tensors (torch.Tensor): tensors of images stacked in device. - - Returns: - features_image_stacked (torch.Tensor): tensors of images features stacked in device. - """ - features_image = [ - model.extract_features({"image": ten}) for ten in images_tensors - ] - features_image_stacked = torch.stack( - [Func.normalize(feat.float(), dim=-1).squeeze(0) for feat in features_image] - ) - return features_image_stacked - - def extract_image_features_basic( - self, model, images_tensors: torch.Tensor - ) -> torch.Tensor: - """ - Extract image features from images_tensors with blip_feature_extractor or albef_feature_extractor model. - - Args: - model (torch.nn.Module): model. - images_tensors (torch.Tensor): tensors of images stacked in device. - - Returns: - features_image_stacked (torch.Tensor): tensors of images features stacked in device. - """ - features_image = [ - model.extract_features({"image": ten, "text_input": ""}, mode="image") - for ten in images_tensors - ] - features_image_stacked = torch.stack( - [feat.image_embeds_proj[:, 0, :].squeeze(0) for feat in features_image] - ) - return features_image_stacked - - def save_tensors( - self, - model_type: str, - features_image_stacked: torch.Tensor, - name: str = "saved_features_image.pt", - path: str = "./saved_tensors/", - ) -> str: - """ - Save tensors as binary to given path. - - Args: - model_type (str): type of the model. - features_image_stacked (torch.Tensor): tensors of images features stacked in device. - name (str): name of the file. Default: "saved_features_image.pt". - path (str): path to save the file. Default: "./saved_tensors/". - - Returns: - name (str): name of the file. - """ - if not os.path.exists(path): - os.makedirs(path) - with open( - str(path) - + str(len(features_image_stacked)) - + "_" - + model_type - + "_" - + name, - "wb", - ) as f: - torch.save(features_image_stacked, f) - return name - - def load_tensors(self, name: str) -> torch.Tensor: - """ - Load tensors from given path. - - Args: - name (str): name of the file. - - Returns: - features_image_stacked (torch.Tensor): tensors of images features. - """ - features_image_stacked = torch.load(name, weights_only=True) - return features_image_stacked - - def extract_text_features(self, model, text_input: str) -> torch.Tensor: - """ - Extract text features from text_input with feature_extractor model. - - Args: - model (torch.nn.Module): model. - text_input (str): text. - - Returns: - features_text (torch.Tensor): tensors of text features. - """ - sample_text = {"text_input": [text_input]} - features_text = model.extract_features(sample_text, mode="text") - - return features_text - - def parsing_images( - self, - model_type: str, - path_to_save_tensors: str = "./saved_tensors/", - path_to_load_tensors: str = None, - ) -> tuple: - """ - Parsing images with feature_extractor model. - - Args: - model_type (str): type of the model. - path_to_save_tensors (str): path to save the tensors. Default: "./saved_tensors/". - path_to_load_tensors (str): path to load the tesors. Default: None. - - Returns: - model (torch.nn.Module): model. - vis_processors (dict): preprocessors for visual inputs. - txt_processors (dict): preprocessors for text inputs. - image_keys (list): sorted list of image keys. - image_names (list): sorted list of image names. - features_image_stacked (torch.Tensor): tensors of images features stacked in device. - """ - if model_type in ("clip_base", "clip_vitl14_336", "clip_vitl14"): - path_to_lib = lavis.__file__[:-11] + "models/clip_models/" - url = "https://raw.githubusercontent.com/salesforce/LAVIS/main/lavis/models/clip_models/bpe_simple_vocab_16e6.txt.gz" - r = requests.get(url, allow_redirects=False) - with open(path_to_lib + "bpe_simple_vocab_16e6.txt.gz", "wb") as f: - f.write(r.content) - - image_keys = sorted(self.subdict.keys()) - image_names = [self.subdict[k]["filename"] for k in image_keys] - - select_model = { - "blip2": MultimodalSearch.load_feature_extractor_model_blip2, - "blip": MultimodalSearch.load_feature_extractor_model_blip, - "albef": MultimodalSearch.load_feature_extractor_model_albef, - "clip_base": MultimodalSearch.load_feature_extractor_model_clip_base, - "clip_vitl14": MultimodalSearch.load_feature_extractor_model_clip_vitl14, - "clip_vitl14_336": MultimodalSearch.load_feature_extractor_model_clip_vitl14_336, - } - - select_extract_image_features = { - "blip2": MultimodalSearch.extract_image_features_blip2, - "blip": MultimodalSearch.extract_image_features_basic, - "albef": MultimodalSearch.extract_image_features_basic, - "clip_base": MultimodalSearch.extract_image_features_clip, - "clip_vitl14": MultimodalSearch.extract_image_features_clip, - "clip_vitl14_336": MultimodalSearch.extract_image_features_clip, - } - - if model_type in select_model.keys(): - ( - model, - vis_processors, - txt_processors, - ) = select_model[ - model_type - ](self, MultimodalSearch.multimodal_device) - else: - raise SyntaxError( - "Please, use one of the following models: blip2, blip, albef, clip_base, clip_vitl14, clip_vitl14_336" - ) - - _, images_tensors = MultimodalSearch.read_and_process_images( - self, image_names, vis_processors - ) - if path_to_load_tensors is None: - with torch.no_grad(): - features_image_stacked = select_extract_image_features[model_type]( - self, model, images_tensors - ) - MultimodalSearch.save_tensors( - self, model_type, features_image_stacked, path=path_to_save_tensors - ) - else: - features_image_stacked = MultimodalSearch.load_tensors( - self, str(path_to_load_tensors) - ) - - return ( - model, - vis_processors, - txt_processors, - image_keys, - image_names, - features_image_stacked, - ) - - def querys_processing( - self, search_query: list, model, txt_processors, vis_processors, model_type: str - ) -> torch.Tensor: - """ - Process querys. - - Args: - search_query (list): list of querys. - model (torch.nn.Module): model. - txt_processors (dict): preprocessors for text inputs. - vis_processors (dict): preprocessors for visual inputs. - model_type (str): type of the model. - - Returns: - multi_features_stacked (torch.Tensor): tensors of querys features. - """ - select_extract_image_features = { - "blip2": MultimodalSearch.extract_image_features_blip2, - "blip": MultimodalSearch.extract_image_features_basic, - "albef": MultimodalSearch.extract_image_features_basic, - "clip_base": MultimodalSearch.extract_image_features_clip, - "clip_vitl14": MultimodalSearch.extract_image_features_clip, - "clip_vitl14_336": MultimodalSearch.extract_image_features_clip, - } - - for query in search_query: - if len(query) != 1 and (query in ("image", "text_input")): - raise SyntaxError( - 'Each query must contain either an "image" or a "text_input"' - ) - multi_sample = [] - for query in search_query: - if "text_input" in query.keys(): - text_processing = txt_processors["eval"](query["text_input"]) - images_tensors = "" - elif "image" in query.keys(): - _, images_tensors = MultimodalSearch.read_and_process_images( - self, [query["image"]], vis_processors - ) - text_processing = "" - multi_sample.append( - {"image": images_tensors, "text_input": text_processing} - ) - - multi_features_query = [] - for query in multi_sample: - if query["image"] == "": - if model_type in ("clip_base", "clip_vitl14_336", "clip_vitl14"): - features = model.extract_features( - {"text_input": query["text_input"]} - ) - - features_squeeze = features.squeeze(0).to( - MultimodalSearch.multimodal_device - ) - multi_features_query.append( - Func.normalize(features_squeeze, dim=-1) - ) - else: - features = model.extract_features(query, mode="text") - features_squeeze = ( - features.text_embeds_proj[:, 0, :] - .squeeze(0) - .to(MultimodalSearch.multimodal_device) - ) - multi_features_query.append(features_squeeze) - if query["text_input"] == "": - multi_features_query.append( - select_extract_image_features[model_type]( - self, model, query["image"] - ) - ) - - multi_features_stacked = torch.stack( - [query.squeeze(0) for query in multi_features_query] - ).to(MultimodalSearch.multimodal_device) - - return multi_features_stacked - - def multimodal_search( - self, - model, - vis_processors, - txt_processors, - model_type: str, - image_keys: list, - features_image_stacked: torch.Tensor, - search_query: list, - filter_number_of_images: str = None, - filter_val_limit: str = None, - filter_rel_error: str = None, - ) -> tuple: - """ - Search for images with given querys. - - Args: - model (torch.nn.Module): model. - vis_processors (dict): preprocessors for visual inputs. - txt_processors (dict): preprocessors for text inputs. - model_type (str): type of the model. - image_keys (list): sorted list of image keys. - features_image_stacked (torch.Tensor): tensors of images features stacked in device. - search_query (list): list of querys. - filter_number_of_images (str): number of images to show. Default: None. - filter_val_limit (str): limit of similarity value. Default: None. - filter_rel_error (str): limit of relative error. Default: None. - - Returns: - similarity (torch.Tensor): similarity between images and querys. - sorted_lists (list): sorted list of similarity. - """ - if filter_number_of_images is None: - filter_number_of_images = len(self.subdict) - if filter_val_limit is None: - filter_val_limit = 0 - if filter_rel_error is None: - filter_rel_error = 1e10 - - features_image_stacked.to(MultimodalSearch.multimodal_device) - - with torch.no_grad(): - multi_features_stacked = MultimodalSearch.querys_processing( - self, search_query, model, txt_processors, vis_processors, model_type - ) - - similarity = features_image_stacked @ multi_features_stacked.t() - sorted_lists = torch.argsort(similarity, dim=0, descending=True).T.tolist() - places = [[item.index(i) for i in range(len(item))] for item in sorted_lists] - - for q in range(len(search_query)): - max_val = similarity[sorted_lists[q][0]][q].item() - for i, key in zip(range(len(image_keys)), sorted_lists[q]): - if ( - i < filter_number_of_images - and similarity[key][q].item() > filter_val_limit - and 100 * abs(max_val - similarity[key][q].item()) / max_val - < filter_rel_error - ): - self.subdict[image_keys[key]][ - "rank " + list(search_query[q].values())[0] - ] = places[q][key] - self.subdict[image_keys[key]][list(search_query[q].values())[0]] = ( - similarity[key][q].item() - ) - else: - self.subdict[image_keys[key]][ - "rank " + list(search_query[q].values())[0] - ] = None - self.subdict[image_keys[key]][list(search_query[q].values())[0]] = 0 - return similarity, sorted_lists - - def itm_text_precessing(self, search_query: list[dict[str, str]]) -> list: - """ - Process text querys for itm model. - - Args: - search_query (list): list of querys. - - Returns: - text_query_index (list): list of indexes of text querys. - """ - for query in search_query: - if (len(query) != 1) and (query in ("image", "text_input")): - raise SyntaxError( - 'Each querry must contain either an "image" or a "text_input"' - ) - text_query_index = [] - for i, query in zip(range(len(search_query)), search_query): - if "text_input" in query.keys(): - text_query_index.append(i) - - return text_query_index - - def get_pathes_from_query(self, query: dict[str, str]) -> tuple: - """ - Get pathes and image names from query. - - Args: - query (dict): query. - - Returns: - paths (list): list of pathes. - image_names (list): list of image names. - """ - paths = [] - image_names = [] - for s in sorted( - self.subdict.items(), - key=lambda t: t[1][list(query.values())[0]], - reverse=True, - ): - if s[1]["rank " + list(query.values())[0]] is None: - break - paths.append(s[1]["filename"]) - image_names.append(s[0]) - return paths, image_names - - def read_and_process_images_itm(self, image_paths: list, vis_processor) -> tuple: - """ - Read and process images with vis_processor for itm model. - - Args: - image_paths (list): paths to images. - vis_processor (dict): preprocessors for visual inputs. - - Returns: - raw_images (list): list of images. - images_tensors (torch.Tensor): tensors of images stacked in device. - """ - raw_images = [MultimodalSearch.read_img(self, path) for path in image_paths] - images = [vis_processor(r_img) for r_img in raw_images] - images_tensors = torch.stack(images).to(MultimodalSearch.multimodal_device) - - return raw_images, images_tensors - - def compute_gradcam_batch( - self, - model: torch.nn.Module, - visual_input: torch.Tensor, - text_input: str, - tokenized_text: torch.Tensor, - block_num: str = 6, - ) -> tuple: - """ - Compute gradcam for itm model. - - Args: - model (torch.nn.Module): model. - visual_input (torch.Tensor): tensors of images features stacked in device. - text_input (str): text. - tokenized_text (torch.Tensor): tokenized text. - block_num (int): number of block. Default: 6. - - Returns: - gradcam (torch.Tensor): gradcam. - output (torch.Tensor): output of model. - """ - model.text_encoder.base_model.base_model.encoder.layer[ - block_num - ].crossattention.self.save_attention = True - - output = model( - {"image": visual_input, "text_input": text_input}, match_head="itm" - ) - loss = output[:, 1].sum() - - model.zero_grad() - loss.backward() - with torch.no_grad(): - mask = tokenized_text.attention_mask.view( - tokenized_text.attention_mask.size(0), 1, -1, 1, 1 - ) # (bsz,1,token_len, 1,1) - token_length = mask.sum() - 2 - token_length = token_length.cpu() - # grads and cams [bsz, num_head, seq_len, image_patch] - grads = model.text_encoder.base_model.base_model.encoder.layer[ - block_num - ].crossattention.self.get_attn_gradients() - cams = model.text_encoder.base_model.base_model.encoder.layer[ - block_num - ].crossattention.self.get_attention_map() - - # assume using vit large with 576 num image patch - cams = ( - cams[:, :, :, 1:].reshape(visual_input.size(0), 12, -1, 24, 24) * mask - ) - grads = ( - grads[:, :, :, 1:] - .clamp(0) - .reshape(visual_input.size(0), 12, -1, 24, 24) - * mask - ) - - gradcam = cams * grads - # [enc token gradcam, average gradcam across token, gradcam for individual token] - # gradcam = torch.cat((gradcam[0:1,:], gradcam[1:token_length+1, :].sum(dim=0, keepdim=True)/token_length, gradcam[1:, :])) - gradcam = gradcam.mean(1).cpu().detach() - gradcam = ( - gradcam[:, 1 : token_length + 1, :].sum(dim=1, keepdim=True) - / token_length - ) - - return gradcam, output - - def resize_img(self, raw_img: Image): - """ - Proportional resize image to 240 p width. - - Args: - raw_img (PIL.Image): image. - - Returns: - resized_image (PIL.Image): proportional resized image to 240p. - """ - w, h = raw_img.size - scaling_factor = 240 / w - resized_image = raw_img.resize( - (int(w * scaling_factor), int(h * scaling_factor)) - ) - return resized_image - - def get_att_map( - self, - img: np.ndarray, - att_map: np.ndarray, - blur: bool = True, - overlap: bool = True, - ) -> np.ndarray: - """ - Get attention map. - - Args: - img (np.ndarray): image. - att_map (np.ndarray): attention map. - blur (bool): blur attention map. Default: True. - overlap (bool): overlap attention map with image. Default: True. - - Returns: - att_map (np.ndarray): attention map. - """ - att_map -= att_map.min() - if att_map.max() > 0: - att_map /= att_map.max() - att_map = skimage_transform.resize( - att_map, (img.shape[:2]), order=3, mode="constant" - ) - if blur: - att_map = filters.gaussian_filter(att_map, 0.02 * max(img.shape[:2])) - att_map -= att_map.min() - att_map /= att_map.max() - cmap = plt.get_cmap("jet") - att_mapv = cmap(att_map) - att_mapv = np.delete(att_mapv, 3, 2) - if overlap: - att_map = ( - 1 * (1 - att_map**0.7).reshape(att_map.shape + (1,)) * img - + (att_map**0.7).reshape(att_map.shape + (1,)) * att_mapv - ) - return att_map - - def upload_model_blip2_coco(self) -> tuple: - """ - Load coco blip2_image_text_matching model and preprocessors for visual inputs from lavis.models. - - Returns: - itm_model (torch.nn.Module): model. - vis_processor (dict): preprocessors for visual inputs. - """ - itm_model = load_model( - "blip2_image_text_matching", - "coco", - is_eval=True, - device=MultimodalSearch.multimodal_device, - ) - vis_processor = load_processor("blip_image_eval").build(image_size=364) - return itm_model, vis_processor - - def upload_model_blip_base(self) -> tuple: - """ - Load base blip_image_text_matching model and preprocessors for visual input from lavis.models. - - Returns: - itm_model (torch.nn.Module): model. - vis_processor (dict): preprocessors for visual inputs. - """ - itm_model = load_model( - "blip_image_text_matching", - "base", - is_eval=True, - device=MultimodalSearch.multimodal_device, - ) - vis_processor = load_processor("blip_image_eval").build(image_size=384) - return itm_model, vis_processor - - def upload_model_blip_large(self) -> tuple: - """ - Load large blip_image_text_matching model and preprocessors for visual input from lavis.models. - - Returns: - itm_model (torch.nn.Module): model. - vis_processor (dict): preprocessors for visual inputs. - """ - itm_model = load_model( - "blip_image_text_matching", - "large", - is_eval=True, - device=MultimodalSearch.multimodal_device, - ) - vis_processor = load_processor("blip_image_eval").build(image_size=384) - return itm_model, vis_processor - - def image_text_match_reordering( - self, - search_query: list[dict[str, str]], - itm_model_type: str, - image_keys: list, - sorted_lists: list[list], - batch_size: int = 1, - need_grad_cam: bool = False, - ) -> tuple: - """ - Reorder images with itm model. - - Args: - search_query (list): list of querys. - itm_model_type (str): type of the model. - image_keys (list): sorted list of image keys. - sorted_lists (list): sorted list of similarity. - batch_size (int): batch size. Default: 1. - need_grad_cam (bool): need gradcam. Default: False. blip2_coco model does not yet work with gradcam. - - Returns: - itm_scores2: list of itm scores. - image_gradcam_with_itm: dict of image names and gradcam. - """ - if itm_model_type == "blip2_coco" and need_grad_cam is True: - raise SyntaxError( - "The blip2_coco model does not yet work with gradcam. Please set need_grad_cam to False" - ) - - choose_model = { - "blip_base": MultimodalSearch.upload_model_blip_base, - "blip_large": MultimodalSearch.upload_model_blip_large, - "blip2_coco": MultimodalSearch.upload_model_blip2_coco, - } - - itm_model, vis_processor_itm = choose_model[itm_model_type](self) - text_processor = load_processor("blip_caption") - tokenizer = BlipBase.init_tokenizer() - - if itm_model_type == "blip2_coco": - need_grad_cam = False - - text_query_index = MultimodalSearch.itm_text_precessing(self, search_query) - - avg_gradcams = [] - itm_scores = [] - itm_scores2 = [] - image_gradcam_with_itm = {} - - for index_text_query in text_query_index: - query = search_query[index_text_query] - pathes, image_names = MultimodalSearch.get_pathes_from_query(self, query) - num_batches = int(len(pathes) / batch_size) - num_batches_residue = len(pathes) % batch_size - - local_itm_scores = [] - local_avg_gradcams = [] - - if num_batches_residue != 0: - num_batches = num_batches + 1 - for i in range(num_batches): - filenames_in_batch = pathes[i * batch_size : (i + 1) * batch_size] - current_len = len(filenames_in_batch) - raw_images, images = MultimodalSearch.read_and_process_images_itm( - self, filenames_in_batch, vis_processor_itm - ) - queries_batch = [text_processor(query["text_input"])] * current_len - queries_tok_batch = tokenizer(queries_batch, return_tensors="pt").to( - MultimodalSearch.multimodal_device - ) - - if need_grad_cam: - gradcam, itm_output = MultimodalSearch.compute_gradcam_batch( - self, - itm_model, - images, - queries_batch, - queries_tok_batch, - ) - norm_imgs = [np.float32(r_img) / 255 for r_img in raw_images] - - for norm_img, grad_cam in zip(norm_imgs, gradcam): - avg_gradcam = MultimodalSearch.get_att_map( - self, norm_img, np.float32(grad_cam[0]), blur=True - ) - local_avg_gradcams.append(avg_gradcam) - - else: - itm_output = itm_model( - {"image": images, "text_input": queries_batch}, match_head="itm" - ) - - with torch.no_grad(): - itm_score = torch.nn.functional.softmax(itm_output, dim=1) - - local_itm_scores.append(itm_score) - - local_itm_scores2 = torch.cat(local_itm_scores)[:, 1] - if need_grad_cam: - localimage_gradcam_with_itm = { - n: i * 255 for n, i in zip(image_names, local_avg_gradcams) - } - else: - localimage_gradcam_with_itm = "" - image_names_with_itm = { - n: i.item() for n, i in zip(image_names, local_itm_scores2) - } - itm_rank = torch.argsort(local_itm_scores2, descending=True) - image_names_with_new_rank = { - image_names[i.item()]: rank - for i, rank in zip(itm_rank, range(len(itm_rank))) - } - for i, key in zip(range(len(image_keys)), sorted_lists[index_text_query]): - if image_keys[key] in image_names: - self.subdict[image_keys[key]][ - "itm " + list(search_query[index_text_query].values())[0] - ] = image_names_with_itm[image_keys[key]] - self.subdict[image_keys[key]][ - "itm_rank " + list(search_query[index_text_query].values())[0] - ] = image_names_with_new_rank[image_keys[key]] - else: - self.subdict[image_keys[key]][ - "itm " + list(search_query[index_text_query].values())[0] - ] = 0 - self.subdict[image_keys[key]][ - "itm_rank " + list(search_query[index_text_query].values())[0] - ] = None - - avg_gradcams.append(local_avg_gradcams) - itm_scores.append(local_itm_scores) - itm_scores2.append(local_itm_scores2) - image_gradcam_with_itm[list(search_query[index_text_query].values())[0]] = ( - localimage_gradcam_with_itm - ) - del ( - itm_model, - vis_processor_itm, - text_processor, - raw_images, - images, - tokenizer, - queries_batch, - queries_tok_batch, - itm_score, - ) - if need_grad_cam: - del itm_output, gradcam, norm_img, grad_cam, avg_gradcam - torch.cuda.empty_cache() - return itm_scores2, image_gradcam_with_itm - - def show_results( - self, query: dict, itm: bool = False, image_gradcam_with_itm: dict = {} - ) -> None: - """ - Show results of search. - - Args: - query (dict): query. - itm (bool): use itm model. Default: False. - image_gradcam_with_itm (dict): use gradcam. Default: empty. - """ - if "image" in query.keys(): - pic = Image.open(query["image"]).convert("RGB") - pic.thumbnail((400, 400)) - display( - "Your search query: ", - pic, - "--------------------------------------------------", - "Results:", - ) - elif "text_input" in query.keys(): - display( - "Your search query: " + query["text_input"], - "--------------------------------------------------", - "Results:", - ) - if itm: - current_querry_val = "itm " + list(query.values())[0] - current_querry_rank = "itm_rank " + list(query.values())[0] - else: - current_querry_val = list(query.values())[0] - current_querry_rank = "rank " + list(query.values())[0] - - for s in sorted( - self.subdict.items(), key=lambda t: t[1][current_querry_val], reverse=True - ): - if s[1][current_querry_rank] is None: - break - if bool(image_gradcam_with_itm) is True and itm is True: - image = image_gradcam_with_itm[list(query.values())[0]][s[0]] - p1 = Image.fromarray(image.astype("uint8"), "RGB") - else: - p1 = Image.open(s[1]["filename"]).convert("RGB") - p1.thumbnail((400, 400)) - display( - "Rank: " - + str(s[1][current_querry_rank]) - + " Val: " - + str(s[1][current_querry_val]), - s[0], - p1, - ) - display( - "--------------------------------------------------", - ) diff --git a/ammico/test/test_multimodal_search.py b/ammico/test/test_multimodal_search.py deleted file mode 100644 index e3ae791..0000000 --- a/ammico/test/test_multimodal_search.py +++ /dev/null @@ -1,521 +0,0 @@ -import pytest -import math -from PIL import Image -import numpy -from torch import device, cuda -import ammico.multimodal_search as ms - -related_error = 1e-2 -gpu_is_not_available = not cuda.is_available() - -cuda.empty_cache() - - -def test_read_img(get_testdict): - my_dict = {} - test_img = ms.MultimodalSearch.read_img( - my_dict, get_testdict["IMG_2746"]["filename"] - ) - assert list(numpy.array(test_img)[257][34]) == [70, 66, 63] - - -pre_proc_pic_blip2_blip_albef = [ - -1.0039474964141846, - -1.0039474964141846, -] -pre_proc_pic_clip_vitl14 = [ - -0.7995694875717163, - -0.7849710583686829, -] - -pre_proc_pic_clip_vitl14_336 = [ - -0.7995694875717163, - -0.7849710583686829, -] - -pre_proc_text_blip2_blip_albef = ( - "the bird sat on a tree located at the intersection of 23rd and 43rd streets" -) - -pre_proc_text_clip_clip_vitl14_clip_vitl14_336 = ( - "The bird sat on a tree located at the intersection of 23rd and 43rd streets." -) - -pre_extracted_feature_img_blip2 = [ - 0.04566730558872223, - -0.042554520070552826, -] - -pre_extracted_feature_img_blip = [ - -0.02480311505496502, - 0.05037587881088257, -] - -pre_extracted_feature_img_albef = [ - 0.08971136063337326, - -0.10915573686361313, -] - -pre_extracted_feature_img_clip = [ - 0.01621132344007492, - -0.004035486374050379, -] - -pre_extracted_feature_img_parsing_clip = [ - 0.01621132344007492, - -0.004035486374050379, -] - -pre_extracted_feature_img_clip_vitl14 = [ - -0.023943455889821053, - -0.021703708916902542, -] - -pre_extracted_feature_img_clip_vitl14_336 = [ - -0.009511193260550499, - -0.012618942186236382, -] - -pre_extracted_feature_text_blip2 = [ - -0.1384204626083374, - -0.008662976324558258, -] - -pre_extracted_feature_text_blip = [ - 0.0118643119931221, - -0.01291718054562807, -] - -pre_extracted_feature_text_albef = [ - -0.06229640915989876, - 0.11278597265481949, -] - -pre_extracted_feature_text_clip = [ - 0.018169036135077477, - 0.03634127229452133, -] - -pre_extracted_feature_text_clip_vitl14 = [ - -0.0055463071912527084, - 0.006908962037414312, -] - -pre_extracted_feature_text_clip_vitl14_336 = [ - -0.008720514364540577, - 0.005284308455884457, -] - -simularity_blip2 = [ - [0.05826476216316223, -0.02717375010251999], - [0.06297147274017334, 0.47339022159576416], -] - -sorted_blip2 = [ - [1, 0], - [1, 0], -] - -simularity_blip = [ - [0.15640679001808167, 0.752173662185669], - [0.17233705520629883, 0.8448910117149353], -] - -sorted_blip = [ - [1, 0], - [1, 0], -] - -simularity_albef = [ - [0.12321824580430984, 0.35511350631713867], - [0.10870333760976791, 0.5143978595733643], -] - -sorted_albef = [ - [0, 1], - [1, 0], -] - -simularity_clip = [ - [0.23923014104366302, 0.5325412750244141], - [0.2310466319322586, 0.5910375714302063], -] - -sorted_clip = [ - [1, 0], - [1, 0], -] - -simularity_clip_vitl14 = [ - [0.1051270067691803, 0.5184808373451233], - [0.1277746558189392, 0.6841973662376404], -] - -sorted_clip_vitl14 = [ - [1, 0], - [1, 0], -] - -simularity_clip_vitl14_336 = [ - [0.09391091763973236, 0.49337542057037354], - [0.13700757920742035, 0.7003108263015747], -] - -sorted_clip_vitl14_336 = [ - [1, 0], - [1, 0], -] - -dict_itm_scores_for_blib = { - "blip_base": [ - 0.07107225805521011, - 0.004100032616406679, - ], - "blip_large": [ - 0.07890705019235611, - 0.00271016638725996, - ], - "blip2_coco": [ - 0.0833505243062973, - 0.004216152708977461, - ], -} - -dict_image_gradcam_with_itm_for_blip = { - "blip_base": [123.36285799741745, 132.31662154197693, 53.38280035299249], - "blip_large": [119.99512910842896, 128.7044593691826, 55.552959859540515], -} - - -@pytest.mark.long -@pytest.mark.parametrize( - ( - "pre_multimodal_device", - "pre_model", - "pre_proc_pic", - "pre_proc_text", - "pre_extracted_feature_img", - "pre_extracted_feature_text", - "pre_simularity", - "pre_sorted", - ), - [ - ( - device("cpu"), - "blip2", - pre_proc_pic_blip2_blip_albef, - pre_proc_text_blip2_blip_albef, - pre_extracted_feature_img_blip2, - pre_extracted_feature_text_blip2, - simularity_blip2, - sorted_blip2, - ), - pytest.param( - device("cuda"), - "blip2", - pre_proc_pic_blip2_blip_albef, - pre_proc_text_blip2_blip_albef, - pre_extracted_feature_img_blip2, - pre_extracted_feature_text_blip2, - simularity_blip2, - sorted_blip2, - marks=pytest.mark.skipif( - gpu_is_not_available, reason="gpu_is_not_availible" - ), - ), - ( - device("cpu"), - "blip", - pre_proc_pic_blip2_blip_albef, - pre_proc_text_blip2_blip_albef, - pre_extracted_feature_img_blip, - pre_extracted_feature_text_blip, - simularity_blip, - sorted_blip, - ), - pytest.param( - device("cuda"), - "blip", - pre_proc_pic_blip2_blip_albef, - pre_proc_text_blip2_blip_albef, - pre_extracted_feature_img_blip, - pre_extracted_feature_text_blip, - simularity_blip, - sorted_blip, - marks=pytest.mark.skipif( - gpu_is_not_available, reason="gpu_is_not_availible" - ), - ), - ( - device("cpu"), - "albef", - pre_proc_pic_blip2_blip_albef, - pre_proc_text_blip2_blip_albef, - pre_extracted_feature_img_albef, - pre_extracted_feature_text_albef, - simularity_albef, - sorted_albef, - ), - pytest.param( - device("cuda"), - "albef", - pre_proc_pic_blip2_blip_albef, - pre_proc_text_blip2_blip_albef, - pre_extracted_feature_img_albef, - pre_extracted_feature_text_albef, - simularity_albef, - sorted_albef, - marks=pytest.mark.skipif( - gpu_is_not_available, reason="gpu_is_not_availible" - ), - ), - ( - device("cpu"), - "clip_base", - pre_proc_pic_clip_vitl14, - pre_proc_text_clip_clip_vitl14_clip_vitl14_336, - pre_extracted_feature_img_clip, - pre_extracted_feature_text_clip, - simularity_clip, - sorted_clip, - ), - pytest.param( - device("cuda"), - "clip_base", - pre_proc_pic_clip_vitl14, - pre_proc_text_clip_clip_vitl14_clip_vitl14_336, - pre_extracted_feature_img_clip, - pre_extracted_feature_text_clip, - simularity_clip, - sorted_clip, - marks=pytest.mark.skipif( - gpu_is_not_available, reason="gpu_is_not_availible" - ), - ), - ( - device("cpu"), - "clip_vitl14", - pre_proc_pic_clip_vitl14, - pre_proc_text_clip_clip_vitl14_clip_vitl14_336, - pre_extracted_feature_img_clip_vitl14, - pre_extracted_feature_text_clip_vitl14, - simularity_clip_vitl14, - sorted_clip_vitl14, - ), - pytest.param( - device("cuda"), - "clip_vitl14", - pre_proc_pic_clip_vitl14, - pre_proc_text_clip_clip_vitl14_clip_vitl14_336, - pre_extracted_feature_img_clip_vitl14, - pre_extracted_feature_text_clip_vitl14, - simularity_clip_vitl14, - sorted_clip_vitl14, - marks=pytest.mark.skipif( - gpu_is_not_available, reason="gpu_is_not_availible" - ), - ), - ( - device("cpu"), - "clip_vitl14_336", - pre_proc_pic_clip_vitl14_336, - pre_proc_text_clip_clip_vitl14_clip_vitl14_336, - pre_extracted_feature_img_clip_vitl14_336, - pre_extracted_feature_text_clip_vitl14_336, - simularity_clip_vitl14_336, - sorted_clip_vitl14_336, - ), - pytest.param( - device("cuda"), - "clip_vitl14_336", - pre_proc_pic_clip_vitl14_336, - pre_proc_text_clip_clip_vitl14_clip_vitl14_336, - pre_extracted_feature_img_clip_vitl14_336, - pre_extracted_feature_text_clip_vitl14_336, - simularity_clip_vitl14_336, - sorted_clip_vitl14_336, - marks=pytest.mark.skipif( - gpu_is_not_available, reason="gpu_is_not_availible" - ), - ), - ], -) -def test_parsing_images( - pre_multimodal_device, - pre_model, - pre_proc_pic, - pre_proc_text, - pre_extracted_feature_img, - pre_extracted_feature_text, - pre_simularity, - pre_sorted, - get_path, - get_testdict, - tmp_path, -): - ms.MultimodalSearch.multimodal_device = pre_multimodal_device - my_obj = ms.MultimodalSearch(get_testdict) - ( - model, - vis_processor, - txt_processor, - image_keys, - _, - features_image_stacked, - ) = my_obj.parsing_images(pre_model, path_to_save_tensors=tmp_path) - - for i, num in zip(range(10), features_image_stacked[0, 10:12].tolist()): - assert ( - math.isclose(num, pre_extracted_feature_img[i], rel_tol=related_error) - is True - ) - - test_pic = Image.open(my_obj.subdict["IMG_2746"]["filename"]).convert("RGB") - test_querry = ( - "The bird sat on a tree located at the intersection of 23rd and 43rd streets." - ) - processed_pic = ( - vis_processor["eval"](test_pic).unsqueeze(0).to(pre_multimodal_device) - ) - processed_text = txt_processor["eval"](test_querry) - - for i, num in zip(range(10), processed_pic[0, 0, 0, 25:27].tolist()): - assert math.isclose(num, pre_proc_pic[i], rel_tol=related_error) is True - - assert processed_text == pre_proc_text - - search_query = [ - {"text_input": test_querry}, - {"image": my_obj.subdict["IMG_2746"]["filename"]}, - ] - multi_features_stacked = my_obj.querys_processing( - search_query, model, txt_processor, vis_processor, pre_model - ) - - for i, num in zip(range(10), multi_features_stacked[0, 10:12].tolist()): - assert ( - math.isclose(num, pre_extracted_feature_text[i], rel_tol=related_error) - is True - ) - - for i, num in zip(range(10), multi_features_stacked[1, 10:12].tolist()): - assert ( - math.isclose(num, pre_extracted_feature_img[i], rel_tol=related_error) - is True - ) - - search_query2 = [ - {"text_input": "A bus"}, - {"image": get_path + "IMG_3758.png"}, - ] - - similarity, sorted_list = my_obj.multimodal_search( - model, - vis_processor, - txt_processor, - pre_model, - image_keys, - features_image_stacked, - search_query2, - ) - - for i, num in zip(range(len(pre_simularity)), similarity.tolist()): - for j, num2 in zip(range(len(num)), num): - assert ( - math.isclose(num2, pre_simularity[i][j], rel_tol=100 * related_error) - is True - ) - - for i, num in zip(range(len(pre_sorted)), sorted_list): - for j, num2 in zip(range(2), num): - assert num2 == pre_sorted[i][j] - - del ( - model, - vis_processor, - txt_processor, - similarity, - features_image_stacked, - processed_pic, - multi_features_stacked, - my_obj, - ) - cuda.empty_cache() - - -@pytest.mark.long -def test_itm(get_test_my_dict, get_path): - search_query3 = [ - {"text_input": "A bus"}, - {"image": get_path + "IMG_3758.png"}, - ] - image_keys = ["IMG_2746", "IMG_2809"] - sorted_list = [[1, 0], [1, 0]] - my_obj = ms.MultimodalSearch(get_test_my_dict) - for itm_model in ["blip_base", "blip_large"]: - ( - itm_scores, - image_gradcam_with_itm, - ) = my_obj.image_text_match_reordering( - search_query3, - itm_model, - image_keys, - sorted_list, - batch_size=1, - need_grad_cam=True, - ) - for i, itm in zip( - range(len(dict_itm_scores_for_blib[itm_model])), - dict_itm_scores_for_blib[itm_model], - ): - assert ( - math.isclose(itm_scores[0].tolist()[i], itm, rel_tol=10 * related_error) - is True - ) - for i, grad_cam in zip( - range(len(dict_image_gradcam_with_itm_for_blip[itm_model])), - dict_image_gradcam_with_itm_for_blip[itm_model], - ): - assert ( - math.isclose( - image_gradcam_with_itm["A bus"]["IMG_2809"][0][0].tolist()[i], - grad_cam, - rel_tol=10 * related_error, - ) - is True - ) - del itm_scores, image_gradcam_with_itm - cuda.empty_cache() - - -@pytest.mark.long -def test_itm_blip2_coco(get_test_my_dict, get_path): - search_query3 = [ - {"text_input": "A bus"}, - {"image": get_path + "IMG_3758.png"}, - ] - image_keys = ["IMG_2746", "IMG_2809"] - sorted_list = [[1, 0], [1, 0]] - my_obj = ms.MultimodalSearch(get_test_my_dict) - - ( - itm_scores, - image_gradcam_with_itm, - ) = my_obj.image_text_match_reordering( - search_query3, - "blip2_coco", - image_keys, - sorted_list, - batch_size=1, - need_grad_cam=False, - ) - for i, itm in zip( - range(len(dict_itm_scores_for_blib["blip2_coco"])), - dict_itm_scores_for_blib["blip2_coco"], - ): - assert ( - math.isclose(itm_scores[0].tolist()[i], itm, rel_tol=10 * related_error) - is True - ) - del itm_scores, image_gradcam_with_itm - cuda.empty_cache()