AMMICO/misinformation/multimodal_search.py
2023-03-31 12:36:58 -07:00

380 строки
14 KiB
Python

from misinformation.utils import AnalysisMethod
import torch
import torch.nn.functional as Func
import requests
import lavis
import os
from PIL import Image
from IPython.display import display
from lavis.models import load_model_and_preprocess
class MultimodalSearch(AnalysisMethod):
def __init__(self, subdict: dict) -> None:
super().__init__(subdict)
# self.subdict.update(self.set_keys())
multimodal_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def load_feature_extractor_model_blip2(self, device):
model, vis_processors, txt_processors = load_model_and_preprocess(
name="blip2_feature_extractor",
model_type="pretrain",
is_eval=True,
device=device,
)
return model, vis_processors, txt_processors
def load_feature_extractor_model_blip(self, device):
model, vis_processors, txt_processors = load_model_and_preprocess(
name="blip_feature_extractor",
model_type="base",
is_eval=True,
device=device,
)
return model, vis_processors, txt_processors
def load_feature_extractor_model_albef(self, device):
model, vis_processors, txt_processors = load_model_and_preprocess(
name="albef_feature_extractor",
model_type="base",
is_eval=True,
device=device,
)
return model, vis_processors, txt_processors
def load_feature_extractor_model_clip_base(self, device):
model, vis_processors, txt_processors = load_model_and_preprocess(
name="clip_feature_extractor",
model_type="base",
is_eval=True,
device=device,
)
return model, vis_processors, txt_processors
def load_feature_extractor_model_clip_vitl14(self, device):
model, vis_processors, txt_processors = load_model_and_preprocess(
name="clip_feature_extractor",
model_type="ViT-L-14",
is_eval=True,
device=device,
)
return model, vis_processors, txt_processors
def load_feature_extractor_model_clip_vitl14_336(self, device):
model, vis_processors, txt_processors = load_model_and_preprocess(
name="clip_feature_extractor",
model_type="ViT-L-14-336",
is_eval=True,
device=device,
)
return model, vis_processors, txt_processors
def read_img(self, filepath):
raw_image = Image.open(filepath).convert("RGB")
return raw_image
def read_and_process_images(self, image_paths, vis_processor):
raw_images = [MultimodalSearch.read_img(self, path) for path in image_paths]
images = [
vis_processor["eval"](r_img)
.unsqueeze(0)
.to(MultimodalSearch.multimodal_device)
for r_img in raw_images
]
images_tensors = torch.stack(images)
return raw_images, images_tensors
def extract_image_features_blip2(self, model, images_tensors):
with torch.cuda.amp.autocast(
enabled=(MultimodalSearch.multimodal_device != torch.device("cpu"))
):
features_image = [
model.extract_features({"image": ten, "text_input": ""}, mode="image")
for ten in images_tensors
]
features_image_stacked = torch.stack(
[feat.image_embeds_proj[:, 0, :].squeeze(0) for feat in features_image]
)
return features_image_stacked
def extract_image_features_clip(self, model, images_tensors):
features_image = [
model.extract_features({"image": ten}) for ten in images_tensors
]
features_image_stacked = torch.stack(
[Func.normalize(feat.float(), dim=-1).squeeze(0) for feat in features_image]
)
return features_image_stacked
def extract_image_features_basic(self, model, images_tensors):
features_image = [
model.extract_features({"image": ten, "text_input": ""}, mode="image")
for ten in images_tensors
]
features_image_stacked = torch.stack(
[feat.image_embeds_proj[:, 0, :].squeeze(0) for feat in features_image]
)
return features_image_stacked
def save_tensors(
self,
model_type,
features_image_stacked,
name="saved_features_image.pt",
path="./saved_tensors/",
):
if not os.path.exists(path):
os.makedirs(path)
with open(
str(path)
+ str(len(features_image_stacked))
+ "_"
+ model_type
+ "_"
+ name,
"wb",
) as f:
torch.save(features_image_stacked, f)
return name
def load_tensors(self, name):
features_image_stacked = torch.load(name)
return features_image_stacked
def extract_text_features(self, model, text_input):
sample_text = {"text_input": [text_input]}
features_text = model.extract_features(sample_text, mode="text")
return features_text
def parsing_images(
self,
model_type,
path_to_saved_tensors="./saved_tensors/",
path_to_load_tensors=None,
):
if model_type in ("clip_base", "clip_vitl14_336", "clip_vitl14"):
path_to_lib = lavis.__file__[:-11] + "models/clip_models/"
url = "https://raw.githubusercontent.com/salesforce/LAVIS/main/lavis/models/clip_models/bpe_simple_vocab_16e6.txt.gz"
r = requests.get(url, allow_redirects=False)
open(path_to_lib + "bpe_simple_vocab_16e6.txt.gz", "wb").write(r.content)
image_keys = sorted(self.keys())
image_names = [self[k]["filename"] for k in image_keys]
select_model = {
"blip2": MultimodalSearch.load_feature_extractor_model_blip2,
"blip": MultimodalSearch.load_feature_extractor_model_blip,
"albef": MultimodalSearch.load_feature_extractor_model_albef,
"clip_base": MultimodalSearch.load_feature_extractor_model_clip_base,
"clip_vitl14": MultimodalSearch.load_feature_extractor_model_clip_vitl14,
"clip_vitl14_336": MultimodalSearch.load_feature_extractor_model_clip_vitl14_336,
}
select_extract_image_features = {
"blip2": MultimodalSearch.extract_image_features_blip2,
"blip": MultimodalSearch.extract_image_features_basic,
"albef": MultimodalSearch.extract_image_features_basic,
"clip_base": MultimodalSearch.extract_image_features_clip,
"clip_vitl14": MultimodalSearch.extract_image_features_clip,
"clip_vitl14_336": MultimodalSearch.extract_image_features_clip,
}
if model_type in select_model.keys():
(
model,
vis_processors,
txt_processors,
) = select_model[
model_type
](self, MultimodalSearch.multimodal_device)
else:
raise SyntaxError(
"Please, use one of the following models: blip2, blip, albef, clip_base, clip_vitl14, clip_vitl14_336"
)
_, images_tensors = MultimodalSearch.read_and_process_images(
self, image_names, vis_processors
)
if path_to_load_tensors is None:
with torch.no_grad():
features_image_stacked = select_extract_image_features[model_type](
self, model, images_tensors
)
MultimodalSearch.save_tensors(
self, model_type, features_image_stacked, path=path_to_saved_tensors
)
else:
features_image_stacked = MultimodalSearch.load_tensors(
self, str(path_to_load_tensors)
)
return (
model,
vis_processors,
txt_processors,
image_keys,
image_names,
features_image_stacked,
)
def querys_processing(
self, search_query, model, txt_processors, vis_processors, model_type
):
select_extract_image_features = {
"blip2": MultimodalSearch.extract_image_features_blip2,
"blip": MultimodalSearch.extract_image_features_basic,
"albef": MultimodalSearch.extract_image_features_basic,
"clip_base": MultimodalSearch.extract_image_features_clip,
"clip_vitl14": MultimodalSearch.extract_image_features_clip,
"clip_vitl14_336": MultimodalSearch.extract_image_features_clip,
}
for query in search_query:
if not (len(query) == 1) and (query in ("image", "text_input")):
raise SyntaxError(
'Each query must contain either an "image" or a "text_input"'
)
multi_sample = []
for query in search_query:
if "text_input" in query.keys():
text_processing = txt_processors["eval"](query["text_input"])
images_tensors = ""
elif "image" in query.keys():
_, images_tensors = MultimodalSearch.read_and_process_images(
self, [query["image"]], vis_processors
)
text_processing = ""
multi_sample.append(
{"image": images_tensors, "text_input": text_processing}
)
multi_features_query = []
for query in multi_sample:
if query["image"] == "":
if model_type in ("clip_base", "clip_vitl14_336", "clip_vitl14"):
features = model.extract_features(
{"text_input": query["text_input"]}
)
features_squeeze = features.squeeze(0).to(
MultimodalSearch.multimodal_device
)
multi_features_query.append(
Func.normalize(features_squeeze, dim=-1)
)
else:
features = model.extract_features(query, mode="text")
features_squeeze = (
features.text_embeds_proj[:, 0, :]
.squeeze(0)
.to(MultimodalSearch.multimodal_device)
)
multi_features_query.append(features_squeeze)
if query["text_input"] == "":
multi_features_query.append(
select_extract_image_features[model_type](
self, model, query["image"]
)
)
multi_features_stacked = torch.stack(
[query.squeeze(0) for query in multi_features_query]
).to(MultimodalSearch.multimodal_device)
return multi_features_stacked
def multimodal_search(
self,
model,
vis_processors,
txt_processors,
model_type,
image_keys,
features_image_stacked,
search_query,
filter_number_of_images=None,
filter_val_limit=None,
filter_rel_error=None,
):
if filter_number_of_images is None:
filter_number_of_images = len(self)
if filter_val_limit is None:
filter_val_limit = 0
if filter_rel_error is None:
filter_rel_error = 1e10
features_image_stacked.to(MultimodalSearch.multimodal_device)
with torch.no_grad():
multi_features_stacked = MultimodalSearch.querys_processing(
self, search_query, model, txt_processors, vis_processors, model_type
)
similarity = features_image_stacked @ multi_features_stacked.t()
# similarity_soft_max = torch.nn.Softmax(dim=0)(similarity / 0.01)
sorted_lists = [
sorted(range(len(similarity)), key=lambda k: similarity[k, i], reverse=True)
for i in range(len(similarity[0]))
]
places = [[item.index(i) for i in range(len(item))] for item in sorted_lists]
for q in range(len(search_query)):
max_val = similarity[sorted_lists[q][0]][q].item()
for i, key in zip(range(len(image_keys)), sorted_lists[q]):
if (
i < filter_number_of_images
and similarity[key][q].item() > filter_val_limit
and 100 * abs(max_val - similarity[key][q].item()) / max_val
< filter_rel_error
):
self[image_keys[key]][
"rank " + list(search_query[q].values())[0]
] = places[q][key]
self[image_keys[key]][
list(search_query[q].values())[0]
] = similarity[key][q].item()
else:
self[image_keys[key]][
"rank " + list(search_query[q].values())[0]
] = None
self[image_keys[key]][list(search_query[q].values())[0]] = 0
return similarity, sorted_lists
def show_results(self, query):
if "image" in query.keys():
pic = Image.open(query["image"]).convert("RGB")
pic.thumbnail((400, 400))
display(
"Your search query: ",
pic,
"--------------------------------------------------",
"Results:",
)
elif "text_input" in query.keys():
display(
"Your search query: " + query["text_input"],
"--------------------------------------------------",
"Results:",
)
for s in sorted(
self.items(), key=lambda t: t[1][list(query.values())[0]], reverse=True
):
if s[1]["rank " + list(query.values())[0]] is None:
break
p1 = Image.open(s[1]["filename"]).convert("RGB")
p1.thumbnail((400, 400))
display(
"Rank: "
+ str(s[1]["rank " + list(query.values())[0]])
+ " Val: "
+ str(s[1][list(query.values())[0]]),
s[0],
p1,
)
display(
"--------------------------------------------------",
)