Merge pull request #60 from ssciwr/add_itm

Add itm functionality
Этот коммит содержится в:
Petr Andriushchenko 2023-04-19 13:12:53 +02:00 коммит произвёл GitHub
родитель 3b1c3ef1ed fae982bc8b
Коммит 80665e5f82
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
6 изменённых файлов: 552 добавлений и 65 удалений

2
.github/workflows/ci.yml поставляемый
Просмотреть файл

@ -32,7 +32,7 @@ jobs:
- name: Run pytest - name: Run pytest
run: | run: |
cd misinformation cd misinformation
python -m pytest -m "not gcv" -svv --cov=. --cov-report=xml python -m pytest -m "not gcv and not long" -svv --cov=. --cov-report=xml
- name: Upload coverage - name: Upload coverage
if: matrix.os == 'ubuntu-22.04' && matrix.python-version == '3.9' if: matrix.os == 'ubuntu-22.04' && matrix.python-version == '3.9'
uses: codecov/codecov-action@v3 uses: codecov/codecov-action@v3

Просмотреть файл

@ -113,7 +113,7 @@
" image_keys,\n", " image_keys,\n",
" image_names,\n", " image_names,\n",
" features_image_stacked,\n", " features_image_stacked,\n",
") = ms.MultimodalSearch.parsing_images(mydict, model_type)" ") = ms.MultimodalSearch.parsing_images(mydict, model_type, path_to_saved_tensors=\".\")"
] ]
}, },
{ {
@ -128,7 +128,9 @@
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"id": "56c6d488-f093-4661-835a-5c73a329c874", "id": "56c6d488-f093-4661-835a-5c73a329c874",
"metadata": {}, "metadata": {
"tags": []
},
"outputs": [], "outputs": [],
"source": [ "source": [
"# (\n", "# (\n",

Просмотреть файл

@ -4,9 +4,14 @@ import torch.nn.functional as Func
import requests import requests
import lavis import lavis
import os import os
import numpy as np
from PIL import Image from PIL import Image
from skimage import transform as skimage_transform
from scipy.ndimage import filters
from matplotlib import pyplot as plt
from IPython.display import display from IPython.display import display
from lavis.models import load_model_and_preprocess from lavis.models import load_model_and_preprocess, load_model, BlipBase
from lavis.processors import load_processor
class MultimodalSearch(AnalysisMethod): class MultimodalSearch(AnalysisMethod):
@ -233,7 +238,7 @@ class MultimodalSearch(AnalysisMethod):
} }
for query in search_query: for query in search_query:
if not (len(query) == 1) and (query in ("image", "text_input")): if len(query) != 1 and (query in ("image", "text_input")):
raise SyntaxError( raise SyntaxError(
'Each query must contain either an "image" or a "text_input"' 'Each query must contain either an "image" or a "text_input"'
) )
@ -343,7 +348,288 @@ class MultimodalSearch(AnalysisMethod):
self[image_keys[key]][list(search_query[q].values())[0]] = 0 self[image_keys[key]][list(search_query[q].values())[0]] = 0
return similarity, sorted_lists return similarity, sorted_lists
def show_results(self, query): def itm_text_precessing(self, search_query):
for query in search_query:
if not (len(query) == 1) and (query in ("image", "text_input")):
raise SyntaxError(
'Each querry must contain either an "image" or a "text_input"'
)
text_query_index = []
for i, query in zip(range(len(search_query)), search_query):
if "text_input" in query.keys():
text_query_index.append(i)
return text_query_index
def get_pathes_from_query(self, query):
paths = []
image_names = []
for s in sorted(
self.items(), key=lambda t: t[1][list(query.values())[0]], reverse=True
):
if s[1]["rank " + list(query.values())[0]] is None:
break
paths.append(s[1]["filename"])
image_names.append(s[0])
return paths, image_names
def read_and_process_images_itm(self, image_paths, vis_processor):
raw_images = [MultimodalSearch.read_img(self, path) for path in image_paths]
images = [vis_processor(r_img) for r_img in raw_images]
images_tensors = torch.stack(images).to(MultimodalSearch.multimodal_device)
return raw_images, images_tensors
def compute_gradcam_batch(
self,
model,
visual_input,
text_input,
tokenized_text,
block_num=6,
):
model.text_encoder.base_model.base_model.encoder.layer[
block_num
].crossattention.self.save_attention = True
output = model(
{"image": visual_input, "text_input": text_input}, match_head="itm"
)
loss = output[:, 1].sum()
model.zero_grad()
loss.backward()
with torch.no_grad():
mask = tokenized_text.attention_mask.view(
tokenized_text.attention_mask.size(0), 1, -1, 1, 1
) # (bsz,1,token_len, 1,1)
token_length = mask.sum() - 2
token_length = token_length.cpu()
# grads and cams [bsz, num_head, seq_len, image_patch]
grads = model.text_encoder.base_model.base_model.encoder.layer[
block_num
].crossattention.self.get_attn_gradients()
cams = model.text_encoder.base_model.base_model.encoder.layer[
block_num
].crossattention.self.get_attention_map()
# assume using vit large with 576 num image patch
cams = (
cams[:, :, :, 1:].reshape(visual_input.size(0), 12, -1, 24, 24) * mask
)
grads = (
grads[:, :, :, 1:]
.clamp(0)
.reshape(visual_input.size(0), 12, -1, 24, 24)
* mask
)
gradcam = cams * grads
# [enc token gradcam, average gradcam across token, gradcam for individual token]
# gradcam = torch.cat((gradcam[0:1,:], gradcam[1:token_length+1, :].sum(dim=0, keepdim=True)/token_length, gradcam[1:, :]))
gradcam = gradcam.mean(1).cpu().detach()
gradcam = (
gradcam[:, 1 : token_length + 1, :].sum(dim=1, keepdim=True)
/ token_length
)
return gradcam, output
def resize_img(self, raw_img):
w, h = raw_img.size
scaling_factor = 240 / w
resized_image = raw_img.resize(
(int(w * scaling_factor), int(h * scaling_factor))
)
return resized_image
def get_att_map(self, img, att_map, blur=True, overlap=True):
att_map -= att_map.min()
if att_map.max() > 0:
att_map /= att_map.max()
att_map = skimage_transform.resize(
att_map, (img.shape[:2]), order=3, mode="constant"
)
if blur:
att_map = filters.gaussian_filter(att_map, 0.02 * max(img.shape[:2]))
att_map -= att_map.min()
att_map /= att_map.max()
cmap = plt.get_cmap("jet")
att_mapv = cmap(att_map)
att_mapv = np.delete(att_mapv, 3, 2)
if overlap:
att_map = (
1 * (1 - att_map**0.7).reshape(att_map.shape + (1,)) * img
+ (att_map**0.7).reshape(att_map.shape + (1,)) * att_mapv
)
return att_map
def upload_model_blip2_coco(self):
itm_model = load_model(
"blip2_image_text_matching",
"coco",
is_eval=True,
device=MultimodalSearch.multimodal_device,
)
vis_processor = load_processor("blip_image_eval").build(image_size=364)
return itm_model, vis_processor
def upload_model_blip_base(self):
itm_model = load_model(
"blip_image_text_matching",
"base",
is_eval=True,
device=MultimodalSearch.multimodal_device,
)
vis_processor = load_processor("blip_image_eval").build(image_size=384)
return itm_model, vis_processor
def upload_model_blip_large(self):
itm_model = load_model(
"blip_image_text_matching",
"large",
is_eval=True,
device=MultimodalSearch.multimodal_device,
)
vis_processor = load_processor("blip_image_eval").build(image_size=384)
return itm_model, vis_processor
def image_text_match_reordering(
self,
search_query,
itm_model_type,
image_keys,
sorted_lists,
batch_size=1,
need_grad_cam=False,
):
if itm_model_type == "blip2_coco" and need_grad_cam is True:
raise SyntaxError(
"The blip2_coco model does not yet work with gradcam. Please set need_grad_cam to False"
)
choose_model = {
"blip_base": MultimodalSearch.upload_model_blip_base,
"blip_large": MultimodalSearch.upload_model_blip_large,
"blip2_coco": MultimodalSearch.upload_model_blip2_coco,
}
itm_model, vis_processor_itm = choose_model[itm_model_type](self)
text_processor = load_processor("blip_caption")
tokenizer = BlipBase.init_tokenizer()
if itm_model_type == "blip2_coco":
need_grad_cam = False
text_query_index = MultimodalSearch.itm_text_precessing(self, search_query)
avg_gradcams = []
itm_scores = []
itm_scores2 = []
image_gradcam_with_itm = {}
for index_text_query in text_query_index:
query = search_query[index_text_query]
pathes, image_names = MultimodalSearch.get_pathes_from_query(self, query)
num_batches = int(len(pathes) / batch_size)
num_batches_residue = len(pathes) % batch_size
local_itm_scores = []
local_avg_gradcams = []
if num_batches_residue != 0:
num_batches = num_batches + 1
for i in range(num_batches):
filenames_in_batch = pathes[i * batch_size : (i + 1) * batch_size]
current_len = len(filenames_in_batch)
raw_images, images = MultimodalSearch.read_and_process_images_itm(
self, filenames_in_batch, vis_processor_itm
)
queries_batch = [text_processor(query["text_input"])] * current_len
queries_tok_batch = tokenizer(queries_batch, return_tensors="pt").to(
MultimodalSearch.multimodal_device
)
if need_grad_cam:
gradcam, itm_output = MultimodalSearch.compute_gradcam_batch(
self,
itm_model,
images,
queries_batch,
queries_tok_batch,
)
norm_imgs = [np.float32(r_img) / 255 for r_img in raw_images]
for norm_img, grad_cam in zip(norm_imgs, gradcam):
avg_gradcam = MultimodalSearch.get_att_map(
self, norm_img, np.float32(grad_cam[0]), blur=True
)
local_avg_gradcams.append(avg_gradcam)
else:
itm_output = itm_model(
{"image": images, "text_input": queries_batch}, match_head="itm"
)
with torch.no_grad():
itm_score = torch.nn.functional.softmax(itm_output, dim=1)
local_itm_scores.append(itm_score)
local_itm_scores2 = torch.cat(local_itm_scores)[:, 1]
if need_grad_cam:
localimage_gradcam_with_itm = {
n: i * 255 for n, i in zip(image_names, local_avg_gradcams)
}
else:
localimage_gradcam_with_itm = ""
image_names_with_itm = {
n: i.item() for n, i in zip(image_names, local_itm_scores2)
}
itm_rank = torch.argsort(local_itm_scores2, descending=True)
image_names_with_new_rank = {
image_names[i.item()]: rank
for i, rank in zip(itm_rank, range(len(itm_rank)))
}
for i, key in zip(range(len(image_keys)), sorted_lists[index_text_query]):
if image_keys[key] in image_names:
self[image_keys[key]][
"itm " + list(search_query[index_text_query].values())[0]
] = image_names_with_itm[image_keys[key]]
self[image_keys[key]][
"itm_rank " + list(search_query[index_text_query].values())[0]
] = image_names_with_new_rank[image_keys[key]]
else:
self[image_keys[key]][
"itm " + list(search_query[index_text_query].values())[0]
] = 0
self[image_keys[key]][
"itm_rank " + list(search_query[index_text_query].values())[0]
] = None
avg_gradcams.append(local_avg_gradcams)
itm_scores.append(local_itm_scores)
itm_scores2.append(local_itm_scores2)
image_gradcam_with_itm[
list(search_query[index_text_query].values())[0]
] = localimage_gradcam_with_itm
del (
itm_model,
vis_processor_itm,
text_processor,
raw_images,
images,
tokenizer,
queries_batch,
queries_tok_batch,
itm_score,
)
if need_grad_cam:
del itm_output, gradcam, norm_img, grad_cam, avg_gradcam
torch.cuda.empty_cache()
return itm_scores2, image_gradcam_with_itm
def show_results(self, query, itm=False, image_gradcam_with_itm=False):
if "image" in query.keys(): if "image" in query.keys():
pic = Image.open(query["image"]).convert("RGB") pic = Image.open(query["image"]).convert("RGB")
pic.thumbnail((400, 400)) pic.thumbnail((400, 400))
@ -359,18 +645,29 @@ class MultimodalSearch(AnalysisMethod):
"--------------------------------------------------", "--------------------------------------------------",
"Results:", "Results:",
) )
if itm:
current_querry_val = "itm " + list(query.values())[0]
current_querry_rank = "itm_rank " + list(query.values())[0]
else:
current_querry_val = list(query.values())[0]
current_querry_rank = "rank " + list(query.values())[0]
for s in sorted( for s in sorted(
self.items(), key=lambda t: t[1][list(query.values())[0]], reverse=True self.items(), key=lambda t: t[1][current_querry_val], reverse=True
): ):
if s[1]["rank " + list(query.values())[0]] is None: if s[1][current_querry_rank] is None:
break break
p1 = Image.open(s[1]["filename"]).convert("RGB") if image_gradcam_with_itm is False:
p1 = Image.open(s[1]["filename"]).convert("RGB")
else:
image = image_gradcam_with_itm[list(query.values())[0]][s[0]]
p1 = Image.fromarray(image.astype("uint8"), "RGB")
p1.thumbnail((400, 400)) p1.thumbnail((400, 400))
display( display(
"Rank: " "Rank: "
+ str(s[1]["rank " + list(query.values())[0]]) + str(s[1][current_querry_rank])
+ " Val: " + " Val: "
+ str(s[1][list(query.values())[0]]), + str(s[1][current_querry_val]),
s[0], s[0],
p1, p1,
) )

