maintain: remove summary module (VQA)

Этот коммит содержится в:
Inga Ulusoy 2025-09-08 15:21:49 +02:00
родитель 95d80135e2
Коммит f2c97e26ff
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: BDC64F2E85CF8272
2 изменённых файлов: 0 добавлений и 1513 удалений

Просмотреть файл

@ -1,992 +0,0 @@
from ammico.utils import AnalysisMethod
import torch
import torch.nn.functional as Func
import requests
import lavis
import os
import numpy as np
from PIL import Image
from skimage import transform as skimage_transform
from scipy.ndimage import filters
from matplotlib import pyplot as plt
from IPython.display import display
from lavis.models import load_model_and_preprocess, load_model, BlipBase
from lavis.processors import load_processor
class MultimodalSearch(AnalysisMethod):
def __init__(self, subdict: dict) -> None:
super().__init__(subdict)
multimodal_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def load_feature_extractor_model_blip2(self, device: str = "cpu"):
"""
Load pretrain blip2_feature_extractor model and preprocessors for visual and text inputs from lavis.models.
Args:
device (str): device to use. Can be "cpu" or "cuda". Default: "cpu".
Returns:
model (torch.nn.Module): model.
vis_processors (dict): preprocessors for visual inputs.
txt_processors (dict): preprocessors for text inputs.
"""
model, vis_processors, txt_processors = load_model_and_preprocess(
name="blip2_feature_extractor",
model_type="pretrain",
is_eval=True,
device=device,
)
return model, vis_processors, txt_processors
def load_feature_extractor_model_blip(self, device: str = "cpu"):
"""
Load base blip_feature_extractor model and preprocessors for visual and text inputs from lavis.models.
Args:
device (str): device to use. Can be "cpu" or "cuda". Default: "cpu".
Returns:
model (torch.nn.Module): model.
vis_processors (dict): preprocessors for visual inputs.
txt_processors (dict): preprocessors for text inputs.
"""
model, vis_processors, txt_processors = load_model_and_preprocess(
name="blip_feature_extractor",
model_type="base",
is_eval=True,
device=device,
)
return model, vis_processors, txt_processors
def load_feature_extractor_model_albef(self, device: str = "cpu"):
"""
Load base albef_feature_extractor model and preprocessors for visual and text inputs from lavis.models.
Args:
device (str): device to use. Can be "cpu" or "cuda". Default: "cpu".
Returns:
model (torch.nn.Module): model.
vis_processors (dict): preprocessors for visual inputs.
txt_processors (dict): preprocessors for text inputs.
"""
model, vis_processors, txt_processors = load_model_and_preprocess(
name="albef_feature_extractor",
model_type="base",
is_eval=True,
device=device,
)
return model, vis_processors, txt_processors
def load_feature_extractor_model_clip_base(self, device: str = "cpu"):
"""
Load base clip_feature_extractor model and preprocessors for visual and text inputs from lavis.models.
Args:
device (str): device to use. Can be "cpu" or "cuda". Default: "cpu".
Returns:
model (torch.nn.Module): model.
vis_processors (dict): preprocessors for visual inputs.
txt_processors (dict): preprocessors for text inputs.
"""
model, vis_processors, txt_processors = load_model_and_preprocess(
name="clip_feature_extractor",
model_type="base",
is_eval=True,
device=device,
)
return model, vis_processors, txt_processors
def load_feature_extractor_model_clip_vitl14(self, device: str = "cpu"):
"""
Load ViT-L-14 clip_feature_extractor model and preprocessors for visual and text inputs from lavis.models.
Args:
device (str): device to use. Can be "cpu" or "cuda". Default: "cpu".
Returns:
model (torch.nn.Module): model.
vis_processors (dict): preprocessors for visual inputs.
txt_processors (dict): preprocessors for text inputs.
"""
model, vis_processors, txt_processors = load_model_and_preprocess(
name="clip_feature_extractor",
model_type="ViT-L-14",
is_eval=True,
device=device,
)
return model, vis_processors, txt_processors
def load_feature_extractor_model_clip_vitl14_336(self, device: str = "cpu"):
"""
Load ViT-L-14-336 clip_feature_extractor model and preprocessors for visual and text inputs from lavis.models.
Args:
device (str): device to use. Can be "cpu" or "cuda". Default: "cpu".
Returns:
model (torch.nn.Module): model.
vis_processors (dict): preprocessors for visual inputs.
txt_processors (dict): preprocessors for text inputs.
"""
model, vis_processors, txt_processors = load_model_and_preprocess(
name="clip_feature_extractor",
model_type="ViT-L-14-336",
is_eval=True,
device=device,
)
return model, vis_processors, txt_processors
def read_img(self, filepath: str) -> Image:
"""
Load Image from filepath.
Args:
filepath (str): path to image.
Returns:
raw_image (PIL.Image): image.
"""
raw_image = Image.open(filepath).convert("RGB")
return raw_image
def read_and_process_images(self, image_paths: list, vis_processor) -> tuple:
"""
Read and process images with vis_processor.
Args:
image_paths (str): paths to images.
vis_processor (dict): preprocessors for visual inputs.
Returns:
raw_images (list): list of images.
images_tensors (torch.Tensor): tensors of images stacked in device.
"""
raw_images = [MultimodalSearch.read_img(self, path) for path in image_paths]
images = [
vis_processor["eval"](r_img)
.unsqueeze(0)
.to(MultimodalSearch.multimodal_device)
for r_img in raw_images
]
images_tensors = torch.stack(images)
return raw_images, images_tensors
def extract_image_features_blip2(
self, model, images_tensors: torch.Tensor
) -> torch.Tensor:
"""
Extract image features from images_tensors with blip2_feature_extractor model.
Args:
model (torch.nn.Module): model.
images_tensors (torch.Tensor): tensors of images stacked in device.
Returns:
features_image_stacked (torch.Tensor): tensors of images features stacked in device.
"""
with torch.cuda.amp.autocast(
enabled=(MultimodalSearch.multimodal_device != torch.device("cpu"))
):
features_image = [
model.extract_features({"image": ten, "text_input": ""}, mode="image")
for ten in images_tensors
]
features_image_stacked = torch.stack(
[feat.image_embeds_proj[:, 0, :].squeeze(0) for feat in features_image]
)
return features_image_stacked
def extract_image_features_clip(
self, model, images_tensors: torch.Tensor
) -> torch.Tensor:
"""
Extract image features from images_tensors with clip_feature_extractor model.
Args:
model (torch.nn.Module): model.
images_tensors (torch.Tensor): tensors of images stacked in device.
Returns:
features_image_stacked (torch.Tensor): tensors of images features stacked in device.
"""
features_image = [
model.extract_features({"image": ten}) for ten in images_tensors
]
features_image_stacked = torch.stack(
[Func.normalize(feat.float(), dim=-1).squeeze(0) for feat in features_image]
)
return features_image_stacked
def extract_image_features_basic(
self, model, images_tensors: torch.Tensor
) -> torch.Tensor:
"""
Extract image features from images_tensors with blip_feature_extractor or albef_feature_extractor model.
Args:
model (torch.nn.Module): model.
images_tensors (torch.Tensor): tensors of images stacked in device.
Returns:
features_image_stacked (torch.Tensor): tensors of images features stacked in device.
"""
features_image = [
model.extract_features({"image": ten, "text_input": ""}, mode="image")
for ten in images_tensors
]
features_image_stacked = torch.stack(
[feat.image_embeds_proj[:, 0, :].squeeze(0) for feat in features_image]
)
return features_image_stacked
def save_tensors(
self,
model_type: str,
features_image_stacked: torch.Tensor,
name: str = "saved_features_image.pt",
path: str = "./saved_tensors/",
) -> str:
"""
Save tensors as binary to given path.
Args:
model_type (str): type of the model.
features_image_stacked (torch.Tensor): tensors of images features stacked in device.
name (str): name of the file. Default: "saved_features_image.pt".
path (str): path to save the file. Default: "./saved_tensors/".
Returns:
name (str): name of the file.
"""
if not os.path.exists(path):
os.makedirs(path)
with open(
str(path)
+ str(len(features_image_stacked))
+ "_"
+ model_type
+ "_"
+ name,
"wb",
) as f:
torch.save(features_image_stacked, f)
return name
def load_tensors(self, name: str) -> torch.Tensor:
"""
Load tensors from given path.
Args:
name (str): name of the file.
Returns:
features_image_stacked (torch.Tensor): tensors of images features.
"""
features_image_stacked = torch.load(name, weights_only=True)
return features_image_stacked
def extract_text_features(self, model, text_input: str) -> torch.Tensor:
"""
Extract text features from text_input with feature_extractor model.
Args:
model (torch.nn.Module): model.
text_input (str): text.
Returns:
features_text (torch.Tensor): tensors of text features.
"""
sample_text = {"text_input": [text_input]}
features_text = model.extract_features(sample_text, mode="text")
return features_text
def parsing_images(
self,
model_type: str,
path_to_save_tensors: str = "./saved_tensors/",
path_to_load_tensors: str = None,
) -> tuple:
"""
Parsing images with feature_extractor model.
Args:
model_type (str): type of the model.
path_to_save_tensors (str): path to save the tensors. Default: "./saved_tensors/".
path_to_load_tensors (str): path to load the tesors. Default: None.
Returns:
model (torch.nn.Module): model.
vis_processors (dict): preprocessors for visual inputs.
txt_processors (dict): preprocessors for text inputs.
image_keys (list): sorted list of image keys.
image_names (list): sorted list of image names.
features_image_stacked (torch.Tensor): tensors of images features stacked in device.
"""
if model_type in ("clip_base", "clip_vitl14_336", "clip_vitl14"):
path_to_lib = lavis.__file__[:-11] + "models/clip_models/"
url = "https://raw.githubusercontent.com/salesforce/LAVIS/main/lavis/models/clip_models/bpe_simple_vocab_16e6.txt.gz"
r = requests.get(url, allow_redirects=False)
with open(path_to_lib + "bpe_simple_vocab_16e6.txt.gz", "wb") as f:
f.write(r.content)
image_keys = sorted(self.subdict.keys())
image_names = [self.subdict[k]["filename"] for k in image_keys]
select_model = {
"blip2": MultimodalSearch.load_feature_extractor_model_blip2,
"blip": MultimodalSearch.load_feature_extractor_model_blip,
"albef": MultimodalSearch.load_feature_extractor_model_albef,
"clip_base": MultimodalSearch.load_feature_extractor_model_clip_base,
"clip_vitl14": MultimodalSearch.load_feature_extractor_model_clip_vitl14,
"clip_vitl14_336": MultimodalSearch.load_feature_extractor_model_clip_vitl14_336,
}
select_extract_image_features = {
"blip2": MultimodalSearch.extract_image_features_blip2,
"blip": MultimodalSearch.extract_image_features_basic,
"albef": MultimodalSearch.extract_image_features_basic,
"clip_base": MultimodalSearch.extract_image_features_clip,
"clip_vitl14": MultimodalSearch.extract_image_features_clip,
"clip_vitl14_336": MultimodalSearch.extract_image_features_clip,
}
if model_type in select_model.keys():
(
model,
vis_processors,
txt_processors,
) = select_model[
model_type
](self, MultimodalSearch.multimodal_device)
else:
raise SyntaxError(
"Please, use one of the following models: blip2, blip, albef, clip_base, clip_vitl14, clip_vitl14_336"
)
_, images_tensors = MultimodalSearch.read_and_process_images(
self, image_names, vis_processors
)
if path_to_load_tensors is None:
with torch.no_grad():
features_image_stacked = select_extract_image_features[model_type](
self, model, images_tensors
)
MultimodalSearch.save_tensors(
self, model_type, features_image_stacked, path=path_to_save_tensors
)
else:
features_image_stacked = MultimodalSearch.load_tensors(
self, str(path_to_load_tensors)
)
return (
model,
vis_processors,
txt_processors,
image_keys,
image_names,
features_image_stacked,
)
def querys_processing(
self, search_query: list, model, txt_processors, vis_processors, model_type: str
) -> torch.Tensor:
"""
Process querys.
Args:
search_query (list): list of querys.
model (torch.nn.Module): model.
txt_processors (dict): preprocessors for text inputs.
vis_processors (dict): preprocessors for visual inputs.
model_type (str): type of the model.
Returns:
multi_features_stacked (torch.Tensor): tensors of querys features.
"""
select_extract_image_features = {
"blip2": MultimodalSearch.extract_image_features_blip2,
"blip": MultimodalSearch.extract_image_features_basic,
"albef": MultimodalSearch.extract_image_features_basic,
"clip_base": MultimodalSearch.extract_image_features_clip,
"clip_vitl14": MultimodalSearch.extract_image_features_clip,
"clip_vitl14_336": MultimodalSearch.extract_image_features_clip,
}
for query in search_query:
if len(query) != 1 and (query in ("image", "text_input")):
raise SyntaxError(
'Each query must contain either an "image" or a "text_input"'
)
multi_sample = []
for query in search_query:
if "text_input" in query.keys():
text_processing = txt_processors["eval"](query["text_input"])
images_tensors = ""
elif "image" in query.keys():
_, images_tensors = MultimodalSearch.read_and_process_images(
self, [query["image"]], vis_processors
)
text_processing = ""
multi_sample.append(
{"image": images_tensors, "text_input": text_processing}
)
multi_features_query = []
for query in multi_sample:
if query["image"] == "":
if model_type in ("clip_base", "clip_vitl14_336", "clip_vitl14"):
features = model.extract_features(
{"text_input": query["text_input"]}
)
features_squeeze = features.squeeze(0).to(
MultimodalSearch.multimodal_device
)
multi_features_query.append(
Func.normalize(features_squeeze, dim=-1)
)
else:
features = model.extract_features(query, mode="text")
features_squeeze = (
features.text_embeds_proj[:, 0, :]
.squeeze(0)
.to(MultimodalSearch.multimodal_device)
)
multi_features_query.append(features_squeeze)
if query["text_input"] == "":
multi_features_query.append(
select_extract_image_features[model_type](
self, model, query["image"]
)
)
multi_features_stacked = torch.stack(
[query.squeeze(0) for query in multi_features_query]
).to(MultimodalSearch.multimodal_device)
return multi_features_stacked
def multimodal_search(
self,
model,
vis_processors,
txt_processors,
model_type: str,
image_keys: list,
features_image_stacked: torch.Tensor,
search_query: list,
filter_number_of_images: str = None,
filter_val_limit: str = None,
filter_rel_error: str = None,
) -> tuple:
"""
Search for images with given querys.
Args:
model (torch.nn.Module): model.
vis_processors (dict): preprocessors for visual inputs.
txt_processors (dict): preprocessors for text inputs.
model_type (str): type of the model.
image_keys (list): sorted list of image keys.
features_image_stacked (torch.Tensor): tensors of images features stacked in device.
search_query (list): list of querys.
filter_number_of_images (str): number of images to show. Default: None.
filter_val_limit (str): limit of similarity value. Default: None.
filter_rel_error (str): limit of relative error. Default: None.
Returns:
similarity (torch.Tensor): similarity between images and querys.
sorted_lists (list): sorted list of similarity.
"""
if filter_number_of_images is None:
filter_number_of_images = len(self.subdict)
if filter_val_limit is None:
filter_val_limit = 0
if filter_rel_error is None:
filter_rel_error = 1e10
features_image_stacked.to(MultimodalSearch.multimodal_device)
with torch.no_grad():
multi_features_stacked = MultimodalSearch.querys_processing(
self, search_query, model, txt_processors, vis_processors, model_type
)
similarity = features_image_stacked @ multi_features_stacked.t()
sorted_lists = torch.argsort(similarity, dim=0, descending=True).T.tolist()
places = [[item.index(i) for i in range(len(item))] for item in sorted_lists]
for q in range(len(search_query)):
max_val = similarity[sorted_lists[q][0]][q].item()
for i, key in zip(range(len(image_keys)), sorted_lists[q]):
if (
i < filter_number_of_images
and similarity[key][q].item() > filter_val_limit
and 100 * abs(max_val - similarity[key][q].item()) / max_val
< filter_rel_error
):
self.subdict[image_keys[key]][
"rank " + list(search_query[q].values())[0]
] = places[q][key]
self.subdict[image_keys[key]][list(search_query[q].values())[0]] = (
similarity[key][q].item()
)
else:
self.subdict[image_keys[key]][
"rank " + list(search_query[q].values())[0]
] = None
self.subdict[image_keys[key]][list(search_query[q].values())[0]] = 0
return similarity, sorted_lists
def itm_text_precessing(self, search_query: list[dict[str, str]]) -> list:
"""
Process text querys for itm model.
Args:
search_query (list): list of querys.
Returns:
text_query_index (list): list of indexes of text querys.
"""
for query in search_query:
if (len(query) != 1) and (query in ("image", "text_input")):
raise SyntaxError(
'Each querry must contain either an "image" or a "text_input"'
)
text_query_index = []
for i, query in zip(range(len(search_query)), search_query):
if "text_input" in query.keys():
text_query_index.append(i)
return text_query_index
def get_pathes_from_query(self, query: dict[str, str]) -> tuple:
"""
Get pathes and image names from query.
Args:
query (dict): query.
Returns:
paths (list): list of pathes.
image_names (list): list of image names.
"""
paths = []
image_names = []
for s in sorted(
self.subdict.items(),
key=lambda t: t[1][list(query.values())[0]],
reverse=True,
):
if s[1]["rank " + list(query.values())[0]] is None:
break
paths.append(s[1]["filename"])
image_names.append(s[0])
return paths, image_names
def read_and_process_images_itm(self, image_paths: list, vis_processor) -> tuple:
"""
Read and process images with vis_processor for itm model.
Args:
image_paths (list): paths to images.
vis_processor (dict): preprocessors for visual inputs.
Returns:
raw_images (list): list of images.
images_tensors (torch.Tensor): tensors of images stacked in device.
"""
raw_images = [MultimodalSearch.read_img(self, path) for path in image_paths]
images = [vis_processor(r_img) for r_img in raw_images]
images_tensors = torch.stack(images).to(MultimodalSearch.multimodal_device)
return raw_images, images_tensors
def compute_gradcam_batch(
self,
model: torch.nn.Module,
visual_input: torch.Tensor,
text_input: str,
tokenized_text: torch.Tensor,
block_num: str = 6,
) -> tuple:
"""
Compute gradcam for itm model.
Args:
model (torch.nn.Module): model.
visual_input (torch.Tensor): tensors of images features stacked in device.
text_input (str): text.
tokenized_text (torch.Tensor): tokenized text.
block_num (int): number of block. Default: 6.
Returns:
gradcam (torch.Tensor): gradcam.
output (torch.Tensor): output of model.
"""
model.text_encoder.base_model.base_model.encoder.layer[
block_num
].crossattention.self.save_attention = True
output = model(
{"image": visual_input, "text_input": text_input}, match_head="itm"
)
loss = output[:, 1].sum()
model.zero_grad()
loss.backward()
with torch.no_grad():
mask = tokenized_text.attention_mask.view(
tokenized_text.attention_mask.size(0), 1, -1, 1, 1
) # (bsz,1,token_len, 1,1)
token_length = mask.sum() - 2
token_length = token_length.cpu()
# grads and cams [bsz, num_head, seq_len, image_patch]
grads = model.text_encoder.base_model.base_model.encoder.layer[
block_num
].crossattention.self.get_attn_gradients()
cams = model.text_encoder.base_model.base_model.encoder.layer[
block_num
].crossattention.self.get_attention_map()
# assume using vit large with 576 num image patch
cams = (
cams[:, :, :, 1:].reshape(visual_input.size(0), 12, -1, 24, 24) * mask
)
grads = (
grads[:, :, :, 1:]
.clamp(0)
.reshape(visual_input.size(0), 12, -1, 24, 24)
* mask
)
gradcam = cams * grads
# [enc token gradcam, average gradcam across token, gradcam for individual token]
# gradcam = torch.cat((gradcam[0:1,:], gradcam[1:token_length+1, :].sum(dim=0, keepdim=True)/token_length, gradcam[1:, :]))
gradcam = gradcam.mean(1).cpu().detach()
gradcam = (
gradcam[:, 1 : token_length + 1, :].sum(dim=1, keepdim=True)
/ token_length
)
return gradcam, output
def resize_img(self, raw_img: Image):
"""
Proportional resize image to 240 p width.
Args:
raw_img (PIL.Image): image.
Returns:
resized_image (PIL.Image): proportional resized image to 240p.
"""
w, h = raw_img.size
scaling_factor = 240 / w
resized_image = raw_img.resize(
(int(w * scaling_factor), int(h * scaling_factor))
)
return resized_image
def get_att_map(
self,
img: np.ndarray,
att_map: np.ndarray,
blur: bool = True,
overlap: bool = True,
) -> np.ndarray:
"""
Get attention map.
Args:
img (np.ndarray): image.
att_map (np.ndarray): attention map.
blur (bool): blur attention map. Default: True.
overlap (bool): overlap attention map with image. Default: True.
Returns:
att_map (np.ndarray): attention map.
"""
att_map -= att_map.min()
if att_map.max() > 0:
att_map /= att_map.max()
att_map = skimage_transform.resize(
att_map, (img.shape[:2]), order=3, mode="constant"
)
if blur:
att_map = filters.gaussian_filter(att_map, 0.02 * max(img.shape[:2]))
att_map -= att_map.min()
att_map /= att_map.max()
cmap = plt.get_cmap("jet")
att_mapv = cmap(att_map)
att_mapv = np.delete(att_mapv, 3, 2)
if overlap:
att_map = (
1 * (1 - att_map**0.7).reshape(att_map.shape + (1,)) * img
+ (att_map**0.7).reshape(att_map.shape + (1,)) * att_mapv
)
return att_map
def upload_model_blip2_coco(self) -> tuple:
"""
Load coco blip2_image_text_matching model and preprocessors for visual inputs from lavis.models.
Returns:
itm_model (torch.nn.Module): model.
vis_processor (dict): preprocessors for visual inputs.
"""
itm_model = load_model(
"blip2_image_text_matching",
"coco",
is_eval=True,
device=MultimodalSearch.multimodal_device,
)
vis_processor = load_processor("blip_image_eval").build(image_size=364)
return itm_model, vis_processor
def upload_model_blip_base(self) -> tuple:
"""
Load base blip_image_text_matching model and preprocessors for visual input from lavis.models.
Returns:
itm_model (torch.nn.Module): model.
vis_processor (dict): preprocessors for visual inputs.
"""
itm_model = load_model(
"blip_image_text_matching",
"base",
is_eval=True,
device=MultimodalSearch.multimodal_device,
)
vis_processor = load_processor("blip_image_eval").build(image_size=384)
return itm_model, vis_processor
def upload_model_blip_large(self) -> tuple:
"""
Load large blip_image_text_matching model and preprocessors for visual input from lavis.models.
Returns:
itm_model (torch.nn.Module): model.
vis_processor (dict): preprocessors for visual inputs.
"""
itm_model = load_model(
"blip_image_text_matching",
"large",
is_eval=True,
device=MultimodalSearch.multimodal_device,
)
vis_processor = load_processor("blip_image_eval").build(image_size=384)
return itm_model, vis_processor
def image_text_match_reordering(
self,
search_query: list[dict[str, str]],
itm_model_type: str,
image_keys: list,
sorted_lists: list[list],
batch_size: int = 1,
need_grad_cam: bool = False,
) -> tuple:
"""
Reorder images with itm model.
Args:
search_query (list): list of querys.
itm_model_type (str): type of the model.
image_keys (list): sorted list of image keys.
sorted_lists (list): sorted list of similarity.
batch_size (int): batch size. Default: 1.
need_grad_cam (bool): need gradcam. Default: False. blip2_coco model does not yet work with gradcam.
Returns:
itm_scores2: list of itm scores.
image_gradcam_with_itm: dict of image names and gradcam.
"""
if itm_model_type == "blip2_coco" and need_grad_cam is True:
raise SyntaxError(
"The blip2_coco model does not yet work with gradcam. Please set need_grad_cam to False"
)
choose_model = {
"blip_base": MultimodalSearch.upload_model_blip_base,
"blip_large": MultimodalSearch.upload_model_blip_large,
"blip2_coco": MultimodalSearch.upload_model_blip2_coco,
}
itm_model, vis_processor_itm = choose_model[itm_model_type](self)
text_processor = load_processor("blip_caption")
tokenizer = BlipBase.init_tokenizer()
if itm_model_type == "blip2_coco":
need_grad_cam = False
text_query_index = MultimodalSearch.itm_text_precessing(self, search_query)
avg_gradcams = []
itm_scores = []
itm_scores2 = []
image_gradcam_with_itm = {}
for index_text_query in text_query_index:
query = search_query[index_text_query]
pathes, image_names = MultimodalSearch.get_pathes_from_query(self, query)
num_batches = int(len(pathes) / batch_size)
num_batches_residue = len(pathes) % batch_size
local_itm_scores = []
local_avg_gradcams = []
if num_batches_residue != 0:
num_batches = num_batches + 1
for i in range(num_batches):
filenames_in_batch = pathes[i * batch_size : (i + 1) * batch_size]
current_len = len(filenames_in_batch)
raw_images, images = MultimodalSearch.read_and_process_images_itm(
self, filenames_in_batch, vis_processor_itm
)
queries_batch = [text_processor(query["text_input"])] * current_len
queries_tok_batch = tokenizer(queries_batch, return_tensors="pt").to(
MultimodalSearch.multimodal_device
)
if need_grad_cam:
gradcam, itm_output = MultimodalSearch.compute_gradcam_batch(
self,
itm_model,
images,
queries_batch,
queries_tok_batch,
)
norm_imgs = [np.float32(r_img) / 255 for r_img in raw_images]
for norm_img, grad_cam in zip(norm_imgs, gradcam):
avg_gradcam = MultimodalSearch.get_att_map(
self, norm_img, np.float32(grad_cam[0]), blur=True
)
local_avg_gradcams.append(avg_gradcam)
else:
itm_output = itm_model(
{"image": images, "text_input": queries_batch}, match_head="itm"
)
with torch.no_grad():
itm_score = torch.nn.functional.softmax(itm_output, dim=1)
local_itm_scores.append(itm_score)
local_itm_scores2 = torch.cat(local_itm_scores)[:, 1]
if need_grad_cam:
localimage_gradcam_with_itm = {
n: i * 255 for n, i in zip(image_names, local_avg_gradcams)
}
else:
localimage_gradcam_with_itm = ""
image_names_with_itm = {
n: i.item() for n, i in zip(image_names, local_itm_scores2)
}
itm_rank = torch.argsort(local_itm_scores2, descending=True)
image_names_with_new_rank = {
image_names[i.item()]: rank
for i, rank in zip(itm_rank, range(len(itm_rank)))
}
for i, key in zip(range(len(image_keys)), sorted_lists[index_text_query]):
if image_keys[key] in image_names:
self.subdict[image_keys[key]][
"itm " + list(search_query[index_text_query].values())[0]
] = image_names_with_itm[image_keys[key]]
self.subdict[image_keys[key]][
"itm_rank " + list(search_query[index_text_query].values())[0]
] = image_names_with_new_rank[image_keys[key]]
else:
self.subdict[image_keys[key]][
"itm " + list(search_query[index_text_query].values())[0]
] = 0
self.subdict[image_keys[key]][
"itm_rank " + list(search_query[index_text_query].values())[0]
] = None
avg_gradcams.append(local_avg_gradcams)
itm_scores.append(local_itm_scores)
itm_scores2.append(local_itm_scores2)
image_gradcam_with_itm[list(search_query[index_text_query].values())[0]] = (
localimage_gradcam_with_itm
)
del (
itm_model,
vis_processor_itm,
text_processor,
raw_images,
images,
tokenizer,
queries_batch,
queries_tok_batch,
itm_score,
)
if need_grad_cam:
del itm_output, gradcam, norm_img, grad_cam, avg_gradcam
torch.cuda.empty_cache()
return itm_scores2, image_gradcam_with_itm
def show_results(
self, query: dict, itm: bool = False, image_gradcam_with_itm: dict = {}
) -> None:
"""
Show results of search.
Args:
query (dict): query.
itm (bool): use itm model. Default: False.
image_gradcam_with_itm (dict): use gradcam. Default: empty.
"""
if "image" in query.keys():
pic = Image.open(query["image"]).convert("RGB")
pic.thumbnail((400, 400))
display(
"Your search query: ",
pic,
"--------------------------------------------------",
"Results:",
)
elif "text_input" in query.keys():
display(
"Your search query: " + query["text_input"],
"--------------------------------------------------",
"Results:",
)
if itm:
current_querry_val = "itm " + list(query.values())[0]
current_querry_rank = "itm_rank " + list(query.values())[0]
else:
current_querry_val = list(query.values())[0]
current_querry_rank = "rank " + list(query.values())[0]
for s in sorted(
self.subdict.items(), key=lambda t: t[1][current_querry_val], reverse=True
):
if s[1][current_querry_rank] is None:
break
if bool(image_gradcam_with_itm) is True and itm is True:
image = image_gradcam_with_itm[list(query.values())[0]][s[0]]
p1 = Image.fromarray(image.astype("uint8"), "RGB")
else:
p1 = Image.open(s[1]["filename"]).convert("RGB")
p1.thumbnail((400, 400))
display(
"Rank: "
+ str(s[1][current_querry_rank])
+ " Val: "
+ str(s[1][current_querry_val]),
s[0],
p1,
)
display(
"--------------------------------------------------",
)

