зеркало из
https://github.com/ssciwr/AMMICO.git
synced 2025-10-29 13:06:04 +02:00
Коммит
80665e5f82
2
.github/workflows/ci.yml
поставляемый
2
.github/workflows/ci.yml
поставляемый
@ -32,7 +32,7 @@ jobs:
|
||||
- name: Run pytest
|
||||
run: |
|
||||
cd misinformation
|
||||
python -m pytest -m "not gcv" -svv --cov=. --cov-report=xml
|
||||
python -m pytest -m "not gcv and not long" -svv --cov=. --cov-report=xml
|
||||
- name: Upload coverage
|
||||
if: matrix.os == 'ubuntu-22.04' && matrix.python-version == '3.9'
|
||||
uses: codecov/codecov-action@v3
|
||||
|
||||
@ -113,7 +113,7 @@
|
||||
" image_keys,\n",
|
||||
" image_names,\n",
|
||||
" features_image_stacked,\n",
|
||||
") = ms.MultimodalSearch.parsing_images(mydict, model_type)"
|
||||
") = ms.MultimodalSearch.parsing_images(mydict, model_type, path_to_saved_tensors=\".\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -128,7 +128,9 @@
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "56c6d488-f093-4661-835a-5c73a329c874",
|
||||
"metadata": {},
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# (\n",
|
||||
|
||||
@ -4,9 +4,14 @@ import torch.nn.functional as Func
|
||||
import requests
|
||||
import lavis
|
||||
import os
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from skimage import transform as skimage_transform
|
||||
from scipy.ndimage import filters
|
||||
from matplotlib import pyplot as plt
|
||||
from IPython.display import display
|
||||
from lavis.models import load_model_and_preprocess
|
||||
from lavis.models import load_model_and_preprocess, load_model, BlipBase
|
||||
from lavis.processors import load_processor
|
||||
|
||||
|
||||
class MultimodalSearch(AnalysisMethod):
|
||||
@ -233,7 +238,7 @@ class MultimodalSearch(AnalysisMethod):
|
||||
}
|
||||
|
||||
for query in search_query:
|
||||
if not (len(query) == 1) and (query in ("image", "text_input")):
|
||||
if len(query) != 1 and (query in ("image", "text_input")):
|
||||
raise SyntaxError(
|
||||
'Each query must contain either an "image" or a "text_input"'
|
||||
)
|
||||
@ -343,7 +348,288 @@ class MultimodalSearch(AnalysisMethod):
|
||||
self[image_keys[key]][list(search_query[q].values())[0]] = 0
|
||||
return similarity, sorted_lists
|
||||
|
||||
def show_results(self, query):
|
||||
def itm_text_precessing(self, search_query):
|
||||
for query in search_query:
|
||||
if not (len(query) == 1) and (query in ("image", "text_input")):
|
||||
raise SyntaxError(
|
||||
'Each querry must contain either an "image" or a "text_input"'
|
||||
)
|
||||
text_query_index = []
|
||||
for i, query in zip(range(len(search_query)), search_query):
|
||||
if "text_input" in query.keys():
|
||||
text_query_index.append(i)
|
||||
|
||||
return text_query_index
|
||||
|
||||
def get_pathes_from_query(self, query):
|
||||
paths = []
|
||||
image_names = []
|
||||
for s in sorted(
|
||||
self.items(), key=lambda t: t[1][list(query.values())[0]], reverse=True
|
||||
):
|
||||
if s[1]["rank " + list(query.values())[0]] is None:
|
||||
break
|
||||
paths.append(s[1]["filename"])
|
||||
image_names.append(s[0])
|
||||
return paths, image_names
|
||||
|
||||
def read_and_process_images_itm(self, image_paths, vis_processor):
|
||||
raw_images = [MultimodalSearch.read_img(self, path) for path in image_paths]
|
||||
images = [vis_processor(r_img) for r_img in raw_images]
|
||||
images_tensors = torch.stack(images).to(MultimodalSearch.multimodal_device)
|
||||
|
||||
return raw_images, images_tensors
|
||||
|
||||
def compute_gradcam_batch(
|
||||
self,
|
||||
model,
|
||||
visual_input,
|
||||
text_input,
|
||||
tokenized_text,
|
||||
block_num=6,
|
||||
):
|
||||
model.text_encoder.base_model.base_model.encoder.layer[
|
||||
block_num
|
||||
].crossattention.self.save_attention = True
|
||||
|
||||
output = model(
|
||||
{"image": visual_input, "text_input": text_input}, match_head="itm"
|
||||
)
|
||||
loss = output[:, 1].sum()
|
||||
|
||||
model.zero_grad()
|
||||
loss.backward()
|
||||
with torch.no_grad():
|
||||
mask = tokenized_text.attention_mask.view(
|
||||
tokenized_text.attention_mask.size(0), 1, -1, 1, 1
|
||||
) # (bsz,1,token_len, 1,1)
|
||||
token_length = mask.sum() - 2
|
||||
token_length = token_length.cpu()
|
||||
# grads and cams [bsz, num_head, seq_len, image_patch]
|
||||
grads = model.text_encoder.base_model.base_model.encoder.layer[
|
||||
block_num
|
||||
].crossattention.self.get_attn_gradients()
|
||||
cams = model.text_encoder.base_model.base_model.encoder.layer[
|
||||
block_num
|
||||
].crossattention.self.get_attention_map()
|
||||
|
||||
# assume using vit large with 576 num image patch
|
||||
cams = (
|
||||
cams[:, :, :, 1:].reshape(visual_input.size(0), 12, -1, 24, 24) * mask
|
||||
)
|
||||
grads = (
|
||||
grads[:, :, :, 1:]
|
||||
.clamp(0)
|
||||
.reshape(visual_input.size(0), 12, -1, 24, 24)
|
||||
* mask
|
||||
)
|
||||
|
||||
gradcam = cams * grads
|
||||
# [enc token gradcam, average gradcam across token, gradcam for individual token]
|
||||
# gradcam = torch.cat((gradcam[0:1,:], gradcam[1:token_length+1, :].sum(dim=0, keepdim=True)/token_length, gradcam[1:, :]))
|
||||
gradcam = gradcam.mean(1).cpu().detach()
|
||||
gradcam = (
|
||||
gradcam[:, 1 : token_length + 1, :].sum(dim=1, keepdim=True)
|
||||
/ token_length
|
||||
)
|
||||
|
||||
return gradcam, output
|
||||
|
||||
def resize_img(self, raw_img):
|
||||
w, h = raw_img.size
|
||||
scaling_factor = 240 / w
|
||||
resized_image = raw_img.resize(
|
||||
(int(w * scaling_factor), int(h * scaling_factor))
|
||||
)
|
||||
return resized_image
|
||||
|
||||
def get_att_map(self, img, att_map, blur=True, overlap=True):
|
||||
att_map -= att_map.min()
|
||||
if att_map.max() > 0:
|
||||
att_map /= att_map.max()
|
||||
att_map = skimage_transform.resize(
|
||||
att_map, (img.shape[:2]), order=3, mode="constant"
|
||||
)
|
||||
if blur:
|
||||
att_map = filters.gaussian_filter(att_map, 0.02 * max(img.shape[:2]))
|
||||
att_map -= att_map.min()
|
||||
att_map /= att_map.max()
|
||||
cmap = plt.get_cmap("jet")
|
||||
att_mapv = cmap(att_map)
|
||||
att_mapv = np.delete(att_mapv, 3, 2)
|
||||
if overlap:
|
||||
att_map = (
|
||||
1 * (1 - att_map**0.7).reshape(att_map.shape + (1,)) * img
|
||||
+ (att_map**0.7).reshape(att_map.shape + (1,)) * att_mapv
|
||||
)
|
||||
return att_map
|
||||
|
||||
def upload_model_blip2_coco(self):
|
||||
itm_model = load_model(
|
||||
"blip2_image_text_matching",
|
||||
"coco",
|
||||
is_eval=True,
|
||||
device=MultimodalSearch.multimodal_device,
|
||||
)
|
||||
vis_processor = load_processor("blip_image_eval").build(image_size=364)
|
||||
return itm_model, vis_processor
|
||||
|
||||
def upload_model_blip_base(self):
|
||||
itm_model = load_model(
|
||||
"blip_image_text_matching",
|
||||
"base",
|
||||
is_eval=True,
|
||||
device=MultimodalSearch.multimodal_device,
|
||||
)
|
||||
vis_processor = load_processor("blip_image_eval").build(image_size=384)
|
||||
return itm_model, vis_processor
|
||||
|
||||
def upload_model_blip_large(self):
|
||||
itm_model = load_model(
|
||||
"blip_image_text_matching",
|
||||
"large",
|
||||
is_eval=True,
|
||||
device=MultimodalSearch.multimodal_device,
|
||||
)
|
||||
vis_processor = load_processor("blip_image_eval").build(image_size=384)
|
||||
return itm_model, vis_processor
|
||||
|
||||
def image_text_match_reordering(
|
||||
self,
|
||||
search_query,
|
||||
itm_model_type,
|
||||
image_keys,
|
||||
sorted_lists,
|
||||
batch_size=1,
|
||||
need_grad_cam=False,
|
||||
):
|
||||
if itm_model_type == "blip2_coco" and need_grad_cam is True:
|
||||
raise SyntaxError(
|
||||
"The blip2_coco model does not yet work with gradcam. Please set need_grad_cam to False"
|
||||
)
|
||||
|
||||
choose_model = {
|
||||
"blip_base": MultimodalSearch.upload_model_blip_base,
|
||||
"blip_large": MultimodalSearch.upload_model_blip_large,
|
||||
"blip2_coco": MultimodalSearch.upload_model_blip2_coco,
|
||||
}
|
||||
|
||||
itm_model, vis_processor_itm = choose_model[itm_model_type](self)
|
||||
text_processor = load_processor("blip_caption")
|
||||
tokenizer = BlipBase.init_tokenizer()
|
||||
|
||||
if itm_model_type == "blip2_coco":
|
||||
need_grad_cam = False
|
||||
|
||||
text_query_index = MultimodalSearch.itm_text_precessing(self, search_query)
|
||||
|
||||
avg_gradcams = []
|
||||
itm_scores = []
|
||||
itm_scores2 = []
|
||||
image_gradcam_with_itm = {}
|
||||
|
||||
for index_text_query in text_query_index:
|
||||
query = search_query[index_text_query]
|
||||
pathes, image_names = MultimodalSearch.get_pathes_from_query(self, query)
|
||||
num_batches = int(len(pathes) / batch_size)
|
||||
num_batches_residue = len(pathes) % batch_size
|
||||
|
||||
local_itm_scores = []
|
||||
local_avg_gradcams = []
|
||||
|
||||
if num_batches_residue != 0:
|
||||
num_batches = num_batches + 1
|
||||
for i in range(num_batches):
|
||||
filenames_in_batch = pathes[i * batch_size : (i + 1) * batch_size]
|
||||
current_len = len(filenames_in_batch)
|
||||
raw_images, images = MultimodalSearch.read_and_process_images_itm(
|
||||
self, filenames_in_batch, vis_processor_itm
|
||||
)
|
||||
queries_batch = [text_processor(query["text_input"])] * current_len
|
||||
queries_tok_batch = tokenizer(queries_batch, return_tensors="pt").to(
|
||||
MultimodalSearch.multimodal_device
|
||||
)
|
||||
|
||||
if need_grad_cam:
|
||||
gradcam, itm_output = MultimodalSearch.compute_gradcam_batch(
|
||||
self,
|
||||
itm_model,
|
||||
images,
|
||||
queries_batch,
|
||||
queries_tok_batch,
|
||||
)
|
||||
norm_imgs = [np.float32(r_img) / 255 for r_img in raw_images]
|
||||
|
||||
for norm_img, grad_cam in zip(norm_imgs, gradcam):
|
||||
avg_gradcam = MultimodalSearch.get_att_map(
|
||||
self, norm_img, np.float32(grad_cam[0]), blur=True
|
||||
)
|
||||
local_avg_gradcams.append(avg_gradcam)
|
||||
|
||||
else:
|
||||
itm_output = itm_model(
|
||||
{"image": images, "text_input": queries_batch}, match_head="itm"
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
itm_score = torch.nn.functional.softmax(itm_output, dim=1)
|
||||
|
||||
local_itm_scores.append(itm_score)
|
||||
|
||||
local_itm_scores2 = torch.cat(local_itm_scores)[:, 1]
|
||||
if need_grad_cam:
|
||||
localimage_gradcam_with_itm = {
|
||||
n: i * 255 for n, i in zip(image_names, local_avg_gradcams)
|
||||
}
|
||||
else:
|
||||
localimage_gradcam_with_itm = ""
|
||||
image_names_with_itm = {
|
||||
n: i.item() for n, i in zip(image_names, local_itm_scores2)
|
||||
}
|
||||
itm_rank = torch.argsort(local_itm_scores2, descending=True)
|
||||
image_names_with_new_rank = {
|
||||
image_names[i.item()]: rank
|
||||
for i, rank in zip(itm_rank, range(len(itm_rank)))
|
||||
}
|
||||
for i, key in zip(range(len(image_keys)), sorted_lists[index_text_query]):
|
||||
if image_keys[key] in image_names:
|
||||
self[image_keys[key]][
|
||||
"itm " + list(search_query[index_text_query].values())[0]
|
||||
] = image_names_with_itm[image_keys[key]]
|
||||
self[image_keys[key]][
|
||||
"itm_rank " + list(search_query[index_text_query].values())[0]
|
||||
] = image_names_with_new_rank[image_keys[key]]
|
||||
else:
|
||||
self[image_keys[key]][
|
||||
"itm " + list(search_query[index_text_query].values())[0]
|
||||
] = 0
|
||||
self[image_keys[key]][
|
||||
"itm_rank " + list(search_query[index_text_query].values())[0]
|
||||
] = None
|
||||
|
||||
avg_gradcams.append(local_avg_gradcams)
|
||||
itm_scores.append(local_itm_scores)
|
||||
itm_scores2.append(local_itm_scores2)
|
||||
image_gradcam_with_itm[
|
||||
list(search_query[index_text_query].values())[0]
|
||||
] = localimage_gradcam_with_itm
|
||||
del (
|
||||
itm_model,
|
||||
vis_processor_itm,
|
||||
text_processor,
|
||||
raw_images,
|
||||
images,
|
||||
tokenizer,
|
||||
queries_batch,
|
||||
queries_tok_batch,
|
||||
itm_score,
|
||||
)
|
||||
if need_grad_cam:
|
||||
del itm_output, gradcam, norm_img, grad_cam, avg_gradcam
|
||||
torch.cuda.empty_cache()
|
||||
return itm_scores2, image_gradcam_with_itm
|
||||
|
||||
def show_results(self, query, itm=False, image_gradcam_with_itm=False):
|
||||
if "image" in query.keys():
|
||||
pic = Image.open(query["image"]).convert("RGB")
|
||||
pic.thumbnail((400, 400))
|
||||
@ -359,18 +645,29 @@ class MultimodalSearch(AnalysisMethod):
|
||||
"--------------------------------------------------",
|
||||
"Results:",
|
||||
)
|
||||
if itm:
|
||||
current_querry_val = "itm " + list(query.values())[0]
|
||||
current_querry_rank = "itm_rank " + list(query.values())[0]
|
||||
else:
|
||||
current_querry_val = list(query.values())[0]
|
||||
current_querry_rank = "rank " + list(query.values())[0]
|
||||
|
||||
for s in sorted(
|
||||
self.items(), key=lambda t: t[1][list(query.values())[0]], reverse=True
|
||||
self.items(), key=lambda t: t[1][current_querry_val], reverse=True
|
||||
):
|
||||
if s[1]["rank " + list(query.values())[0]] is None:
|
||||
if s[1][current_querry_rank] is None:
|
||||
break
|
||||
p1 = Image.open(s[1]["filename"]).convert("RGB")
|
||||
if image_gradcam_with_itm is False:
|
||||
p1 = Image.open(s[1]["filename"]).convert("RGB")
|
||||
else:
|
||||
image = image_gradcam_with_itm[list(query.values())[0]][s[0]]
|
||||
p1 = Image.fromarray(image.astype("uint8"), "RGB")
|
||||
p1.thumbnail((400, 400))
|
||||
display(
|
||||
"Rank: "
|
||||
+ str(s[1]["rank " + list(query.values())[0]])
|
||||
+ str(s[1][current_querry_rank])
|
||||
+ " Val: "
|
||||
+ str(s[1][list(query.values())[0]]),
|
||||
+ str(s[1][current_querry_val]),
|
||||
s[0],
|
||||
p1,
|
||||
)
|
||||
|
||||
@ -16,3 +16,33 @@ def set_environ(request):
|
||||
mypath + "/../../data/seismic-bonfire-329406-412821a70264.json"
|
||||
)
|
||||
print(os.environ.get("GOOGLE_APPLICATION_CREDENTIALS"))
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def get_testdict(get_path):
|
||||
testdict = {
|
||||
"IMG_2746": {"filename": get_path + "IMG_2746.png"},
|
||||
"IMG_2809": {"filename": get_path + "IMG_2809.png"},
|
||||
}
|
||||
return testdict
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def get_test_my_dict(get_path):
|
||||
test_my_dict = {
|
||||
"IMG_2746": {
|
||||
"filename": get_path + "IMG_2746.png",
|
||||
"rank A bus": 1,
|
||||
"A bus": 0.15640679001808167,
|
||||
"rank " + get_path + "IMG_3758.png": 1,
|
||||
get_path + "IMG_3758.png": 0.7533495426177979,
|
||||
},
|
||||
"IMG_2809": {
|
||||
"filename": get_path + "IMG_2809.png",
|
||||
"rank A bus": 0,
|
||||
"A bus": 0.1970970332622528,
|
||||
"rank " + get_path + "IMG_3758.png": 0,
|
||||
get_path + "IMG_3758.png": 0.8907483816146851,
|
||||
},
|
||||
}
|
||||
return test_my_dict
|
||||
|
||||
@ -5,22 +5,17 @@ import numpy
|
||||
from torch import device, cuda
|
||||
import misinformation.multimodal_search as ms
|
||||
|
||||
|
||||
testdict = {
|
||||
"IMG_2746": {"filename": "./test/data/IMG_2746.png"},
|
||||
"IMG_2809": {"filename": "./test/data/IMG_2809.png"},
|
||||
}
|
||||
|
||||
related_error = 1e-2
|
||||
gpu_is_not_available = not cuda.is_available()
|
||||
|
||||
|
||||
cuda.empty_cache()
|
||||
|
||||
|
||||
def test_read_img():
|
||||
def test_read_img(get_testdict):
|
||||
my_dict = {}
|
||||
test_img = ms.MultimodalSearch.read_img(my_dict, testdict["IMG_2746"]["filename"])
|
||||
test_img = ms.MultimodalSearch.read_img(
|
||||
my_dict, get_testdict["IMG_2746"]["filename"]
|
||||
)
|
||||
assert list(numpy.array(test_img)[257][34]) == [70, 66, 63]
|
||||
|
||||
|
||||
@ -205,29 +200,29 @@ dict_image_gradcam_with_itm_for_blip = {
|
||||
"pre_sorted",
|
||||
),
|
||||
[
|
||||
# (
|
||||
# device("cpu"),
|
||||
# "blip2",
|
||||
# pre_proc_pic_blip2_blip_albef,
|
||||
# pre_proc_text_blip2_blip_albef,
|
||||
# pre_extracted_feature_img_blip2,
|
||||
# pre_extracted_feature_text_blip2,
|
||||
# simularity_blip2,
|
||||
# sorted_blip2,
|
||||
# ),
|
||||
# pytest.param(
|
||||
# device("cuda"),
|
||||
# "blip2",
|
||||
# pre_proc_pic_blip2_blip_albef,
|
||||
# pre_proc_text_blip2_blip_albef,
|
||||
# pre_extracted_feature_img_blip2,
|
||||
# pre_extracted_feature_text_blip2,
|
||||
# simularity_blip2,
|
||||
# sorted_blip2,
|
||||
# marks=pytest.mark.skipif(
|
||||
# gpu_is_not_available, reason="gpu_is_not_availible"
|
||||
# ),
|
||||
# ),
|
||||
(
|
||||
device("cpu"),
|
||||
"blip2",
|
||||
pre_proc_pic_blip2_blip_albef,
|
||||
pre_proc_text_blip2_blip_albef,
|
||||
pre_extracted_feature_img_blip2,
|
||||
pre_extracted_feature_text_blip2,
|
||||
simularity_blip2,
|
||||
sorted_blip2,
|
||||
),
|
||||
pytest.param(
|
||||
device("cuda"),
|
||||
"blip2",
|
||||
pre_proc_pic_blip2_blip_albef,
|
||||
pre_proc_text_blip2_blip_albef,
|
||||
pre_extracted_feature_img_blip2,
|
||||
pre_extracted_feature_text_blip2,
|
||||
simularity_blip2,
|
||||
sorted_blip2,
|
||||
marks=pytest.mark.skipif(
|
||||
gpu_is_not_available, reason="gpu_is_not_availible"
|
||||
),
|
||||
),
|
||||
(
|
||||
device("cpu"),
|
||||
"blip",
|
||||
@ -354,6 +349,8 @@ def test_parsing_images(
|
||||
pre_extracted_feature_text,
|
||||
pre_simularity,
|
||||
pre_sorted,
|
||||
get_path,
|
||||
get_testdict,
|
||||
tmp_path,
|
||||
):
|
||||
ms.MultimodalSearch.multimodal_device = pre_multimodal_device
|
||||
@ -365,7 +362,7 @@ def test_parsing_images(
|
||||
_,
|
||||
features_image_stacked,
|
||||
) = ms.MultimodalSearch.parsing_images(
|
||||
testdict, pre_model, path_to_saved_tensors=tmp_path
|
||||
get_testdict, pre_model, path_to_saved_tensors=tmp_path
|
||||
)
|
||||
|
||||
for i, num in zip(range(10), features_image_stacked[0, 10:12].tolist()):
|
||||
@ -374,7 +371,7 @@ def test_parsing_images(
|
||||
is True
|
||||
)
|
||||
|
||||
test_pic = Image.open(testdict["IMG_2746"]["filename"]).convert("RGB")
|
||||
test_pic = Image.open(get_testdict["IMG_2746"]["filename"]).convert("RGB")
|
||||
test_querry = (
|
||||
"The bird sat on a tree located at the intersection of 23rd and 43rd streets."
|
||||
)
|
||||
@ -390,10 +387,10 @@ def test_parsing_images(
|
||||
|
||||
search_query = [
|
||||
{"text_input": test_querry},
|
||||
{"image": testdict["IMG_2746"]["filename"]},
|
||||
{"image": get_testdict["IMG_2746"]["filename"]},
|
||||
]
|
||||
multi_features_stacked = ms.MultimodalSearch.querys_processing(
|
||||
testdict, search_query, model, txt_processor, vis_processor, pre_model
|
||||
get_testdict, search_query, model, txt_processor, vis_processor, pre_model
|
||||
)
|
||||
|
||||
for i, num in zip(range(10), multi_features_stacked[0, 10:12].tolist()):
|
||||
@ -410,11 +407,11 @@ def test_parsing_images(
|
||||
|
||||
search_query2 = [
|
||||
{"text_input": "A bus"},
|
||||
{"image": "../misinformation/test/data/IMG_3758.png"},
|
||||
{"image": get_path + "IMG_3758.png"},
|
||||
]
|
||||
|
||||
similarity, sorted_list = ms.MultimodalSearch.multimodal_search(
|
||||
testdict,
|
||||
get_testdict,
|
||||
model,
|
||||
vis_processor,
|
||||
txt_processor,
|
||||
@ -445,3 +442,81 @@ def test_parsing_images(
|
||||
multi_features_stacked,
|
||||
)
|
||||
cuda.empty_cache()
|
||||
|
||||
|
||||
@pytest.mark.long
|
||||
def test_itm(get_test_my_dict, get_path):
|
||||
search_query3 = [
|
||||
{"text_input": "A bus"},
|
||||
{"image": get_path + "IMG_3758.png"},
|
||||
]
|
||||
image_keys = ["IMG_2746", "IMG_2809"]
|
||||
sorted_list = [[1, 0], [1, 0]]
|
||||
for itm_model in ["blip_base", "blip_large"]:
|
||||
(
|
||||
itm_scores,
|
||||
image_gradcam_with_itm,
|
||||
) = ms.MultimodalSearch.image_text_match_reordering(
|
||||
get_test_my_dict,
|
||||
search_query3,
|
||||
itm_model,
|
||||
image_keys,
|
||||
sorted_list,
|
||||
batch_size=1,
|
||||
need_grad_cam=True,
|
||||
)
|
||||
for i, itm in zip(
|
||||
range(len(dict_itm_scores_for_blib[itm_model])),
|
||||
dict_itm_scores_for_blib[itm_model],
|
||||
):
|
||||
assert (
|
||||
math.isclose(itm_scores[0].tolist()[i], itm, rel_tol=10 * related_error)
|
||||
is True
|
||||
)
|
||||
for i, grad_cam in zip(
|
||||
range(len(dict_image_gradcam_with_itm_for_blip[itm_model])),
|
||||
dict_image_gradcam_with_itm_for_blip[itm_model],
|
||||
):
|
||||
assert (
|
||||
math.isclose(
|
||||
image_gradcam_with_itm["A bus"]["IMG_2809"][0][0].tolist()[i],
|
||||
grad_cam,
|
||||
rel_tol=10 * related_error,
|
||||
)
|
||||
is True
|
||||
)
|
||||
del itm_scores, image_gradcam_with_itm
|
||||
cuda.empty_cache()
|
||||
|
||||
|
||||
@pytest.mark.long
|
||||
def test_itm_blip2_coco(get_test_my_dict, get_path):
|
||||
search_query3 = [
|
||||
{"text_input": "A bus"},
|
||||
{"image": get_path + "IMG_3758.png"},
|
||||
]
|
||||
image_keys = ["IMG_2746", "IMG_2809"]
|
||||
sorted_list = [[1, 0], [1, 0]]
|
||||
|
||||
(
|
||||
itm_scores,
|
||||
image_gradcam_with_itm,
|
||||
) = ms.MultimodalSearch.image_text_match_reordering(
|
||||
get_test_my_dict,
|
||||
search_query3,
|
||||
"blip2_coco",
|
||||
image_keys,
|
||||
sorted_list,
|
||||
batch_size=1,
|
||||
need_grad_cam=False,
|
||||
)
|
||||
for i, itm in zip(
|
||||
range(len(dict_itm_scores_for_blib["blip2_coco"])),
|
||||
dict_itm_scores_for_blib["blip2_coco"],
|
||||
):
|
||||
assert (
|
||||
math.isclose(itm_scores[0].tolist()[i], itm, rel_tol=10 * related_error)
|
||||
is True
|
||||
)
|
||||
del itm_scores, image_gradcam_with_itm
|
||||
cuda.empty_cache()
|
||||
|
||||
115
notebooks/multimodal_search.ipynb
сгенерированный
115
notebooks/multimodal_search.ipynb
сгенерированный
@ -138,9 +138,7 @@
|
||||
" image_keys,\n",
|
||||
" image_names,\n",
|
||||
" features_image_stacked,\n",
|
||||
") = ms.MultimodalSearch.parsing_images(\n",
|
||||
" mydict, model_type, path_to_saved_tensors=\"./saved_tensors/\"\n",
|
||||
")"
|
||||
") = ms.MultimodalSearch.parsing_images(mydict, model_type, path_to_saved_tensors=\".\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -170,7 +168,7 @@
|
||||
"# ) = ms.MultimodalSearch.parsing_images(\n",
|
||||
"# mydict,\n",
|
||||
"# model_type,\n",
|
||||
"# path_to_load_tensors=\"./saved_tensors/18_blip_saved_features_image.pt\",\n",
|
||||
"# path_to_load_tensors=\".5_blip_saved_features_image.pt\",\n",
|
||||
"# )"
|
||||
]
|
||||
},
|
||||
@ -194,15 +192,14 @@
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "c4196a52-d01e-42e4-8674-5712f7d6f792",
|
||||
"metadata": {},
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"search_query3 = [\n",
|
||||
" {\"text_input\": \"politician press conference\"},\n",
|
||||
" {\"text_input\": \"a world map\"},\n",
|
||||
" {\"image\": \"../data/haos.png\"},\n",
|
||||
" {\"image\": \"../data/image-34098-800.png\"},\n",
|
||||
" {\"image\": \"../data/LeonPresserMorocco20032015_600.png\"},\n",
|
||||
" {\"text_input\": \"a dog\"},\n",
|
||||
"]"
|
||||
]
|
||||
@ -222,10 +219,12 @@
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "7f7dc52f-7ee9-4590-96b7-e0d9d3b82378",
|
||||
"metadata": {},
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"similarity = ms.MultimodalSearch.multimodal_search(\n",
|
||||
"similarity, sorted_lists = ms.MultimodalSearch.multimodal_search(\n",
|
||||
" mydict,\n",
|
||||
" model,\n",
|
||||
" vis_processors,\n",
|
||||
@ -234,6 +233,7 @@
|
||||
" image_keys,\n",
|
||||
" features_image_stacked,\n",
|
||||
" search_query3,\n",
|
||||
" filter_number_of_images=20,\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
@ -249,10 +249,12 @@
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "9ad74b21-6187-4a58-9ed8-fd3e80f5a4ed",
|
||||
"metadata": {},
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"mydict[\"100127S_ara\"]"
|
||||
"mydict[\"109237S_spa\"]"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -267,10 +269,79 @@
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "4324e4fd-e9aa-4933-bb12-074d54e0c510",
|
||||
"metadata": {},
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"ms.MultimodalSearch.show_results(mydict, search_query3[4])"
|
||||
"ms.MultimodalSearch.show_results(\n",
|
||||
" mydict,\n",
|
||||
" search_query3[0],\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "0b750e9f-fe64-4028-9caf-52d7187462f1",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"For even better results, a slightly different approach has been prepared that can improve search results. It is quite resource-intensive, so it is applied after the main algorithm has found the most relevant images. This approach works only with text queries. Among the parameters you can choose 3 models: `\"blip_base\"`, `\"blip_large\"`, `\"blip2_coco\"`. If you get the Out of Memory error, try reducing the batch_size value (minimum = 1), which is the number of images being processed simultaneously. With the parameter `need_grad_cam = True/False` you can enable the calculation of the heat map of each image to be processed. Thus the `image_text_match_reordering` function calculates new similarity values and new ranks for each image. The resulting values are added to the general dictionary."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "b3af7b39-6d0d-4da3-9b8f-7dfd3f5779be",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"itm_model = \"blip_base\"\n",
|
||||
"# itm_model = \"blip_large\"\n",
|
||||
"# itm_model = \"blip2_coco\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "caf1f4ae-4b37-4954-800e-7120f0419de5",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"itm_scores, image_gradcam_with_itm = ms.MultimodalSearch.image_text_match_reordering(\n",
|
||||
" mydict,\n",
|
||||
" search_query3,\n",
|
||||
" itm_model,\n",
|
||||
" image_keys,\n",
|
||||
" sorted_lists,\n",
|
||||
" batch_size=1,\n",
|
||||
" need_grad_cam=True,\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "9e98c150-5fab-4251-bce7-0d8fc7b385b9",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Then using the same output function you can add the `ITM=True` arguments to output the new image order. You can also add the `image_gradcam_with_itm` argument to output the heat maps of the calculated images. "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "6a829b99-5230-463a-8b11-30ffbb67fc3a",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"ms.MultimodalSearch.show_results(\n",
|
||||
" mydict, search_query3[0], itm=True, image_gradcam_with_itm=image_gradcam_with_itm\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -320,7 +391,9 @@
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "e78646d6-80be-4d3e-8123-3360957bcaa8",
|
||||
"metadata": {},
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"df.head(10)"
|
||||
@ -338,11 +411,21 @@
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "185f7dde-20dc-44d8-9ab0-de41f9b5734d",
|
||||
"metadata": {},
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"df.to_csv(\"./data_out.csv\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "b6a79201-7c17-496c-a6a1-b8ecfd3dd1e8",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
|
||||
Загрузка…
x
Ссылка в новой задаче
Block a user