Просмотреть файл

@ -16,3 +16,33 @@ def set_environ(request):
mypath + "/../../data/seismic-bonfire-329406-412821a70264.json" mypath + "/../../data/seismic-bonfire-329406-412821a70264.json"
) )
print(os.environ.get("GOOGLE_APPLICATION_CREDENTIALS")) print(os.environ.get("GOOGLE_APPLICATION_CREDENTIALS"))
@pytest.fixture
def get_testdict(get_path):
testdict = {
"IMG_2746": {"filename": get_path + "IMG_2746.png"},
"IMG_2809": {"filename": get_path + "IMG_2809.png"},
}
return testdict
@pytest.fixture
def get_test_my_dict(get_path):
test_my_dict = {
"IMG_2746": {
"filename": get_path + "IMG_2746.png",
"rank A bus": 1,
"A bus": 0.15640679001808167,
"rank " + get_path + "IMG_3758.png": 1,
get_path + "IMG_3758.png": 0.7533495426177979,
},
"IMG_2809": {
"filename": get_path + "IMG_2809.png",
"rank A bus": 0,
"A bus": 0.1970970332622528,
"rank " + get_path + "IMG_3758.png": 0,
get_path + "IMG_3758.png": 0.8907483816146851,
},
}
return test_my_dict

Просмотреть файл