Просмотреть файл

@ -1,521 +0,0 @@
import pytest
import math
from PIL import Image
import numpy
from torch import device, cuda
import ammico.multimodal_search as ms
related_error = 1e-2
gpu_is_not_available = not cuda.is_available()
cuda.empty_cache()
def test_read_img(get_testdict):
my_dict = {}
test_img = ms.MultimodalSearch.read_img(
my_dict, get_testdict["IMG_2746"]["filename"]
)
assert list(numpy.array(test_img)[257][34]) == [70, 66, 63]
pre_proc_pic_blip2_blip_albef = [
-1.0039474964141846,
-1.0039474964141846,
]
pre_proc_pic_clip_vitl14 = [
-0.7995694875717163,
-0.7849710583686829,
]
pre_proc_pic_clip_vitl14_336 = [
-0.7995694875717163,
-0.7849710583686829,
]
pre_proc_text_blip2_blip_albef = (
"the bird sat on a tree located at the intersection of 23rd and 43rd streets"
)
pre_proc_text_clip_clip_vitl14_clip_vitl14_336 = (
"The bird sat on a tree located at the intersection of 23rd and 43rd streets."
)
pre_extracted_feature_img_blip2 = [
0.04566730558872223,
-0.042554520070552826,
]
pre_extracted_feature_img_blip = [
-0.02480311505496502,
0.05037587881088257,
]
pre_extracted_feature_img_albef = [
0.08971136063337326,
-0.10915573686361313,
]
pre_extracted_feature_img_clip = [
0.01621132344007492,
-0.004035486374050379,
]
pre_extracted_feature_img_parsing_clip = [
0.01621132344007492,
-0.004035486374050379,
]
pre_extracted_feature_img_clip_vitl14 = [
-0.023943455889821053,
-0.021703708916902542,
]
pre_extracted_feature_img_clip_vitl14_336 = [
-0.009511193260550499,
-0.012618942186236382,
]
pre_extracted_feature_text_blip2 = [
-0.1384204626083374,
-0.008662976324558258,
]
pre_extracted_feature_text_blip = [
0.0118643119931221,
-0.01291718054562807,
]
pre_extracted_feature_text_albef = [
-0.06229640915989876,
0.11278597265481949,
]
pre_extracted_feature_text_clip = [
0.018169036135077477,
0.03634127229452133,
]
pre_extracted_feature_text_clip_vitl14 = [
-0.0055463071912527084,
0.006908962037414312,
]
pre_extracted_feature_text_clip_vitl14_336 = [
-0.008720514364540577,
0.005284308455884457,
]
simularity_blip2 = [
[0.05826476216316223, -0.02717375010251999],
[0.06297147274017334, 0.47339022159576416],
]
sorted_blip2 = [
[1, 0],
[1, 0],
]
simularity_blip = [
[0.15640679001808167, 0.752173662185669],
[0.17233705520629883, 0.8448910117149353],
]
sorted_blip = [
[1, 0],
[1, 0],
]
simularity_albef = [
[0.12321824580430984, 0.35511350631713867],
[0.10870333760976791, 0.5143978595733643],
]
sorted_albef = [
[0, 1],
[1, 0],
]
simularity_clip = [
[0.23923014104366302, 0.5325412750244141],
[0.2310466319322586, 0.5910375714302063],
]
sorted_clip = [
[1, 0],
[1, 0],
]
simularity_clip_vitl14 = [
[0.1051270067691803, 0.5184808373451233],
[0.1277746558189392, 0.6841973662376404],
]
sorted_clip_vitl14 = [
[1, 0],
[1, 0],
]
simularity_clip_vitl14_336 = [
[0.09391091763973236, 0.49337542057037354],
[0.13700757920742035, 0.7003108263015747],
]
sorted_clip_vitl14_336 = [
[1, 0],
[1, 0],
]
dict_itm_scores_for_blib = {
"blip_base": [
0.07107225805521011,
0.004100032616406679,
],
"blip_large": [
0.07890705019235611,
0.00271016638725996,
],
"blip2_coco": [
0.0833505243062973,
0.004216152708977461,
],
}
dict_image_gradcam_with_itm_for_blip = {
"blip_base": [123.36285799741745, 132.31662154197693, 53.38280035299249],
"blip_large": [119.99512910842896, 128.7044593691826, 55.552959859540515],
}
@pytest.mark.long
@pytest.mark.parametrize(
(
"pre_multimodal_device",
"pre_model",
"pre_proc_pic",
"pre_proc_text",
"pre_extracted_feature_img",
"pre_extracted_feature_text",
"pre_simularity",
"pre_sorted",
),
[
(
device("cpu"),
"blip2",
pre_proc_pic_blip2_blip_albef,
pre_proc_text_blip2_blip_albef,
pre_extracted_feature_img_blip2,
pre_extracted_feature_text_blip2,
simularity_blip2,
sorted_blip2,
),
pytest.param(
device("cuda"),
"blip2",
pre_proc_pic_blip2_blip_albef,
pre_proc_text_blip2_blip_albef,
pre_extracted_feature_img_blip2,
pre_extracted_feature_text_blip2,
simularity_blip2,
sorted_blip2,
marks=pytest.mark.skipif(
gpu_is_not_available, reason="gpu_is_not_availible"
),
),
(
device("cpu"),
"blip",
pre_proc_pic_blip2_blip_albef,
pre_proc_text_blip2_blip_albef,
pre_extracted_feature_img_blip,
pre_extracted_feature_text_blip,
simularity_blip,
sorted_blip,
),
pytest.param(
device("cuda"),
"blip",
pre_proc_pic_blip2_blip_albef,
pre_proc_text_blip2_blip_albef,
pre_extracted_feature_img_blip,
pre_extracted_feature_text_blip,
simularity_blip,
sorted_blip,
marks=pytest.mark.skipif(
gpu_is_not_available, reason="gpu_is_not_availible"
),
),
(
device("cpu"),
"albef",
pre_proc_pic_blip2_blip_albef,
pre_proc_text_blip2_blip_albef,
pre_extracted_feature_img_albef,
pre_extracted_feature_text_albef,
simularity_albef,
sorted_albef,
),
pytest.param(
device("cuda"),
"albef",
pre_proc_pic_blip2_blip_albef,
pre_proc_text_blip2_blip_albef,
pre_extracted_feature_img_albef,
pre_extracted_feature_text_albef,
simularity_albef,
sorted_albef,
marks=pytest.mark.skipif(
gpu_is_not_available, reason="gpu_is_not_availible"
),
),
(
device("cpu"),
"clip_base",
pre_proc_pic_clip_vitl14,
pre_proc_text_clip_clip_vitl14_clip_vitl14_336,
pre_extracted_feature_img_clip,
pre_extracted_feature_text_clip,
simularity_clip,
sorted_clip,
),
pytest.param(
device("cuda"),
"clip_base",
pre_proc_pic_clip_vitl14,
pre_proc_text_clip_clip_vitl14_clip_vitl14_336,
pre_extracted_feature_img_clip,
pre_extracted_feature_text_clip,
simularity_clip,
sorted_clip,
marks=pytest.mark.skipif(
gpu_is_not_available, reason="gpu_is_not_availible"
),
),
(
device("cpu"),
"clip_vitl14",
pre_proc_pic_clip_vitl14,
pre_proc_text_clip_clip_vitl14_clip_vitl14_336,
pre_extracted_feature_img_clip_vitl14,
pre_extracted_feature_text_clip_vitl14,
simularity_clip_vitl14,
sorted_clip_vitl14,
),
pytest.param(
device("cuda"),
"clip_vitl14",
pre_proc_pic_clip_vitl14,
pre_proc_text_clip_clip_vitl14_clip_vitl14_336,
pre_extracted_feature_img_clip_vitl14,
pre_extracted_feature_text_clip_vitl14,
simularity_clip_vitl14,
sorted_clip_vitl14,
marks=pytest.mark.skipif(
gpu_is_not_available, reason="gpu_is_not_availible"
),
),
(
device("cpu"),
"clip_vitl14_336",
pre_proc_pic_clip_vitl14_336,
pre_proc_text_clip_clip_vitl14_clip_vitl14_336,
pre_extracted_feature_img_clip_vitl14_336,
pre_extracted_feature_text_clip_vitl14_336,
simularity_clip_vitl14_336,
sorted_clip_vitl14_336,
),
pytest.param(
device("cuda"),
"clip_vitl14_336",
pre_proc_pic_clip_vitl14_336,
pre_proc_text_clip_clip_vitl14_clip_vitl14_336,
pre_extracted_feature_img_clip_vitl14_336,
pre_extracted_feature_text_clip_vitl14_336,
simularity_clip_vitl14_336,
sorted_clip_vitl14_336,
marks=pytest.mark.skipif(
gpu_is_not_available, reason="gpu_is_not_availible"
),
),
],
)
def test_parsing_images(
pre_multimodal_device,
pre_model,
pre_proc_pic,
pre_proc_text,
pre_extracted_feature_img,
pre_extracted_feature_text,
pre_simularity,
pre_sorted,
get_path,
get_testdict,
tmp_path,
):
ms.MultimodalSearch.multimodal_device = pre_multimodal_device
my_obj = ms.MultimodalSearch(get_testdict)
(
model,
vis_processor,
txt_processor,
image_keys,
_,
features_image_stacked,
) = my_obj.parsing_images(pre_model, path_to_save_tensors=tmp_path)
for i, num in zip(range(10), features_image_stacked[0, 10:12].tolist()):
assert (
math.isclose(num, pre_extracted_feature_img[i], rel_tol=related_error)
is True
)
test_pic = Image.open(my_obj.subdict["IMG_2746"]["filename"]).convert("RGB")
test_querry = (
"The bird sat on a tree located at the intersection of 23rd and 43rd streets."
)
processed_pic = (
vis_processor["eval"](test_pic).unsqueeze(0).to(pre_multimodal_device)
)
processed_text = txt_processor["eval"](test_querry)
for i, num in zip(range(10), processed_pic[0, 0, 0, 25:27].tolist()):
assert math.isclose(num, pre_proc_pic[i], rel_tol=related_error) is True
assert processed_text == pre_proc_text
search_query = [
{"text_input": test_querry},
{"image": my_obj.subdict["IMG_2746"]["filename"]},
]
multi_features_stacked = my_obj.querys_processing(
search_query, model, txt_processor, vis_processor, pre_model
)
for i, num in zip(range(10), multi_features_stacked[0, 10:12].tolist()):
assert (
math.isclose(num, pre_extracted_feature_text[i], rel_tol=related_error)
is True
)
for i, num in zip(range(10), multi_features_stacked[1, 10:12].tolist()):
assert (
math.isclose(num, pre_extracted_feature_img[i], rel_tol=related_error)
is True
)
search_query2 = [
{"text_input": "A bus"},
{"image": get_path + "IMG_3758.png"},
]
similarity, sorted_list = my_obj.multimodal_search(
model,
vis_processor,
txt_processor,
pre_model,
image_keys,
features_image_stacked,
search_query2,
)
for i, num in zip(range(len(pre_simularity)), similarity.tolist()):
for j, num2 in zip(range(len(num)), num):
assert (
math.isclose(num2, pre_simularity[i][j], rel_tol=100 * related_error)
is True
)
for i, num in zip(range(len(pre_sorted)), sorted_list):
for j, num2 in zip(range(2), num):
assert num2 == pre_sorted[i][j]
del (
model,
vis_processor,
txt_processor,
similarity,
features_image_stacked,
processed_pic,
multi_features_stacked,
my_obj,
)
cuda.empty_cache()
@pytest.mark.long
def test_itm(get_test_my_dict, get_path):
search_query3 = [
{"text_input": "A bus"},
{"image": get_path + "IMG_3758.png"},
]
image_keys = ["IMG_2746", "IMG_2809"]
sorted_list = [[1, 0], [1, 0]]
my_obj = ms.MultimodalSearch(get_test_my_dict)
for itm_model in ["blip_base", "blip_large"]:
(
itm_scores,
image_gradcam_with_itm,
) = my_obj.image_text_match_reordering(
search_query3,
itm_model,
image_keys,
sorted_list,
batch_size=1,
need_grad_cam=True,
)
for i, itm in zip(
range(len(dict_itm_scores_for_blib[itm_model])),
dict_itm_scores_for_blib[itm_model],
):
assert (
math.isclose(itm_scores[0].tolist()[i], itm, rel_tol=10 * related_error)
is True
)
for i, grad_cam in zip(
range(len(dict_image_gradcam_with_itm_for_blip[itm_model])),
dict_image_gradcam_with_itm_for_blip[itm_model],
):
assert (
math.isclose(
image_gradcam_with_itm["A bus"]["IMG_2809"][0][0].tolist()[i],
grad_cam,
rel_tol=10 * related_error,
)
is True
)
del itm_scores, image_gradcam_with_itm
cuda.empty_cache()
@pytest.mark.long
def test_itm_blip2_coco(get_test_my_dict, get_path):
search_query3 = [
{"text_input": "A bus"},
{"image": get_path + "IMG_3758.png"},
]
image_keys = ["IMG_2746", "IMG_2809"]
sorted_list = [[1, 0], [1, 0]]
my_obj = ms.MultimodalSearch(get_test_my_dict)
(
itm_scores,
image_gradcam_with_itm,
) = my_obj.image_text_match_reordering(
search_query3,
"blip2_coco",
image_keys,
sorted_list,
batch_size=1,
need_grad_cam=False,
)
for i, itm in zip(
range(len(dict_itm_scores_for_blib["blip2_coco"])),
dict_itm_scores_for_blib["blip2_coco"],
):
assert (
math.isclose(itm_scores[0].tolist()[i], itm, rel_tol=10 * related_error)
is True
)
del itm_scores, image_gradcam_with_itm
cuda.empty_cache()