зеркало из
https://github.com/ssciwr/AMMICO.git
synced 2025-10-29 13:06:04 +02:00
maintain: remove summary module (VQA)
Этот коммит содержится в:
родитель
95d80135e2
Коммит
f2c97e26ff
@ -1,992 +0,0 @@
|
||||
from ammico.utils import AnalysisMethod
|
||||
import torch
|
||||
import torch.nn.functional as Func
|
||||
import requests
|
||||
import lavis
|
||||
import os
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from skimage import transform as skimage_transform
|
||||
from scipy.ndimage import filters
|
||||
from matplotlib import pyplot as plt
|
||||
from IPython.display import display
|
||||
from lavis.models import load_model_and_preprocess, load_model, BlipBase
|
||||
from lavis.processors import load_processor
|
||||
|
||||
|
||||
class MultimodalSearch(AnalysisMethod):
|
||||
def __init__(self, subdict: dict) -> None:
|
||||
super().__init__(subdict)
|
||||
|
||||
multimodal_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
def load_feature_extractor_model_blip2(self, device: str = "cpu"):
|
||||
"""
|
||||
Load pretrain blip2_feature_extractor model and preprocessors for visual and text inputs from lavis.models.
|
||||
|
||||
Args:
|
||||
device (str): device to use. Can be "cpu" or "cuda". Default: "cpu".
|
||||
|
||||
Returns:
|
||||
model (torch.nn.Module): model.
|
||||
vis_processors (dict): preprocessors for visual inputs.
|
||||
txt_processors (dict): preprocessors for text inputs.
|
||||
"""
|
||||
model, vis_processors, txt_processors = load_model_and_preprocess(
|
||||
name="blip2_feature_extractor",
|
||||
model_type="pretrain",
|
||||
is_eval=True,
|
||||
device=device,
|
||||
)
|
||||
return model, vis_processors, txt_processors
|
||||
|
||||
def load_feature_extractor_model_blip(self, device: str = "cpu"):
|
||||
"""
|
||||
Load base blip_feature_extractor model and preprocessors for visual and text inputs from lavis.models.
|
||||
|
||||
Args:
|
||||
device (str): device to use. Can be "cpu" or "cuda". Default: "cpu".
|
||||
|
||||
Returns:
|
||||
model (torch.nn.Module): model.
|
||||
vis_processors (dict): preprocessors for visual inputs.
|
||||
txt_processors (dict): preprocessors for text inputs.
|
||||
"""
|
||||
model, vis_processors, txt_processors = load_model_and_preprocess(
|
||||
name="blip_feature_extractor",
|
||||
model_type="base",
|
||||
is_eval=True,
|
||||
device=device,
|
||||
)
|
||||
return model, vis_processors, txt_processors
|
||||
|
||||
def load_feature_extractor_model_albef(self, device: str = "cpu"):
|
||||
"""
|
||||
Load base albef_feature_extractor model and preprocessors for visual and text inputs from lavis.models.
|
||||
|
||||
Args:
|
||||
device (str): device to use. Can be "cpu" or "cuda". Default: "cpu".
|
||||
|
||||
Returns:
|
||||
model (torch.nn.Module): model.
|
||||
vis_processors (dict): preprocessors for visual inputs.
|
||||
txt_processors (dict): preprocessors for text inputs.
|
||||
"""
|
||||
model, vis_processors, txt_processors = load_model_and_preprocess(
|
||||
name="albef_feature_extractor",
|
||||
model_type="base",
|
||||
is_eval=True,
|
||||
device=device,
|
||||
)
|
||||
return model, vis_processors, txt_processors
|
||||
|
||||
def load_feature_extractor_model_clip_base(self, device: str = "cpu"):
|
||||
"""
|
||||
Load base clip_feature_extractor model and preprocessors for visual and text inputs from lavis.models.
|
||||
|
||||
Args:
|
||||
device (str): device to use. Can be "cpu" or "cuda". Default: "cpu".
|
||||
|
||||
Returns:
|
||||
model (torch.nn.Module): model.
|
||||
vis_processors (dict): preprocessors for visual inputs.
|
||||
txt_processors (dict): preprocessors for text inputs.
|
||||
"""
|
||||
model, vis_processors, txt_processors = load_model_and_preprocess(
|
||||
name="clip_feature_extractor",
|
||||
model_type="base",
|
||||
is_eval=True,
|
||||
device=device,
|
||||
)
|
||||
return model, vis_processors, txt_processors
|
||||
|
||||
def load_feature_extractor_model_clip_vitl14(self, device: str = "cpu"):
|
||||
"""
|
||||
Load ViT-L-14 clip_feature_extractor model and preprocessors for visual and text inputs from lavis.models.
|
||||
|
||||
Args:
|
||||
device (str): device to use. Can be "cpu" or "cuda". Default: "cpu".
|
||||
|
||||
Returns:
|
||||
model (torch.nn.Module): model.
|
||||
vis_processors (dict): preprocessors for visual inputs.
|
||||
txt_processors (dict): preprocessors for text inputs.
|
||||
"""
|
||||
model, vis_processors, txt_processors = load_model_and_preprocess(
|
||||
name="clip_feature_extractor",
|
||||
model_type="ViT-L-14",
|
||||
is_eval=True,
|
||||
device=device,
|
||||
)
|
||||
return model, vis_processors, txt_processors
|
||||
|
||||
def load_feature_extractor_model_clip_vitl14_336(self, device: str = "cpu"):
|
||||
"""
|
||||
Load ViT-L-14-336 clip_feature_extractor model and preprocessors for visual and text inputs from lavis.models.
|
||||
|
||||
Args:
|
||||
device (str): device to use. Can be "cpu" or "cuda". Default: "cpu".
|
||||
|
||||
Returns:
|
||||
model (torch.nn.Module): model.
|
||||
vis_processors (dict): preprocessors for visual inputs.
|
||||
txt_processors (dict): preprocessors for text inputs.
|
||||
"""
|
||||
model, vis_processors, txt_processors = load_model_and_preprocess(
|
||||
name="clip_feature_extractor",
|
||||
model_type="ViT-L-14-336",
|
||||
is_eval=True,
|
||||
device=device,
|
||||
)
|
||||
return model, vis_processors, txt_processors
|
||||
|
||||
def read_img(self, filepath: str) -> Image:
|
||||
"""
|
||||
Load Image from filepath.
|
||||
|
||||
Args:
|
||||
filepath (str): path to image.
|
||||
|
||||
Returns:
|
||||
raw_image (PIL.Image): image.
|
||||
"""
|
||||
raw_image = Image.open(filepath).convert("RGB")
|
||||
return raw_image
|
||||
|
||||
def read_and_process_images(self, image_paths: list, vis_processor) -> tuple:
|
||||
"""
|
||||
Read and process images with vis_processor.
|
||||
|
||||
Args:
|
||||
image_paths (str): paths to images.
|
||||
vis_processor (dict): preprocessors for visual inputs.
|
||||
|
||||
Returns:
|
||||
raw_images (list): list of images.
|
||||
images_tensors (torch.Tensor): tensors of images stacked in device.
|
||||
"""
|
||||
raw_images = [MultimodalSearch.read_img(self, path) for path in image_paths]
|
||||
images = [
|
||||
vis_processor["eval"](r_img)
|
||||
.unsqueeze(0)
|
||||
.to(MultimodalSearch.multimodal_device)
|
||||
for r_img in raw_images
|
||||
]
|
||||
images_tensors = torch.stack(images)
|
||||
|
||||
return raw_images, images_tensors
|
||||
|
||||
def extract_image_features_blip2(
|
||||
self, model, images_tensors: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Extract image features from images_tensors with blip2_feature_extractor model.
|
||||
|
||||
Args:
|
||||
model (torch.nn.Module): model.
|
||||
images_tensors (torch.Tensor): tensors of images stacked in device.
|
||||
|
||||
Returns:
|
||||
features_image_stacked (torch.Tensor): tensors of images features stacked in device.
|
||||
"""
|
||||
with torch.cuda.amp.autocast(
|
||||
enabled=(MultimodalSearch.multimodal_device != torch.device("cpu"))
|
||||
):
|
||||
features_image = [
|
||||
model.extract_features({"image": ten, "text_input": ""}, mode="image")
|
||||
for ten in images_tensors
|
||||
]
|
||||
features_image_stacked = torch.stack(
|
||||
[feat.image_embeds_proj[:, 0, :].squeeze(0) for feat in features_image]
|
||||
)
|
||||
return features_image_stacked
|
||||
|
||||
def extract_image_features_clip(
|
||||
self, model, images_tensors: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Extract image features from images_tensors with clip_feature_extractor model.
|
||||
|
||||
Args:
|
||||
model (torch.nn.Module): model.
|
||||
images_tensors (torch.Tensor): tensors of images stacked in device.
|
||||
|
||||
Returns:
|
||||
features_image_stacked (torch.Tensor): tensors of images features stacked in device.
|
||||
"""
|
||||
features_image = [
|
||||
model.extract_features({"image": ten}) for ten in images_tensors
|
||||
]
|
||||
features_image_stacked = torch.stack(
|
||||
[Func.normalize(feat.float(), dim=-1).squeeze(0) for feat in features_image]
|
||||
)
|
||||
return features_image_stacked
|
||||
|
||||
def extract_image_features_basic(
|
||||
self, model, images_tensors: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Extract image features from images_tensors with blip_feature_extractor or albef_feature_extractor model.
|
||||
|
||||
Args:
|
||||
model (torch.nn.Module): model.
|
||||
images_tensors (torch.Tensor): tensors of images stacked in device.
|
||||
|
||||
Returns:
|
||||
features_image_stacked (torch.Tensor): tensors of images features stacked in device.
|
||||
"""
|
||||
features_image = [
|
||||
model.extract_features({"image": ten, "text_input": ""}, mode="image")
|
||||
for ten in images_tensors
|
||||
]
|
||||
features_image_stacked = torch.stack(
|
||||
[feat.image_embeds_proj[:, 0, :].squeeze(0) for feat in features_image]
|
||||
)
|
||||
return features_image_stacked
|
||||
|
||||
def save_tensors(
|
||||
self,
|
||||
model_type: str,
|
||||
features_image_stacked: torch.Tensor,
|
||||
name: str = "saved_features_image.pt",
|
||||
path: str = "./saved_tensors/",
|
||||
) -> str:
|
||||
"""
|
||||
Save tensors as binary to given path.
|
||||
|
||||
Args:
|
||||
model_type (str): type of the model.
|
||||
features_image_stacked (torch.Tensor): tensors of images features stacked in device.
|
||||
name (str): name of the file. Default: "saved_features_image.pt".
|
||||
path (str): path to save the file. Default: "./saved_tensors/".
|
||||
|
||||
Returns:
|
||||
name (str): name of the file.
|
||||
"""
|
||||
if not os.path.exists(path):
|
||||
os.makedirs(path)
|
||||
with open(
|
||||
str(path)
|
||||
+ str(len(features_image_stacked))
|
||||
+ "_"
|
||||
+ model_type
|
||||
+ "_"
|
||||
+ name,
|
||||
"wb",
|
||||
) as f:
|
||||
torch.save(features_image_stacked, f)
|
||||
return name
|
||||
|
||||
def load_tensors(self, name: str) -> torch.Tensor:
|
||||
"""
|
||||
Load tensors from given path.
|
||||
|
||||
Args:
|
||||
name (str): name of the file.
|
||||
|
||||
Returns:
|
||||
features_image_stacked (torch.Tensor): tensors of images features.
|
||||
"""
|
||||
features_image_stacked = torch.load(name, weights_only=True)
|
||||
return features_image_stacked
|
||||
|
||||
def extract_text_features(self, model, text_input: str) -> torch.Tensor:
|
||||
"""
|
||||
Extract text features from text_input with feature_extractor model.
|
||||
|
||||
Args:
|
||||
model (torch.nn.Module): model.
|
||||
text_input (str): text.
|
||||
|
||||
Returns:
|
||||
features_text (torch.Tensor): tensors of text features.
|
||||
"""
|
||||
sample_text = {"text_input": [text_input]}
|
||||
features_text = model.extract_features(sample_text, mode="text")
|
||||
|
||||
return features_text
|
||||
|
||||
def parsing_images(
|
||||
self,
|
||||
model_type: str,
|
||||
path_to_save_tensors: str = "./saved_tensors/",
|
||||
path_to_load_tensors: str = None,
|
||||
) -> tuple:
|
||||
"""
|
||||
Parsing images with feature_extractor model.
|
||||
|
||||
Args:
|
||||
model_type (str): type of the model.
|
||||
path_to_save_tensors (str): path to save the tensors. Default: "./saved_tensors/".
|
||||
path_to_load_tensors (str): path to load the tesors. Default: None.
|
||||
|
||||
Returns:
|
||||
model (torch.nn.Module): model.
|
||||
vis_processors (dict): preprocessors for visual inputs.
|
||||
txt_processors (dict): preprocessors for text inputs.
|
||||
image_keys (list): sorted list of image keys.
|
||||
image_names (list): sorted list of image names.
|
||||
features_image_stacked (torch.Tensor): tensors of images features stacked in device.
|
||||
"""
|
||||
if model_type in ("clip_base", "clip_vitl14_336", "clip_vitl14"):
|
||||
path_to_lib = lavis.__file__[:-11] + "models/clip_models/"
|
||||
url = "https://raw.githubusercontent.com/salesforce/LAVIS/main/lavis/models/clip_models/bpe_simple_vocab_16e6.txt.gz"
|
||||
r = requests.get(url, allow_redirects=False)
|
||||
with open(path_to_lib + "bpe_simple_vocab_16e6.txt.gz", "wb") as f:
|
||||
f.write(r.content)
|
||||
|
||||
image_keys = sorted(self.subdict.keys())
|
||||
image_names = [self.subdict[k]["filename"] for k in image_keys]
|
||||
|
||||
select_model = {
|
||||
"blip2": MultimodalSearch.load_feature_extractor_model_blip2,
|
||||
"blip": MultimodalSearch.load_feature_extractor_model_blip,
|
||||
"albef": MultimodalSearch.load_feature_extractor_model_albef,
|
||||
"clip_base": MultimodalSearch.load_feature_extractor_model_clip_base,
|
||||
"clip_vitl14": MultimodalSearch.load_feature_extractor_model_clip_vitl14,
|
||||
"clip_vitl14_336": MultimodalSearch.load_feature_extractor_model_clip_vitl14_336,
|
||||
}
|
||||
|
||||
select_extract_image_features = {
|
||||
"blip2": MultimodalSearch.extract_image_features_blip2,
|
||||
"blip": MultimodalSearch.extract_image_features_basic,
|
||||
"albef": MultimodalSearch.extract_image_features_basic,
|
||||
"clip_base": MultimodalSearch.extract_image_features_clip,
|
||||
"clip_vitl14": MultimodalSearch.extract_image_features_clip,
|
||||
"clip_vitl14_336": MultimodalSearch.extract_image_features_clip,
|
||||
}
|
||||
|
||||
if model_type in select_model.keys():
|
||||
(
|
||||
model,
|
||||
vis_processors,
|
||||
txt_processors,
|
||||
) = select_model[
|
||||
model_type
|
||||
](self, MultimodalSearch.multimodal_device)
|
||||
else:
|
||||
raise SyntaxError(
|
||||
"Please, use one of the following models: blip2, blip, albef, clip_base, clip_vitl14, clip_vitl14_336"
|
||||
)
|
||||
|
||||
_, images_tensors = MultimodalSearch.read_and_process_images(
|
||||
self, image_names, vis_processors
|
||||
)
|
||||
if path_to_load_tensors is None:
|
||||
with torch.no_grad():
|
||||
features_image_stacked = select_extract_image_features[model_type](
|
||||
self, model, images_tensors
|
||||
)
|
||||
MultimodalSearch.save_tensors(
|
||||
self, model_type, features_image_stacked, path=path_to_save_tensors
|
||||
)
|
||||
else:
|
||||
features_image_stacked = MultimodalSearch.load_tensors(
|
||||
self, str(path_to_load_tensors)
|
||||
)
|
||||
|
||||
return (
|
||||
model,
|
||||
vis_processors,
|
||||
txt_processors,
|
||||
image_keys,
|
||||
image_names,
|
||||
features_image_stacked,
|
||||
)
|
||||
|
||||
def querys_processing(
|
||||
self, search_query: list, model, txt_processors, vis_processors, model_type: str
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Process querys.
|
||||
|
||||
Args:
|
||||
search_query (list): list of querys.
|
||||
model (torch.nn.Module): model.
|
||||
txt_processors (dict): preprocessors for text inputs.
|
||||
vis_processors (dict): preprocessors for visual inputs.
|
||||
model_type (str): type of the model.
|
||||
|
||||
Returns:
|
||||
multi_features_stacked (torch.Tensor): tensors of querys features.
|
||||
"""
|
||||
select_extract_image_features = {
|
||||
"blip2": MultimodalSearch.extract_image_features_blip2,
|
||||
"blip": MultimodalSearch.extract_image_features_basic,
|
||||
"albef": MultimodalSearch.extract_image_features_basic,
|
||||
"clip_base": MultimodalSearch.extract_image_features_clip,
|
||||
"clip_vitl14": MultimodalSearch.extract_image_features_clip,
|
||||
"clip_vitl14_336": MultimodalSearch.extract_image_features_clip,
|
||||
}
|
||||
|
||||
for query in search_query:
|
||||
if len(query) != 1 and (query in ("image", "text_input")):
|
||||
raise SyntaxError(
|
||||
'Each query must contain either an "image" or a "text_input"'
|
||||
)
|
||||
multi_sample = []
|
||||
for query in search_query:
|
||||
if "text_input" in query.keys():
|
||||
text_processing = txt_processors["eval"](query["text_input"])
|
||||
images_tensors = ""
|
||||
elif "image" in query.keys():
|
||||
_, images_tensors = MultimodalSearch.read_and_process_images(
|
||||
self, [query["image"]], vis_processors
|
||||
)
|
||||
text_processing = ""
|
||||
multi_sample.append(
|
||||
{"image": images_tensors, "text_input": text_processing}
|
||||
)
|
||||
|
||||
multi_features_query = []
|
||||
for query in multi_sample:
|
||||
if query["image"] == "":
|
||||
if model_type in ("clip_base", "clip_vitl14_336", "clip_vitl14"):
|
||||
features = model.extract_features(
|
||||
{"text_input": query["text_input"]}
|
||||
)
|
||||
|
||||
features_squeeze = features.squeeze(0).to(
|
||||
MultimodalSearch.multimodal_device
|
||||
)
|
||||
multi_features_query.append(
|
||||
Func.normalize(features_squeeze, dim=-1)
|
||||
)
|
||||
else:
|
||||
features = model.extract_features(query, mode="text")
|
||||
features_squeeze = (
|
||||
features.text_embeds_proj[:, 0, :]
|
||||
.squeeze(0)
|
||||
.to(MultimodalSearch.multimodal_device)
|
||||
)
|
||||
multi_features_query.append(features_squeeze)
|
||||
if query["text_input"] == "":
|
||||
multi_features_query.append(
|
||||
select_extract_image_features[model_type](
|
||||
self, model, query["image"]
|
||||
)
|
||||
)
|
||||
|
||||
multi_features_stacked = torch.stack(
|
||||
[query.squeeze(0) for query in multi_features_query]
|
||||
).to(MultimodalSearch.multimodal_device)
|
||||
|
||||
return multi_features_stacked
|
||||
|
||||
def multimodal_search(
|
||||
self,
|
||||
model,
|
||||
vis_processors,
|
||||
txt_processors,
|
||||
model_type: str,
|
||||
image_keys: list,
|
||||
features_image_stacked: torch.Tensor,
|
||||
search_query: list,
|
||||
filter_number_of_images: str = None,
|
||||
filter_val_limit: str = None,
|
||||
filter_rel_error: str = None,
|
||||
) -> tuple:
|
||||
"""
|
||||
Search for images with given querys.
|
||||
|
||||
Args:
|
||||
model (torch.nn.Module): model.
|
||||
vis_processors (dict): preprocessors for visual inputs.
|
||||
txt_processors (dict): preprocessors for text inputs.
|
||||
model_type (str): type of the model.
|
||||
image_keys (list): sorted list of image keys.
|
||||
features_image_stacked (torch.Tensor): tensors of images features stacked in device.
|
||||
search_query (list): list of querys.
|
||||
filter_number_of_images (str): number of images to show. Default: None.
|
||||
filter_val_limit (str): limit of similarity value. Default: None.
|
||||
filter_rel_error (str): limit of relative error. Default: None.
|
||||
|
||||
Returns:
|
||||
similarity (torch.Tensor): similarity between images and querys.
|
||||
sorted_lists (list): sorted list of similarity.
|
||||
"""
|
||||
if filter_number_of_images is None:
|
||||
filter_number_of_images = len(self.subdict)
|
||||
if filter_val_limit is None:
|
||||
filter_val_limit = 0
|
||||
if filter_rel_error is None:
|
||||
filter_rel_error = 1e10
|
||||
|
||||
features_image_stacked.to(MultimodalSearch.multimodal_device)
|
||||
|
||||
with torch.no_grad():
|
||||
multi_features_stacked = MultimodalSearch.querys_processing(
|
||||
self, search_query, model, txt_processors, vis_processors, model_type
|
||||
)
|
||||
|
||||
similarity = features_image_stacked @ multi_features_stacked.t()
|
||||
sorted_lists = torch.argsort(similarity, dim=0, descending=True).T.tolist()
|
||||
places = [[item.index(i) for i in range(len(item))] for item in sorted_lists]
|
||||
|
||||
for q in range(len(search_query)):
|
||||
max_val = similarity[sorted_lists[q][0]][q].item()
|
||||
for i, key in zip(range(len(image_keys)), sorted_lists[q]):
|
||||
if (
|
||||
i < filter_number_of_images
|
||||
and similarity[key][q].item() > filter_val_limit
|
||||
and 100 * abs(max_val - similarity[key][q].item()) / max_val
|
||||
< filter_rel_error
|
||||
):
|
||||
self.subdict[image_keys[key]][
|
||||
"rank " + list(search_query[q].values())[0]
|
||||
] = places[q][key]
|
||||
self.subdict[image_keys[key]][list(search_query[q].values())[0]] = (
|
||||
similarity[key][q].item()
|
||||
)
|
||||
else:
|
||||
self.subdict[image_keys[key]][
|
||||
"rank " + list(search_query[q].values())[0]
|
||||
] = None
|
||||
self.subdict[image_keys[key]][list(search_query[q].values())[0]] = 0
|
||||
return similarity, sorted_lists
|
||||
|
||||
def itm_text_precessing(self, search_query: list[dict[str, str]]) -> list:
|
||||
"""
|
||||
Process text querys for itm model.
|
||||
|
||||
Args:
|
||||
search_query (list): list of querys.
|
||||
|
||||
Returns:
|
||||
text_query_index (list): list of indexes of text querys.
|
||||
"""
|
||||
for query in search_query:
|
||||
if (len(query) != 1) and (query in ("image", "text_input")):
|
||||
raise SyntaxError(
|
||||
'Each querry must contain either an "image" or a "text_input"'
|
||||
)
|
||||
text_query_index = []
|
||||
for i, query in zip(range(len(search_query)), search_query):
|
||||
if "text_input" in query.keys():
|
||||
text_query_index.append(i)
|
||||
|
||||
return text_query_index
|
||||
|
||||
def get_pathes_from_query(self, query: dict[str, str]) -> tuple:
|
||||
"""
|
||||
Get pathes and image names from query.
|
||||
|
||||
Args:
|
||||
query (dict): query.
|
||||
|
||||
Returns:
|
||||
paths (list): list of pathes.
|
||||
image_names (list): list of image names.
|
||||
"""
|
||||
paths = []
|
||||
image_names = []
|
||||
for s in sorted(
|
||||
self.subdict.items(),
|
||||
key=lambda t: t[1][list(query.values())[0]],
|
||||
reverse=True,
|
||||
):
|
||||
if s[1]["rank " + list(query.values())[0]] is None:
|
||||
break
|
||||
paths.append(s[1]["filename"])
|
||||
image_names.append(s[0])
|
||||
return paths, image_names
|
||||
|
||||
def read_and_process_images_itm(self, image_paths: list, vis_processor) -> tuple:
|
||||
"""
|
||||
Read and process images with vis_processor for itm model.
|
||||
|
||||
Args:
|
||||
image_paths (list): paths to images.
|
||||
vis_processor (dict): preprocessors for visual inputs.
|
||||
|
||||
Returns:
|
||||
raw_images (list): list of images.
|
||||
images_tensors (torch.Tensor): tensors of images stacked in device.
|
||||
"""
|
||||
raw_images = [MultimodalSearch.read_img(self, path) for path in image_paths]
|
||||
images = [vis_processor(r_img) for r_img in raw_images]
|
||||
images_tensors = torch.stack(images).to(MultimodalSearch.multimodal_device)
|
||||
|
||||
return raw_images, images_tensors
|
||||
|
||||
def compute_gradcam_batch(
|
||||
self,
|
||||
model: torch.nn.Module,
|
||||
visual_input: torch.Tensor,
|
||||
text_input: str,
|
||||
tokenized_text: torch.Tensor,
|
||||
block_num: str = 6,
|
||||
) -> tuple:
|
||||
"""
|
||||
Compute gradcam for itm model.
|
||||
|
||||
Args:
|
||||
model (torch.nn.Module): model.
|
||||
visual_input (torch.Tensor): tensors of images features stacked in device.
|
||||
text_input (str): text.
|
||||
tokenized_text (torch.Tensor): tokenized text.
|
||||
block_num (int): number of block. Default: 6.
|
||||
|
||||
Returns:
|
||||
gradcam (torch.Tensor): gradcam.
|
||||
output (torch.Tensor): output of model.
|
||||
"""
|
||||
model.text_encoder.base_model.base_model.encoder.layer[
|
||||
block_num
|
||||
].crossattention.self.save_attention = True
|
||||
|
||||
output = model(
|
||||
{"image": visual_input, "text_input": text_input}, match_head="itm"
|
||||
)
|
||||
loss = output[:, 1].sum()
|
||||
|
||||
model.zero_grad()
|
||||
loss.backward()
|
||||
with torch.no_grad():
|
||||
mask = tokenized_text.attention_mask.view(
|
||||
tokenized_text.attention_mask.size(0), 1, -1, 1, 1
|
||||
) # (bsz,1,token_len, 1,1)
|
||||
token_length = mask.sum() - 2
|
||||
token_length = token_length.cpu()
|
||||
# grads and cams [bsz, num_head, seq_len, image_patch]
|
||||
grads = model.text_encoder.base_model.base_model.encoder.layer[
|
||||
block_num
|
||||
].crossattention.self.get_attn_gradients()
|
||||
cams = model.text_encoder.base_model.base_model.encoder.layer[
|
||||
block_num
|
||||
].crossattention.self.get_attention_map()
|
||||
|
||||
# assume using vit large with 576 num image patch
|
||||
cams = (
|
||||
cams[:, :, :, 1:].reshape(visual_input.size(0), 12, -1, 24, 24) * mask
|
||||
)
|
||||
grads = (
|
||||
grads[:, :, :, 1:]
|
||||
.clamp(0)
|
||||
.reshape(visual_input.size(0), 12, -1, 24, 24)
|
||||
* mask
|
||||
)
|
||||
|
||||
gradcam = cams * grads
|
||||
# [enc token gradcam, average gradcam across token, gradcam for individual token]
|
||||
# gradcam = torch.cat((gradcam[0:1,:], gradcam[1:token_length+1, :].sum(dim=0, keepdim=True)/token_length, gradcam[1:, :]))
|
||||
gradcam = gradcam.mean(1).cpu().detach()
|
||||
gradcam = (
|
||||
gradcam[:, 1 : token_length + 1, :].sum(dim=1, keepdim=True)
|
||||
/ token_length
|
||||
)
|
||||
|
||||
return gradcam, output
|
||||
|
||||
def resize_img(self, raw_img: Image):
|
||||
"""
|
||||
Proportional resize image to 240 p width.
|
||||
|
||||
Args:
|
||||
raw_img (PIL.Image): image.
|
||||
|
||||
Returns:
|
||||
resized_image (PIL.Image): proportional resized image to 240p.
|
||||
"""
|
||||
w, h = raw_img.size
|
||||
scaling_factor = 240 / w
|
||||
resized_image = raw_img.resize(
|
||||
(int(w * scaling_factor), int(h * scaling_factor))
|
||||
)
|
||||
return resized_image
|
||||
|
||||
def get_att_map(
|
||||
self,
|
||||
img: np.ndarray,
|
||||
att_map: np.ndarray,
|
||||
blur: bool = True,
|
||||
overlap: bool = True,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Get attention map.
|
||||
|
||||
Args:
|
||||
img (np.ndarray): image.
|
||||
att_map (np.ndarray): attention map.
|
||||
blur (bool): blur attention map. Default: True.
|
||||
overlap (bool): overlap attention map with image. Default: True.
|
||||
|
||||
Returns:
|
||||
att_map (np.ndarray): attention map.
|
||||
"""
|
||||
att_map -= att_map.min()
|
||||
if att_map.max() > 0:
|
||||
att_map /= att_map.max()
|
||||
att_map = skimage_transform.resize(
|
||||
att_map, (img.shape[:2]), order=3, mode="constant"
|
||||
)
|
||||
if blur:
|
||||
att_map = filters.gaussian_filter(att_map, 0.02 * max(img.shape[:2]))
|
||||
att_map -= att_map.min()
|
||||
att_map /= att_map.max()
|
||||
cmap = plt.get_cmap("jet")
|
||||
att_mapv = cmap(att_map)
|
||||
att_mapv = np.delete(att_mapv, 3, 2)
|
||||
if overlap:
|
||||
att_map = (
|
||||
1 * (1 - att_map**0.7).reshape(att_map.shape + (1,)) * img
|
||||
+ (att_map**0.7).reshape(att_map.shape + (1,)) * att_mapv
|
||||
)
|
||||
return att_map
|
||||
|
||||
def upload_model_blip2_coco(self) -> tuple:
|
||||
"""
|
||||
Load coco blip2_image_text_matching model and preprocessors for visual inputs from lavis.models.
|
||||
|
||||
Returns:
|
||||
itm_model (torch.nn.Module): model.
|
||||
vis_processor (dict): preprocessors for visual inputs.
|
||||
"""
|
||||
itm_model = load_model(
|
||||
"blip2_image_text_matching",
|
||||
"coco",
|
||||
is_eval=True,
|
||||
device=MultimodalSearch.multimodal_device,
|
||||
)
|
||||
vis_processor = load_processor("blip_image_eval").build(image_size=364)
|
||||
return itm_model, vis_processor
|
||||
|
||||
def upload_model_blip_base(self) -> tuple:
|
||||
"""
|
||||
Load base blip_image_text_matching model and preprocessors for visual input from lavis.models.
|
||||
|
||||
Returns:
|
||||
itm_model (torch.nn.Module): model.
|
||||
vis_processor (dict): preprocessors for visual inputs.
|
||||
"""
|
||||
itm_model = load_model(
|
||||
"blip_image_text_matching",
|
||||
"base",
|
||||
is_eval=True,
|
||||
device=MultimodalSearch.multimodal_device,
|
||||
)
|
||||
vis_processor = load_processor("blip_image_eval").build(image_size=384)
|
||||
return itm_model, vis_processor
|
||||
|
||||
def upload_model_blip_large(self) -> tuple:
|
||||
"""
|
||||
Load large blip_image_text_matching model and preprocessors for visual input from lavis.models.
|
||||
|
||||
Returns:
|
||||
itm_model (torch.nn.Module): model.
|
||||
vis_processor (dict): preprocessors for visual inputs.
|
||||
"""
|
||||
itm_model = load_model(
|
||||
"blip_image_text_matching",
|
||||
"large",
|
||||
is_eval=True,
|
||||
device=MultimodalSearch.multimodal_device,
|
||||
)
|
||||
vis_processor = load_processor("blip_image_eval").build(image_size=384)
|
||||
return itm_model, vis_processor
|
||||
|
||||
def image_text_match_reordering(
|
||||
self,
|
||||
search_query: list[dict[str, str]],
|
||||
itm_model_type: str,
|
||||
image_keys: list,
|
||||
sorted_lists: list[list],
|
||||
batch_size: int = 1,
|
||||
need_grad_cam: bool = False,
|
||||
) -> tuple:
|
||||
"""
|
||||
Reorder images with itm model.
|
||||
|
||||
Args:
|
||||
search_query (list): list of querys.
|
||||
itm_model_type (str): type of the model.
|
||||
image_keys (list): sorted list of image keys.
|
||||
sorted_lists (list): sorted list of similarity.
|
||||
batch_size (int): batch size. Default: 1.
|
||||
need_grad_cam (bool): need gradcam. Default: False. blip2_coco model does not yet work with gradcam.
|
||||
|
||||
Returns:
|
||||
itm_scores2: list of itm scores.
|
||||
image_gradcam_with_itm: dict of image names and gradcam.
|
||||
"""
|
||||
if itm_model_type == "blip2_coco" and need_grad_cam is True:
|
||||
raise SyntaxError(
|
||||
"The blip2_coco model does not yet work with gradcam. Please set need_grad_cam to False"
|
||||
)
|
||||
|
||||
choose_model = {
|
||||
"blip_base": MultimodalSearch.upload_model_blip_base,
|
||||
"blip_large": MultimodalSearch.upload_model_blip_large,
|
||||
"blip2_coco": MultimodalSearch.upload_model_blip2_coco,
|
||||
}
|
||||
|
||||
itm_model, vis_processor_itm = choose_model[itm_model_type](self)
|
||||
text_processor = load_processor("blip_caption")
|
||||
tokenizer = BlipBase.init_tokenizer()
|
||||
|
||||
if itm_model_type == "blip2_coco":
|
||||
need_grad_cam = False
|
||||
|
||||
text_query_index = MultimodalSearch.itm_text_precessing(self, search_query)
|
||||
|
||||
avg_gradcams = []
|
||||
itm_scores = []
|
||||
itm_scores2 = []
|
||||
image_gradcam_with_itm = {}
|
||||
|
||||
for index_text_query in text_query_index:
|
||||
query = search_query[index_text_query]
|
||||
pathes, image_names = MultimodalSearch.get_pathes_from_query(self, query)
|
||||
num_batches = int(len(pathes) / batch_size)
|
||||
num_batches_residue = len(pathes) % batch_size
|
||||
|
||||
local_itm_scores = []
|
||||
local_avg_gradcams = []
|
||||
|
||||
if num_batches_residue != 0:
|
||||
num_batches = num_batches + 1
|
||||
for i in range(num_batches):
|
||||
filenames_in_batch = pathes[i * batch_size : (i + 1) * batch_size]
|
||||
current_len = len(filenames_in_batch)
|
||||
raw_images, images = MultimodalSearch.read_and_process_images_itm(
|
||||
self, filenames_in_batch, vis_processor_itm
|
||||
)
|
||||
queries_batch = [text_processor(query["text_input"])] * current_len
|
||||
queries_tok_batch = tokenizer(queries_batch, return_tensors="pt").to(
|
||||
MultimodalSearch.multimodal_device
|
||||
)
|
||||
|
||||
if need_grad_cam:
|
||||
gradcam, itm_output = MultimodalSearch.compute_gradcam_batch(
|
||||
self,
|
||||
itm_model,
|
||||
images,
|
||||
queries_batch,
|
||||
queries_tok_batch,
|
||||
)
|
||||
norm_imgs = [np.float32(r_img) / 255 for r_img in raw_images]
|
||||
|
||||
for norm_img, grad_cam in zip(norm_imgs, gradcam):
|
||||
avg_gradcam = MultimodalSearch.get_att_map(
|
||||
self, norm_img, np.float32(grad_cam[0]), blur=True
|
||||
)
|
||||
local_avg_gradcams.append(avg_gradcam)
|
||||
|
||||
else:
|
||||
itm_output = itm_model(
|
||||
{"image": images, "text_input": queries_batch}, match_head="itm"
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
itm_score = torch.nn.functional.softmax(itm_output, dim=1)
|
||||
|
||||
local_itm_scores.append(itm_score)
|
||||
|
||||
local_itm_scores2 = torch.cat(local_itm_scores)[:, 1]
|
||||
if need_grad_cam:
|
||||
localimage_gradcam_with_itm = {
|
||||
n: i * 255 for n, i in zip(image_names, local_avg_gradcams)
|
||||
}
|
||||
else:
|
||||
localimage_gradcam_with_itm = ""
|
||||
image_names_with_itm = {
|
||||
n: i.item() for n, i in zip(image_names, local_itm_scores2)
|
||||
}
|
||||
itm_rank = torch.argsort(local_itm_scores2, descending=True)
|
||||
image_names_with_new_rank = {
|
||||
image_names[i.item()]: rank
|
||||
for i, rank in zip(itm_rank, range(len(itm_rank)))
|
||||
}
|
||||
for i, key in zip(range(len(image_keys)), sorted_lists[index_text_query]):
|
||||
if image_keys[key] in image_names:
|
||||
self.subdict[image_keys[key]][
|
||||
"itm " + list(search_query[index_text_query].values())[0]
|
||||
] = image_names_with_itm[image_keys[key]]
|
||||
self.subdict[image_keys[key]][
|
||||
"itm_rank " + list(search_query[index_text_query].values())[0]
|
||||
] = image_names_with_new_rank[image_keys[key]]
|
||||
else:
|
||||
self.subdict[image_keys[key]][
|
||||
"itm " + list(search_query[index_text_query].values())[0]
|
||||
] = 0
|
||||
self.subdict[image_keys[key]][
|
||||
"itm_rank " + list(search_query[index_text_query].values())[0]
|
||||
] = None
|
||||
|
||||
avg_gradcams.append(local_avg_gradcams)
|
||||
itm_scores.append(local_itm_scores)
|
||||
itm_scores2.append(local_itm_scores2)
|
||||
image_gradcam_with_itm[list(search_query[index_text_query].values())[0]] = (
|
||||
localimage_gradcam_with_itm
|
||||
)
|
||||
del (
|
||||
itm_model,
|
||||
vis_processor_itm,
|
||||
text_processor,
|
||||
raw_images,
|
||||
images,
|
||||
tokenizer,
|
||||
queries_batch,
|
||||
queries_tok_batch,
|
||||
itm_score,
|
||||
)
|
||||
if need_grad_cam:
|
||||
del itm_output, gradcam, norm_img, grad_cam, avg_gradcam
|
||||
torch.cuda.empty_cache()
|
||||
return itm_scores2, image_gradcam_with_itm
|
||||
|
||||
def show_results(
|
||||
self, query: dict, itm: bool = False, image_gradcam_with_itm: dict = {}
|
||||
) -> None:
|
||||
"""
|
||||
Show results of search.
|
||||
|
||||
Args:
|
||||
query (dict): query.
|
||||
itm (bool): use itm model. Default: False.
|
||||
image_gradcam_with_itm (dict): use gradcam. Default: empty.
|
||||
"""
|
||||
if "image" in query.keys():
|
||||
pic = Image.open(query["image"]).convert("RGB")
|
||||
pic.thumbnail((400, 400))
|
||||
display(
|
||||
"Your search query: ",
|
||||
pic,
|
||||
"--------------------------------------------------",
|
||||
"Results:",
|
||||
)
|
||||
elif "text_input" in query.keys():
|
||||
display(
|
||||
"Your search query: " + query["text_input"],
|
||||
"--------------------------------------------------",
|
||||
"Results:",
|
||||
)
|
||||
if itm:
|
||||
current_querry_val = "itm " + list(query.values())[0]
|
||||
current_querry_rank = "itm_rank " + list(query.values())[0]
|
||||
else:
|
||||
current_querry_val = list(query.values())[0]
|
||||
current_querry_rank = "rank " + list(query.values())[0]
|
||||
|
||||
for s in sorted(
|
||||
self.subdict.items(), key=lambda t: t[1][current_querry_val], reverse=True
|
||||
):
|
||||
if s[1][current_querry_rank] is None:
|
||||
break
|
||||
if bool(image_gradcam_with_itm) is True and itm is True:
|
||||
image = image_gradcam_with_itm[list(query.values())[0]][s[0]]
|
||||
p1 = Image.fromarray(image.astype("uint8"), "RGB")
|
||||
else:
|
||||
p1 = Image.open(s[1]["filename"]).convert("RGB")
|
||||
p1.thumbnail((400, 400))
|
||||
display(
|
||||
"Rank: "
|
||||
+ str(s[1][current_querry_rank])
|
||||
+ " Val: "
|
||||
+ str(s[1][current_querry_val]),
|
||||
s[0],
|
||||
p1,
|
||||
)
|
||||
display(
|
||||
"--------------------------------------------------",
|
||||
)
|
||||
@ -1,521 +0,0 @@
|
||||
import pytest
|
||||
import math
|
||||
from PIL import Image
|
||||
import numpy
|
||||
from torch import device, cuda
|
||||
import ammico.multimodal_search as ms
|
||||
|
||||
related_error = 1e-2
|
||||
gpu_is_not_available = not cuda.is_available()
|
||||
|
||||
cuda.empty_cache()
|
||||
|
||||
|
||||
def test_read_img(get_testdict):
|
||||
my_dict = {}
|
||||
test_img = ms.MultimodalSearch.read_img(
|
||||
my_dict, get_testdict["IMG_2746"]["filename"]
|
||||
)
|
||||
assert list(numpy.array(test_img)[257][34]) == [70, 66, 63]
|
||||
|
||||
|
||||
pre_proc_pic_blip2_blip_albef = [
|
||||
-1.0039474964141846,
|
||||
-1.0039474964141846,
|
||||
]
|
||||
pre_proc_pic_clip_vitl14 = [
|
||||
-0.7995694875717163,
|
||||
-0.7849710583686829,
|
||||
]
|
||||
|
||||
pre_proc_pic_clip_vitl14_336 = [
|
||||
-0.7995694875717163,
|
||||
-0.7849710583686829,
|
||||
]
|
||||
|
||||
pre_proc_text_blip2_blip_albef = (
|
||||
"the bird sat on a tree located at the intersection of 23rd and 43rd streets"
|
||||
)
|
||||
|
||||
pre_proc_text_clip_clip_vitl14_clip_vitl14_336 = (
|
||||
"The bird sat on a tree located at the intersection of 23rd and 43rd streets."
|
||||
)
|
||||
|
||||
pre_extracted_feature_img_blip2 = [
|
||||
0.04566730558872223,
|
||||
-0.042554520070552826,
|
||||
]
|
||||
|
||||
pre_extracted_feature_img_blip = [
|
||||
-0.02480311505496502,
|
||||
0.05037587881088257,
|
||||
]
|
||||
|
||||
pre_extracted_feature_img_albef = [
|
||||
0.08971136063337326,
|
||||
-0.10915573686361313,
|
||||
]
|
||||
|
||||
pre_extracted_feature_img_clip = [
|
||||
0.01621132344007492,
|
||||
-0.004035486374050379,
|
||||
]
|
||||
|
||||
pre_extracted_feature_img_parsing_clip = [
|
||||
0.01621132344007492,
|
||||
-0.004035486374050379,
|
||||
]
|
||||
|
||||
pre_extracted_feature_img_clip_vitl14 = [
|
||||
-0.023943455889821053,
|
||||
-0.021703708916902542,
|
||||
]
|
||||
|
||||
pre_extracted_feature_img_clip_vitl14_336 = [
|
||||
-0.009511193260550499,
|
||||
-0.012618942186236382,
|
||||
]
|
||||
|
||||
pre_extracted_feature_text_blip2 = [
|
||||
-0.1384204626083374,
|
||||
-0.008662976324558258,
|
||||
]
|
||||
|
||||
pre_extracted_feature_text_blip = [
|
||||
0.0118643119931221,
|
||||
-0.01291718054562807,
|
||||
]
|
||||
|
||||
pre_extracted_feature_text_albef = [
|
||||
-0.06229640915989876,
|
||||
0.11278597265481949,
|
||||
]
|
||||
|
||||
pre_extracted_feature_text_clip = [
|
||||
0.018169036135077477,
|
||||
0.03634127229452133,
|
||||
]
|
||||
|
||||
pre_extracted_feature_text_clip_vitl14 = [
|
||||
-0.0055463071912527084,
|
||||
0.006908962037414312,
|
||||
]
|
||||
|
||||
pre_extracted_feature_text_clip_vitl14_336 = [
|
||||
-0.008720514364540577,
|
||||
0.005284308455884457,
|
||||
]
|
||||
|
||||
simularity_blip2 = [
|
||||
[0.05826476216316223, -0.02717375010251999],
|
||||
[0.06297147274017334, 0.47339022159576416],
|
||||
]
|
||||
|
||||
sorted_blip2 = [
|
||||
[1, 0],
|
||||
[1, 0],
|
||||
]
|
||||
|
||||
simularity_blip = [
|
||||
[0.15640679001808167, 0.752173662185669],
|
||||
[0.17233705520629883, 0.8448910117149353],
|
||||
]
|
||||
|
||||
sorted_blip = [
|
||||
[1, 0],
|
||||
[1, 0],
|
||||
]
|
||||
|
||||
simularity_albef = [
|
||||
[0.12321824580430984, 0.35511350631713867],
|
||||
[0.10870333760976791, 0.5143978595733643],
|
||||
]
|
||||
|
||||
sorted_albef = [
|
||||
[0, 1],
|
||||
[1, 0],
|
||||
]
|
||||
|
||||
simularity_clip = [
|
||||
[0.23923014104366302, 0.5325412750244141],
|
||||
[0.2310466319322586, 0.5910375714302063],
|
||||
]
|
||||
|
||||
sorted_clip = [
|
||||
[1, 0],
|
||||
[1, 0],
|
||||
]
|
||||
|
||||
simularity_clip_vitl14 = [
|
||||
[0.1051270067691803, 0.5184808373451233],
|
||||
[0.1277746558189392, 0.6841973662376404],
|
||||
]
|
||||
|
||||
sorted_clip_vitl14 = [
|
||||
[1, 0],
|
||||
[1, 0],
|
||||
]
|
||||
|
||||
simularity_clip_vitl14_336 = [
|
||||
[0.09391091763973236, 0.49337542057037354],
|
||||
[0.13700757920742035, 0.7003108263015747],
|
||||
]
|
||||
|
||||
sorted_clip_vitl14_336 = [
|
||||
[1, 0],
|
||||
[1, 0],
|
||||
]
|
||||
|
||||
dict_itm_scores_for_blib = {
|
||||
"blip_base": [
|
||||
0.07107225805521011,
|
||||
0.004100032616406679,
|
||||
],
|
||||
"blip_large": [
|
||||
0.07890705019235611,
|
||||
0.00271016638725996,
|
||||
],
|
||||
"blip2_coco": [
|
||||
0.0833505243062973,
|
||||
0.004216152708977461,
|
||||
],
|
||||
}
|
||||
|
||||
dict_image_gradcam_with_itm_for_blip = {
|
||||
"blip_base": [123.36285799741745, 132.31662154197693, 53.38280035299249],
|
||||
"blip_large": [119.99512910842896, 128.7044593691826, 55.552959859540515],
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.long
|
||||
@pytest.mark.parametrize(
|
||||
(
|
||||
"pre_multimodal_device",
|
||||
"pre_model",
|
||||
"pre_proc_pic",
|
||||
"pre_proc_text",
|
||||
"pre_extracted_feature_img",
|
||||
"pre_extracted_feature_text",
|
||||
"pre_simularity",
|
||||
"pre_sorted",
|
||||
),
|
||||
[
|
||||
(
|
||||
device("cpu"),
|
||||
"blip2",
|
||||
pre_proc_pic_blip2_blip_albef,
|
||||
pre_proc_text_blip2_blip_albef,
|
||||
pre_extracted_feature_img_blip2,
|
||||
pre_extracted_feature_text_blip2,
|
||||
simularity_blip2,
|
||||
sorted_blip2,
|
||||
),
|
||||
pytest.param(
|
||||
device("cuda"),
|
||||
"blip2",
|
||||
pre_proc_pic_blip2_blip_albef,
|
||||
pre_proc_text_blip2_blip_albef,
|
||||
pre_extracted_feature_img_blip2,
|
||||
pre_extracted_feature_text_blip2,
|
||||
simularity_blip2,
|
||||
sorted_blip2,
|
||||
marks=pytest.mark.skipif(
|
||||
gpu_is_not_available, reason="gpu_is_not_availible"
|
||||
),
|
||||
),
|
||||
(
|
||||
device("cpu"),
|
||||
"blip",
|
||||
pre_proc_pic_blip2_blip_albef,
|
||||
pre_proc_text_blip2_blip_albef,
|
||||
pre_extracted_feature_img_blip,
|
||||
pre_extracted_feature_text_blip,
|
||||
simularity_blip,
|
||||
sorted_blip,
|
||||
),
|
||||
pytest.param(
|
||||
device("cuda"),
|
||||
"blip",
|
||||
pre_proc_pic_blip2_blip_albef,
|
||||
pre_proc_text_blip2_blip_albef,
|
||||
pre_extracted_feature_img_blip,
|
||||
pre_extracted_feature_text_blip,
|
||||
simularity_blip,
|
||||
sorted_blip,
|
||||
marks=pytest.mark.skipif(
|
||||
gpu_is_not_available, reason="gpu_is_not_availible"
|
||||
),
|
||||
),
|
||||
(
|
||||
device("cpu"),
|
||||
"albef",
|
||||
pre_proc_pic_blip2_blip_albef,
|
||||
pre_proc_text_blip2_blip_albef,
|
||||
pre_extracted_feature_img_albef,
|
||||
pre_extracted_feature_text_albef,
|
||||
simularity_albef,
|
||||
sorted_albef,
|
||||
),
|
||||
pytest.param(
|
||||
device("cuda"),
|
||||
"albef",
|
||||
pre_proc_pic_blip2_blip_albef,
|
||||
pre_proc_text_blip2_blip_albef,
|
||||
pre_extracted_feature_img_albef,
|
||||
pre_extracted_feature_text_albef,
|
||||
simularity_albef,
|
||||
sorted_albef,
|
||||
marks=pytest.mark.skipif(
|
||||
gpu_is_not_available, reason="gpu_is_not_availible"
|
||||
),
|
||||
),
|
||||
(
|
||||
device("cpu"),
|
||||
"clip_base",
|
||||
pre_proc_pic_clip_vitl14,
|
||||
pre_proc_text_clip_clip_vitl14_clip_vitl14_336,
|
||||
pre_extracted_feature_img_clip,
|
||||
pre_extracted_feature_text_clip,
|
||||
simularity_clip,
|
||||
sorted_clip,
|
||||
),
|
||||
pytest.param(
|
||||
device("cuda"),
|
||||
"clip_base",
|
||||
pre_proc_pic_clip_vitl14,
|
||||
pre_proc_text_clip_clip_vitl14_clip_vitl14_336,
|
||||
pre_extracted_feature_img_clip,
|
||||
pre_extracted_feature_text_clip,
|
||||
simularity_clip,
|
||||
sorted_clip,
|
||||
marks=pytest.mark.skipif(
|
||||
gpu_is_not_available, reason="gpu_is_not_availible"
|
||||
),
|
||||
),
|
||||
(
|
||||
device("cpu"),
|
||||
"clip_vitl14",
|
||||
pre_proc_pic_clip_vitl14,
|
||||
pre_proc_text_clip_clip_vitl14_clip_vitl14_336,
|
||||
pre_extracted_feature_img_clip_vitl14,
|
||||
pre_extracted_feature_text_clip_vitl14,
|
||||
simularity_clip_vitl14,
|
||||
sorted_clip_vitl14,
|
||||
),
|
||||
pytest.param(
|
||||
device("cuda"),
|
||||
"clip_vitl14",
|
||||
pre_proc_pic_clip_vitl14,
|
||||
pre_proc_text_clip_clip_vitl14_clip_vitl14_336,
|
||||
pre_extracted_feature_img_clip_vitl14,
|
||||
pre_extracted_feature_text_clip_vitl14,
|
||||
simularity_clip_vitl14,
|
||||
sorted_clip_vitl14,
|
||||
marks=pytest.mark.skipif(
|
||||
gpu_is_not_available, reason="gpu_is_not_availible"
|
||||
),
|
||||
),
|
||||
(
|
||||
device("cpu"),
|
||||
"clip_vitl14_336",
|
||||
pre_proc_pic_clip_vitl14_336,
|
||||
pre_proc_text_clip_clip_vitl14_clip_vitl14_336,
|
||||
pre_extracted_feature_img_clip_vitl14_336,
|
||||
pre_extracted_feature_text_clip_vitl14_336,
|
||||
simularity_clip_vitl14_336,
|
||||
sorted_clip_vitl14_336,
|
||||
),
|
||||
pytest.param(
|
||||
device("cuda"),
|
||||
"clip_vitl14_336",
|
||||
pre_proc_pic_clip_vitl14_336,
|
||||
pre_proc_text_clip_clip_vitl14_clip_vitl14_336,
|
||||
pre_extracted_feature_img_clip_vitl14_336,
|
||||
pre_extracted_feature_text_clip_vitl14_336,
|
||||
simularity_clip_vitl14_336,
|
||||
sorted_clip_vitl14_336,
|
||||
marks=pytest.mark.skipif(
|
||||
gpu_is_not_available, reason="gpu_is_not_availible"
|
||||
),
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_parsing_images(
|
||||
pre_multimodal_device,
|
||||
pre_model,
|
||||
pre_proc_pic,
|
||||
pre_proc_text,
|
||||
pre_extracted_feature_img,
|
||||
pre_extracted_feature_text,
|
||||
pre_simularity,
|
||||
pre_sorted,
|
||||
get_path,
|
||||
get_testdict,
|
||||
tmp_path,
|
||||
):
|
||||
ms.MultimodalSearch.multimodal_device = pre_multimodal_device
|
||||
my_obj = ms.MultimodalSearch(get_testdict)
|
||||
(
|
||||
model,
|
||||
vis_processor,
|
||||
txt_processor,
|
||||
image_keys,
|
||||
_,
|
||||
features_image_stacked,
|
||||
) = my_obj.parsing_images(pre_model, path_to_save_tensors=tmp_path)
|
||||
|
||||
for i, num in zip(range(10), features_image_stacked[0, 10:12].tolist()):
|
||||
assert (
|
||||
math.isclose(num, pre_extracted_feature_img[i], rel_tol=related_error)
|
||||
is True
|
||||
)
|
||||
|
||||
test_pic = Image.open(my_obj.subdict["IMG_2746"]["filename"]).convert("RGB")
|
||||
test_querry = (
|
||||
"The bird sat on a tree located at the intersection of 23rd and 43rd streets."
|
||||
)
|
||||
processed_pic = (
|
||||
vis_processor["eval"](test_pic).unsqueeze(0).to(pre_multimodal_device)
|
||||
)
|
||||
processed_text = txt_processor["eval"](test_querry)
|
||||
|
||||
for i, num in zip(range(10), processed_pic[0, 0, 0, 25:27].tolist()):
|
||||
assert math.isclose(num, pre_proc_pic[i], rel_tol=related_error) is True
|
||||
|
||||
assert processed_text == pre_proc_text
|
||||
|
||||
search_query = [
|
||||
{"text_input": test_querry},
|
||||
{"image": my_obj.subdict["IMG_2746"]["filename"]},
|
||||
]
|
||||
multi_features_stacked = my_obj.querys_processing(
|
||||
search_query, model, txt_processor, vis_processor, pre_model
|
||||
)
|
||||
|
||||
for i, num in zip(range(10), multi_features_stacked[0, 10:12].tolist()):
|
||||
assert (
|
||||
math.isclose(num, pre_extracted_feature_text[i], rel_tol=related_error)
|
||||
is True
|
||||
)
|
||||
|
||||
for i, num in zip(range(10), multi_features_stacked[1, 10:12].tolist()):
|
||||
assert (
|
||||
math.isclose(num, pre_extracted_feature_img[i], rel_tol=related_error)
|
||||
is True
|
||||
)
|
||||
|
||||
search_query2 = [
|
||||
{"text_input": "A bus"},
|
||||
{"image": get_path + "IMG_3758.png"},
|
||||
]
|
||||
|
||||
similarity, sorted_list = my_obj.multimodal_search(
|
||||
model,
|
||||
vis_processor,
|
||||
txt_processor,
|
||||
pre_model,
|
||||
image_keys,
|
||||
features_image_stacked,
|
||||
search_query2,
|
||||
)
|
||||
|
||||
for i, num in zip(range(len(pre_simularity)), similarity.tolist()):
|
||||
for j, num2 in zip(range(len(num)), num):
|
||||
assert (
|
||||
math.isclose(num2, pre_simularity[i][j], rel_tol=100 * related_error)
|
||||
is True
|
||||
)
|
||||
|
||||
for i, num in zip(range(len(pre_sorted)), sorted_list):
|
||||
for j, num2 in zip(range(2), num):
|
||||
assert num2 == pre_sorted[i][j]
|
||||
|
||||
del (
|
||||
model,
|
||||
vis_processor,
|
||||
txt_processor,
|
||||
similarity,
|
||||
features_image_stacked,
|
||||
processed_pic,
|
||||
multi_features_stacked,
|
||||
my_obj,
|
||||
)
|
||||
cuda.empty_cache()
|
||||
|
||||
|
||||
@pytest.mark.long
|
||||
def test_itm(get_test_my_dict, get_path):
|
||||
search_query3 = [
|
||||
{"text_input": "A bus"},
|
||||
{"image": get_path + "IMG_3758.png"},
|
||||
]
|
||||
image_keys = ["IMG_2746", "IMG_2809"]
|
||||
sorted_list = [[1, 0], [1, 0]]
|
||||
my_obj = ms.MultimodalSearch(get_test_my_dict)
|
||||
for itm_model in ["blip_base", "blip_large"]:
|
||||
(
|
||||
itm_scores,
|
||||
image_gradcam_with_itm,
|
||||
) = my_obj.image_text_match_reordering(
|
||||
search_query3,
|
||||
itm_model,
|
||||
image_keys,
|
||||
sorted_list,
|
||||
batch_size=1,
|
||||
need_grad_cam=True,
|
||||
)
|
||||
for i, itm in zip(
|
||||
range(len(dict_itm_scores_for_blib[itm_model])),
|
||||
dict_itm_scores_for_blib[itm_model],
|
||||
):
|
||||
assert (
|
||||
math.isclose(itm_scores[0].tolist()[i], itm, rel_tol=10 * related_error)
|
||||
is True
|
||||
)
|
||||
for i, grad_cam in zip(
|
||||
range(len(dict_image_gradcam_with_itm_for_blip[itm_model])),
|
||||
dict_image_gradcam_with_itm_for_blip[itm_model],
|
||||
):
|
||||
assert (
|
||||
math.isclose(
|
||||
image_gradcam_with_itm["A bus"]["IMG_2809"][0][0].tolist()[i],
|
||||
grad_cam,
|
||||
rel_tol=10 * related_error,
|
||||
)
|
||||
is True
|
||||
)
|
||||
del itm_scores, image_gradcam_with_itm
|
||||
cuda.empty_cache()
|
||||
|
||||
|
||||
@pytest.mark.long
|
||||
def test_itm_blip2_coco(get_test_my_dict, get_path):
|
||||
search_query3 = [
|
||||
{"text_input": "A bus"},
|
||||
{"image": get_path + "IMG_3758.png"},
|
||||
]
|
||||
image_keys = ["IMG_2746", "IMG_2809"]
|
||||
sorted_list = [[1, 0], [1, 0]]
|
||||
my_obj = ms.MultimodalSearch(get_test_my_dict)
|
||||
|
||||
(
|
||||
itm_scores,
|
||||
image_gradcam_with_itm,
|
||||
) = my_obj.image_text_match_reordering(
|
||||
search_query3,
|
||||
"blip2_coco",
|
||||
image_keys,
|
||||
sorted_list,
|
||||
batch_size=1,
|
||||
need_grad_cam=False,
|
||||
)
|
||||
for i, itm in zip(
|
||||
range(len(dict_itm_scores_for_blib["blip2_coco"])),
|
||||
dict_itm_scores_for_blib["blip2_coco"],
|
||||
):
|
||||
assert (
|
||||
math.isclose(itm_scores[0].tolist()[i], itm, rel_tol=10 * related_error)
|
||||
is True
|
||||
)
|
||||
del itm_scores, image_gradcam_with_itm
|
||||
cuda.empty_cache()
|
||||
Загрузка…
x
Ссылка в новой задаче
Block a user