@ -5,22 +5,17 @@ import numpy
from torch import device, cuda from torch import device, cuda
import misinformation.multimodal_search as ms import misinformation.multimodal_search as ms
testdict = {
"IMG_2746": {"filename": "./test/data/IMG_2746.png"},
"IMG_2809": {"filename": "./test/data/IMG_2809.png"},
}
related_error = 1e-2 related_error = 1e-2
gpu_is_not_available = not cuda.is_available() gpu_is_not_available = not cuda.is_available()
cuda.empty_cache() cuda.empty_cache()
def test_read_img(): def test_read_img(get_testdict):
my_dict = {} my_dict = {}
test_img = ms.MultimodalSearch.read_img(my_dict, testdict["IMG_2746"]["filename"]) test_img = ms.MultimodalSearch.read_img(
my_dict, get_testdict["IMG_2746"]["filename"]
)
assert list(numpy.array(test_img)[257][34]) == [70, 66, 63] assert list(numpy.array(test_img)[257][34]) == [70, 66, 63]
@ -205,29 +200,29 @@ dict_image_gradcam_with_itm_for_blip = {
"pre_sorted", "pre_sorted",
), ),
[ [
# ( (
# device("cpu"), device("cpu"),
# "blip2", "blip2",
# pre_proc_pic_blip2_blip_albef, pre_proc_pic_blip2_blip_albef,
# pre_proc_text_blip2_blip_albef, pre_proc_text_blip2_blip_albef,
# pre_extracted_feature_img_blip2, pre_extracted_feature_img_blip2,
# pre_extracted_feature_text_blip2, pre_extracted_feature_text_blip2,
# simularity_blip2, simularity_blip2,
# sorted_blip2, sorted_blip2,
# ), ),
# pytest.param( pytest.param(
# device("cuda"), device("cuda"),
# "blip2", "blip2",
# pre_proc_pic_blip2_blip_albef, pre_proc_pic_blip2_blip_albef,
# pre_proc_text_blip2_blip_albef, pre_proc_text_blip2_blip_albef,
# pre_extracted_feature_img_blip2, pre_extracted_feature_img_blip2,
# pre_extracted_feature_text_blip2, pre_extracted_feature_text_blip2,
# simularity_blip2, simularity_blip2,
# sorted_blip2, sorted_blip2,
# marks=pytest.mark.skipif( marks=pytest.mark.skipif(
# gpu_is_not_available, reason="gpu_is_not_availible" gpu_is_not_available, reason="gpu_is_not_availible"
# ), ),
# ), ),
( (
device("cpu"), device("cpu"),
"blip", "blip",
@ -354,6 +349,8 @@ def test_parsing_images(
pre_extracted_feature_text, pre_extracted_feature_text,
pre_simularity, pre_simularity,
pre_sorted, pre_sorted,
get_path,
get_testdict,
tmp_path, tmp_path,
): ):
ms.MultimodalSearch.multimodal_device = pre_multimodal_device ms.MultimodalSearch.multimodal_device = pre_multimodal_device
@ -365,7 +362,7 @@ def test_parsing_images(
_, _,
features_image_stacked, features_image_stacked,
) = ms.MultimodalSearch.parsing_images( ) = ms.MultimodalSearch.parsing_images(
testdict, pre_model, path_to_saved_tensors=tmp_path get_testdict, pre_model, path_to_saved_tensors=tmp_path
) )
for i, num in zip(range(10), features_image_stacked[0, 10:12].tolist()): for i, num in zip(range(10), features_image_stacked[0, 10:12].tolist()):
@ -374,7 +371,7 @@ def test_parsing_images(
is True is True
) )
test_pic = Image.open(testdict["IMG_2746"]["filename"]).convert("RGB") test_pic = Image.open(get_testdict["IMG_2746"]["filename"]).convert("RGB")
test_querry = ( test_querry = (
"The bird sat on a tree located at the intersection of 23rd and 43rd streets." "The bird sat on a tree located at the intersection of 23rd and 43rd streets."
) )
@ -390,10 +387,10 @@ def test_parsing_images(
search_query = [ search_query = [
{"text_input": test_querry}, {"text_input": test_querry},
{"image": testdict["IMG_2746"]["filename"]}, {"image": get_testdict["IMG_2746"]["filename"]},
] ]
multi_features_stacked = ms.MultimodalSearch.querys_processing( multi_features_stacked = ms.MultimodalSearch.querys_processing(
testdict, search_query, model, txt_processor, vis_processor, pre_model get_testdict, search_query, model, txt_processor, vis_processor, pre_model
) )
for i, num in zip(range(10), multi_features_stacked[0, 10:12].tolist()): for i, num in zip(range(10), multi_features_stacked[0, 10:12].tolist()):
@ -410,11 +407,11 @@ def test_parsing_images(
search_query2 = [ search_query2 = [
{"text_input": "A bus"}, {"text_input": "A bus"},
{"image": "../misinformation/test/data/IMG_3758.png"}, {"image": get_path + "IMG_3758.png"},
] ]
similarity, sorted_list = ms.MultimodalSearch.multimodal_search( similarity, sorted_list = ms.MultimodalSearch.multimodal_search(
testdict, get_testdict,
model, model,
vis_processor, vis_processor,
txt_processor, txt_processor,
@ -445,3 +442,81 @@ def test_parsing_images(
multi_features_stacked, multi_features_stacked,
) )
cuda.empty_cache() cuda.empty_cache()
@pytest.mark.long
def test_itm(get_test_my_dict, get_path):
search_query3 = [
{"text_input": "A bus"},
{"image": get_path + "IMG_3758.png"},
]
image_keys = ["IMG_2746", "IMG_2809"]
sorted_list = [[1, 0], [1, 0]]
for itm_model in ["blip_base", "blip_large"]:
(
itm_scores,
image_gradcam_with_itm,
) = ms.MultimodalSearch.image_text_match_reordering(
get_test_my_dict,
search_query3,
itm_model,
image_keys,
sorted_list,
batch_size=1,
need_grad_cam=True,
)
for i, itm in zip(
range(len(dict_itm_scores_for_blib[itm_model])),
dict_itm_scores_for_blib[itm_model],
):
assert (
math.isclose(itm_scores[0].tolist()[i], itm, rel_tol=10 * related_error)
is True
)
for i, grad_cam in zip(
range(len(dict_image_gradcam_with_itm_for_blip[itm_model])),
dict_image_gradcam_with_itm_for_blip[itm_model],
):
assert (
math.isclose(
image_gradcam_with_itm["A bus"]["IMG_2809"][0][0].tolist()[i],
grad_cam,
rel_tol=10 * related_error,
)
is True
)
del itm_scores, image_gradcam_with_itm
cuda.empty_cache()
@pytest.mark.long
def test_itm_blip2_coco(get_test_my_dict, get_path):
search_query3 = [
{"text_input": "A bus"},
{"image": get_path + "IMG_3758.png"},
]
image_keys = ["IMG_2746", "IMG_2809"]
sorted_list = [[1, 0], [1, 0]]
(
itm_scores,
image_gradcam_with_itm,
) = ms.MultimodalSearch.image_text_match_reordering(
get_test_my_dict,
search_query3,
"blip2_coco",
image_keys,
sorted_list,
batch_size=1,
need_grad_cam=False,
)
for i, itm in zip(
range(len(dict_itm_scores_for_blib["blip2_coco"])),
dict_itm_scores_for_blib["blip2_coco"],
):
assert (
math.isclose(itm_scores[0].tolist()[i], itm, rel_tol=10 * related_error)
is True
)
del itm_scores, image_gradcam_with_itm
cuda.empty_cache()

115
notebooks/multimodal_search.ipynb сгенерированный
Просмотреть файл

@ -138,9 +138,7 @@
" image_keys,\n", " image_keys,\n",
" image_names,\n", " image_names,\n",
" features_image_stacked,\n", " features_image_stacked,\n",
") = ms.MultimodalSearch.parsing_images(\n", ") = ms.MultimodalSearch.parsing_images(mydict, model_type, path_to_saved_tensors=\".\")"
" mydict, model_type, path_to_saved_tensors=\"./saved_tensors/\"\n",
")"
] ]
}, },
{ {
@ -170,7 +168,7 @@
"# ) = ms.MultimodalSearch.parsing_images(\n", "# ) = ms.MultimodalSearch.parsing_images(\n",
"# mydict,\n", "# mydict,\n",
"# model_type,\n", "# model_type,\n",
"# path_to_load_tensors=\"./saved_tensors/18_blip_saved_features_image.pt\",\n", "# path_to_load_tensors=\".5_blip_saved_features_image.pt\",\n",
"# )" "# )"
] ]
}, },
@ -194,15 +192,14 @@
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"id": "c4196a52-d01e-42e4-8674-5712f7d6f792", "id": "c4196a52-d01e-42e4-8674-5712f7d6f792",
"metadata": {}, "metadata": {
"tags": []
},
"outputs": [], "outputs": [],
"source": [ "source": [
"search_query3 = [\n", "search_query3 = [\n",
" {\"text_input\": \"politician press conference\"},\n", " {\"text_input\": \"politician press conference\"},\n",
" {\"text_input\": \"a world map\"},\n", " {\"text_input\": \"a world map\"},\n",
" {\"image\": \"../data/haos.png\"},\n",
" {\"image\": \"../data/image-34098-800.png\"},\n",
" {\"image\": \"../data/LeonPresserMorocco20032015_600.png\"},\n",
" {\"text_input\": \"a dog\"},\n", " {\"text_input\": \"a dog\"},\n",
"]" "]"
] ]
@ -222,10 +219,12 @@
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"id": "7f7dc52f-7ee9-4590-96b7-e0d9d3b82378", "id": "7f7dc52f-7ee9-4590-96b7-e0d9d3b82378",
"metadata": {}, "metadata": {
"tags": []
},
"outputs": [], "outputs": [],
"source": [ "source": [
"similarity = ms.MultimodalSearch.multimodal_search(\n", "similarity, sorted_lists = ms.MultimodalSearch.multimodal_search(\n",
" mydict,\n", " mydict,\n",
" model,\n", " model,\n",
" vis_processors,\n", " vis_processors,\n",
@ -234,6 +233,7 @@
" image_keys,\n", " image_keys,\n",
" features_image_stacked,\n", " features_image_stacked,\n",
" search_query3,\n", " search_query3,\n",
" filter_number_of_images=20,\n",
")" ")"
] ]
}, },
@ -249,10 +249,12 @@
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"id": "9ad74b21-6187-4a58-9ed8-fd3e80f5a4ed", "id": "9ad74b21-6187-4a58-9ed8-fd3e80f5a4ed",
"metadata": {}, "metadata": {
"tags": []
},
"outputs": [], "outputs": [],
"source": [ "source": [
"mydict[\"100127S_ara\"]" "mydict[\"109237S_spa\"]"
] ]
}, },
{ {
@ -267,10 +269,79 @@
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"id": "4324e4fd-e9aa-4933-bb12-074d54e0c510", "id": "4324e4fd-e9aa-4933-bb12-074d54e0c510",
"metadata": {}, "metadata": {
"tags": []
},
"outputs": [], "outputs": [],
"source": [ "source": [
"ms.MultimodalSearch.show_results(mydict, search_query3[4])" "ms.MultimodalSearch.show_results(\n",
" mydict,\n",
" search_query3[0],\n",
")"
]
},
{
"cell_type": "markdown",
"id": "0b750e9f-fe64-4028-9caf-52d7187462f1",
"metadata": {},
"source": [
"For even better results, a slightly different approach has been prepared that can improve search results. It is quite resource-intensive, so it is applied after the main algorithm has found the most relevant images. This approach works only with text queries. Among the parameters you can choose 3 models: `\"blip_base\"`, `\"blip_large\"`, `\"blip2_coco\"`. If you get the Out of Memory error, try reducing the batch_size value (minimum = 1), which is the number of images being processed simultaneously. With the parameter `need_grad_cam = True/False` you can enable the calculation of the heat map of each image to be processed. Thus the `image_text_match_reordering` function calculates new similarity values and new ranks for each image. The resulting values are added to the general dictionary."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b3af7b39-6d0d-4da3-9b8f-7dfd3f5779be",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"itm_model = \"blip_base\"\n",
"# itm_model = \"blip_large\"\n",
"# itm_model = \"blip2_coco\""
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "caf1f4ae-4b37-4954-800e-7120f0419de5",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"itm_scores, image_gradcam_with_itm = ms.MultimodalSearch.image_text_match_reordering(\n",
" mydict,\n",
" search_query3,\n",
" itm_model,\n",
" image_keys,\n",
" sorted_lists,\n",
" batch_size=1,\n",
" need_grad_cam=True,\n",
")"
]
},
{
"cell_type": "markdown",
"id": "9e98c150-5fab-4251-bce7-0d8fc7b385b9",
"metadata": {},
"source": [
"Then using the same output function you can add the `ITM=True` arguments to output the new image order. You can also add the `image_gradcam_with_itm` argument to output the heat maps of the calculated images. "
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6a829b99-5230-463a-8b11-30ffbb67fc3a",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"ms.MultimodalSearch.show_results(\n",
" mydict, search_query3[0], itm=True, image_gradcam_with_itm=image_gradcam_with_itm\n",
")"
] ]
}, },
{ {
@ -320,7 +391,9 @@
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"id": "e78646d6-80be-4d3e-8123-3360957bcaa8", "id": "e78646d6-80be-4d3e-8123-3360957bcaa8",
"metadata": {}, "metadata": {
"tags": []
},
"outputs": [], "outputs": [],
"source": [ "source": [
"df.head(10)" "df.head(10)"
@ -338,11 +411,21 @@
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"id": "185f7dde-20dc-44d8-9ab0-de41f9b5734d", "id": "185f7dde-20dc-44d8-9ab0-de41f9b5734d",
"metadata": {}, "metadata": {
"tags": []
},
"outputs": [], "outputs": [],
"source": [ "source": [
"df.to_csv(\"./data_out.csv\")" "df.to_csv(\"./data_out.csv\")"
] ]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b6a79201-7c17-496c-a6a1-b8ecfd3dd1e8",
"metadata": {},
"outputs": [],
"source": []
} }
], ],
"metadata": { "metadata": {