зеркало из
https://github.com/ssciwr/AMMICO.git
synced 2025-10-30 05:26:05 +02:00
added new itm functionality
Этот коммит содержится в:
родитель
65d916921b
Коммит
65dbf28eef
@ -3,9 +3,14 @@ import torch
|
|||||||
import torch.nn.functional as Func
|
import torch.nn.functional as Func
|
||||||
import requests
|
import requests
|
||||||
import lavis
|
import lavis
|
||||||
|
import numpy as np
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
from skimage import transform as skimage_transform
|
||||||
|
from scipy.ndimage import filters
|
||||||
|
from matplotlib import pyplot as plt
|
||||||
from IPython.display import display
|
from IPython.display import display
|
||||||
from lavis.models import load_model_and_preprocess
|
from lavis.models import load_model_and_preprocess, load_model, BlipBase
|
||||||
|
from lavis.processors import load_processor
|
||||||
|
|
||||||
|
|
||||||
class MultimodalSearch(AnalysisMethod):
|
class MultimodalSearch(AnalysisMethod):
|
||||||
@ -301,7 +306,6 @@ class MultimodalSearch(AnalysisMethod):
|
|||||||
|
|
||||||
for q in range(len(search_query)):
|
for q in range(len(search_query)):
|
||||||
max_val = similarity[sorted_lists[q][0]][q].item()
|
max_val = similarity[sorted_lists[q][0]][q].item()
|
||||||
print(max_val)
|
|
||||||
for i, key in zip(range(len(image_keys)), sorted_lists[q]):
|
for i, key in zip(range(len(image_keys)), sorted_lists[q]):
|
||||||
if (
|
if (
|
||||||
i < filter_number_of_images
|
i < filter_number_of_images
|
||||||
@ -322,7 +326,278 @@ class MultimodalSearch(AnalysisMethod):
|
|||||||
self[image_keys[key]][list(search_query[q].values())[0]] = 0
|
self[image_keys[key]][list(search_query[q].values())[0]] = 0
|
||||||
return similarity, sorted_lists
|
return similarity, sorted_lists
|
||||||
|
|
||||||
def show_results(self, query):
|
def itm_text_precessing(self, search_query):
|
||||||
|
for query in search_query:
|
||||||
|
if not (len(query) == 1) and (query in ("image", "text_input")):
|
||||||
|
raise SyntaxError(
|
||||||
|
'Each querry must contain either an "image" or a "text_input"'
|
||||||
|
)
|
||||||
|
text_query_index = []
|
||||||
|
for i, query in zip(range(len(search_query)), search_query):
|
||||||
|
if "text_input" in query.keys():
|
||||||
|
text_query_index.append(i)
|
||||||
|
|
||||||
|
return text_query_index
|
||||||
|
|
||||||
|
def itm_images_processing(self, image_paths, vis_processor):
|
||||||
|
raw_images = [MultimodalSearch.read_img(self, path) for path in image_paths]
|
||||||
|
images = [vis_processor(r_img) for r_img in raw_images]
|
||||||
|
images_tensors = torch.stack(images).to(MultimodalSearch.multimodal_device)
|
||||||
|
|
||||||
|
return raw_images, images_tensors
|
||||||
|
|
||||||
|
def get_pathes_from_query(self, query):
|
||||||
|
paths = []
|
||||||
|
image_names = []
|
||||||
|
for s in sorted(
|
||||||
|
self.items(), key=lambda t: t[1][list(query.values())[0]], reverse=True
|
||||||
|
):
|
||||||
|
if s[1]["rank " + list(query.values())[0]] is None:
|
||||||
|
break
|
||||||
|
paths.append(s[1]["filename"])
|
||||||
|
image_names.append(s[0])
|
||||||
|
return paths, image_names
|
||||||
|
|
||||||
|
def read_and_process_images_itm(self, image_paths, vis_processor):
|
||||||
|
raw_images = [MultimodalSearch.read_img(self, path) for path in image_paths]
|
||||||
|
images = [vis_processor(r_img) for r_img in raw_images]
|
||||||
|
images_tensors = torch.stack(images).to(MultimodalSearch.multimodal_device)
|
||||||
|
|
||||||
|
return raw_images, images_tensors
|
||||||
|
|
||||||
|
def compute_gradcam_batch(
|
||||||
|
self,
|
||||||
|
itm_model_type,
|
||||||
|
model,
|
||||||
|
visual_input,
|
||||||
|
text_input,
|
||||||
|
tokenized_text,
|
||||||
|
block_num=6,
|
||||||
|
):
|
||||||
|
if itm_model_type != "blip2_coco":
|
||||||
|
model.text_encoder.base_model.base_model.encoder.layer[
|
||||||
|
block_num
|
||||||
|
].crossattention.self.save_attention = True
|
||||||
|
|
||||||
|
output = model(
|
||||||
|
{"image": visual_input, "text_input": text_input}, match_head="itm"
|
||||||
|
)
|
||||||
|
loss = output[:, 1].sum()
|
||||||
|
|
||||||
|
model.zero_grad()
|
||||||
|
loss.backward()
|
||||||
|
with torch.no_grad():
|
||||||
|
mask = tokenized_text.attention_mask.view(
|
||||||
|
tokenized_text.attention_mask.size(0), 1, -1, 1, 1
|
||||||
|
) # (bsz,1,token_len, 1,1)
|
||||||
|
token_length = mask.sum() - 2
|
||||||
|
token_length = token_length.cpu()
|
||||||
|
# grads and cams [bsz, num_head, seq_len, image_patch]
|
||||||
|
grads = model.text_encoder.base_model.base_model.encoder.layer[
|
||||||
|
block_num
|
||||||
|
].crossattention.self.get_attn_gradients()
|
||||||
|
cams = model.text_encoder.base_model.base_model.encoder.layer[
|
||||||
|
block_num
|
||||||
|
].crossattention.self.get_attention_map()
|
||||||
|
|
||||||
|
# assume using vit large with 576 num image patch
|
||||||
|
cams = (
|
||||||
|
cams[:, :, :, 1:].reshape(visual_input.size(0), 12, -1, 24, 24) * mask
|
||||||
|
)
|
||||||
|
grads = (
|
||||||
|
grads[:, :, :, 1:]
|
||||||
|
.clamp(0)
|
||||||
|
.reshape(visual_input.size(0), 12, -1, 24, 24)
|
||||||
|
* mask
|
||||||
|
)
|
||||||
|
|
||||||
|
gradcam = cams * grads
|
||||||
|
# [enc token gradcam, average gradcam across token, gradcam for individual token]
|
||||||
|
# gradcam = torch.cat((gradcam[0:1,:], gradcam[1:token_length+1, :].sum(dim=0, keepdim=True)/token_length, gradcam[1:, :]))
|
||||||
|
gradcam = gradcam.mean(1).cpu().detach()
|
||||||
|
gradcam = (
|
||||||
|
gradcam[:, 1 : token_length + 1, :].sum(dim=1, keepdim=True)
|
||||||
|
/ token_length
|
||||||
|
)
|
||||||
|
|
||||||
|
return gradcam, output
|
||||||
|
|
||||||
|
def resize_img(self, raw_img):
|
||||||
|
w, h = raw_img.size
|
||||||
|
scaling_factor = 240 / w
|
||||||
|
resized_image = raw_img.resize(
|
||||||
|
(int(w * scaling_factor), int(h * scaling_factor))
|
||||||
|
)
|
||||||
|
return resized_image
|
||||||
|
|
||||||
|
def getAttMap(self, img, attMap, blur=True, overlap=True):
|
||||||
|
attMap -= attMap.min()
|
||||||
|
if attMap.max() > 0:
|
||||||
|
attMap /= attMap.max()
|
||||||
|
attMap = skimage_transform.resize(
|
||||||
|
attMap, (img.shape[:2]), order=3, mode="constant"
|
||||||
|
)
|
||||||
|
if blur:
|
||||||
|
attMap = filters.gaussian_filter(attMap, 0.02 * max(img.shape[:2]))
|
||||||
|
attMap -= attMap.min()
|
||||||
|
attMap /= attMap.max()
|
||||||
|
cmap = plt.get_cmap("jet")
|
||||||
|
attMapV = cmap(attMap)
|
||||||
|
attMapV = np.delete(attMapV, 3, 2)
|
||||||
|
if overlap:
|
||||||
|
attMap = (
|
||||||
|
1 * (1 - attMap**0.7).reshape(attMap.shape + (1,)) * img
|
||||||
|
+ (attMap**0.7).reshape(attMap.shape + (1,)) * attMapV
|
||||||
|
)
|
||||||
|
return attMap
|
||||||
|
|
||||||
|
def upload_model_blip2_coco(self):
|
||||||
|
|
||||||
|
itm_model = load_model(
|
||||||
|
"blip2_image_text_matching",
|
||||||
|
"coco",
|
||||||
|
is_eval=True,
|
||||||
|
device=MultimodalSearch.multimodal_device,
|
||||||
|
)
|
||||||
|
vis_processor = load_processor("blip_image_eval").build(image_size=364)
|
||||||
|
return itm_model, vis_processor
|
||||||
|
|
||||||
|
def upload_model_blip_base(self):
|
||||||
|
|
||||||
|
itm_model = load_model(
|
||||||
|
"blip_image_text_matching",
|
||||||
|
"base",
|
||||||
|
is_eval=True,
|
||||||
|
device=MultimodalSearch.multimodal_device,
|
||||||
|
)
|
||||||
|
vis_processor = load_processor("blip_image_eval").build(image_size=384)
|
||||||
|
return itm_model, vis_processor
|
||||||
|
|
||||||
|
def upload_model_blip_large(self):
|
||||||
|
|
||||||
|
itm_model = load_model(
|
||||||
|
"blip_image_text_matching",
|
||||||
|
"large",
|
||||||
|
is_eval=True,
|
||||||
|
device=MultimodalSearch.multimodal_device,
|
||||||
|
)
|
||||||
|
vis_processor = load_processor("blip_image_eval").build(image_size=384)
|
||||||
|
return itm_model, vis_processor
|
||||||
|
|
||||||
|
def image_text_match_reordering(
|
||||||
|
self,
|
||||||
|
search_query,
|
||||||
|
itm_model_type,
|
||||||
|
image_keys,
|
||||||
|
sorted_lists,
|
||||||
|
batch_size=1,
|
||||||
|
need_grad_cam=False,
|
||||||
|
):
|
||||||
|
choose_model = {
|
||||||
|
" ": MultimodalSearch.upload_model_blip_base,
|
||||||
|
"blip_large": MultimodalSearch.upload_model_blip_large,
|
||||||
|
"blip2_coco": MultimodalSearch.upload_model_blip2_coco,
|
||||||
|
}
|
||||||
|
itm_model, vis_processor = choose_model[itm_model_type](self)
|
||||||
|
text_processor = load_processor("blip_caption")
|
||||||
|
tokenizer = BlipBase.init_tokenizer()
|
||||||
|
|
||||||
|
text_query_index = MultimodalSearch.itm_text_precessing(self, search_query)
|
||||||
|
|
||||||
|
avg_gradcams = []
|
||||||
|
itm_scores = []
|
||||||
|
itm_scores2 = []
|
||||||
|
image_gradcam_with_itm = {}
|
||||||
|
|
||||||
|
for index_text_query in text_query_index:
|
||||||
|
query = search_query[index_text_query]
|
||||||
|
pathes, image_names = MultimodalSearch.get_pathes_from_query(self, query)
|
||||||
|
num_batches = int(len(pathes) / batch_size)
|
||||||
|
num_batches_residue = len(pathes) % batch_size
|
||||||
|
|
||||||
|
local_itm_scores = []
|
||||||
|
local_avg_gradcams = []
|
||||||
|
|
||||||
|
if num_batches_residue != 0:
|
||||||
|
num_batches = num_batches + 1
|
||||||
|
for i in range(num_batches):
|
||||||
|
filenames_in_batch = pathes[i * batch_size : (i + 1) * batch_size]
|
||||||
|
current_len = len(filenames_in_batch)
|
||||||
|
raw_images, images = MultimodalSearch.read_and_process_images_itm(
|
||||||
|
self, filenames_in_batch, vis_processor
|
||||||
|
)
|
||||||
|
queries_batch = [text_processor(query["text_input"])] * current_len
|
||||||
|
queries_tok_batch = tokenizer(queries_batch, return_tensors="pt").to(
|
||||||
|
MultimodalSearch.multimodal_device
|
||||||
|
)
|
||||||
|
|
||||||
|
if need_grad_cam:
|
||||||
|
gradcam, itm_output = MultimodalSearch.compute_gradcam_batch(
|
||||||
|
self,
|
||||||
|
itm_model_type,
|
||||||
|
itm_model,
|
||||||
|
images,
|
||||||
|
queries_batch,
|
||||||
|
queries_tok_batch,
|
||||||
|
)
|
||||||
|
norm_imgs = [np.float32(r_img) / 255 for r_img in raw_images]
|
||||||
|
|
||||||
|
for norm_img, grad_cam in zip(norm_imgs, gradcam):
|
||||||
|
avg_gradcam = MultimodalSearch.getAttMap(
|
||||||
|
self, norm_img, grad_cam[0], blur=True
|
||||||
|
)
|
||||||
|
local_avg_gradcams.append(avg_gradcam)
|
||||||
|
|
||||||
|
else:
|
||||||
|
itm_output = itm_model(
|
||||||
|
{"image": images, "text_input": queries_batch}, match_head="itm"
|
||||||
|
)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
itm_score = torch.nn.functional.softmax(itm_output, dim=1)
|
||||||
|
|
||||||
|
local_itm_scores.append(itm_score)
|
||||||
|
|
||||||
|
local_itm_scores2 = torch.cat(local_itm_scores)[:, 1]
|
||||||
|
if need_grad_cam:
|
||||||
|
localimage_gradcam_with_itm = {
|
||||||
|
n: i * 255 for n, i in zip(image_names, local_avg_gradcams)
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
localimage_gradcam_with_itm = ""
|
||||||
|
image_names_with_itm = {
|
||||||
|
n: i.item() for n, i in zip(image_names, local_itm_scores2)
|
||||||
|
}
|
||||||
|
itm_rank = torch.argsort(local_itm_scores2, descending=True)
|
||||||
|
image_names_with_new_rank = {
|
||||||
|
image_names[i.item()]: rank
|
||||||
|
for i, rank in zip(itm_rank, range(len(itm_rank)))
|
||||||
|
}
|
||||||
|
for i, key in zip(range(len(image_keys)), sorted_lists[index_text_query]):
|
||||||
|
if image_keys[key] in image_names:
|
||||||
|
self[image_keys[key]][
|
||||||
|
"itm " + list(search_query[index_text_query].values())[0]
|
||||||
|
] = image_names_with_itm[image_keys[key]]
|
||||||
|
self[image_keys[key]][
|
||||||
|
"itm_rank " + list(search_query[index_text_query].values())[0]
|
||||||
|
] = image_names_with_new_rank[image_keys[key]]
|
||||||
|
else:
|
||||||
|
self[image_keys[key]][
|
||||||
|
"itm " + list(search_query[index_text_query].values())[0]
|
||||||
|
] = 0
|
||||||
|
self[image_keys[key]][
|
||||||
|
"itm_rank " + list(search_query[index_text_query].values())[0]
|
||||||
|
] = None
|
||||||
|
|
||||||
|
avg_gradcams.append(local_avg_gradcams)
|
||||||
|
itm_scores.append(local_itm_scores)
|
||||||
|
itm_scores2.append(local_itm_scores2)
|
||||||
|
image_gradcam_with_itm[
|
||||||
|
list(search_query[index_text_query].values())[0]
|
||||||
|
] = localimage_gradcam_with_itm
|
||||||
|
return itm_scores2, image_gradcam_with_itm
|
||||||
|
|
||||||
|
def show_results(self, query, itm=False, image_gradcam_with_itm=False):
|
||||||
if "image" in query.keys():
|
if "image" in query.keys():
|
||||||
pic = Image.open(query["image"]).convert("RGB")
|
pic = Image.open(query["image"]).convert("RGB")
|
||||||
pic.thumbnail((400, 400))
|
pic.thumbnail((400, 400))
|
||||||
@ -338,18 +613,29 @@ class MultimodalSearch(AnalysisMethod):
|
|||||||
"--------------------------------------------------",
|
"--------------------------------------------------",
|
||||||
"Results:",
|
"Results:",
|
||||||
)
|
)
|
||||||
|
if itm:
|
||||||
|
current_querry_val = "itm " + list(query.values())[0]
|
||||||
|
current_querry_rank = "itm_rank " + list(query.values())[0]
|
||||||
|
else:
|
||||||
|
current_querry_val = list(query.values())[0]
|
||||||
|
current_querry_rank = "rank " + list(query.values())[0]
|
||||||
|
|
||||||
for s in sorted(
|
for s in sorted(
|
||||||
self.items(), key=lambda t: t[1][list(query.values())[0]], reverse=True
|
self.items(), key=lambda t: t[1][current_querry_val], reverse=True
|
||||||
):
|
):
|
||||||
if s[1]["rank " + list(query.values())[0]] is None:
|
if s[1][current_querry_rank] is None:
|
||||||
break
|
break
|
||||||
|
if image_gradcam_with_itm is False:
|
||||||
p1 = Image.open(s[1]["filename"]).convert("RGB")
|
p1 = Image.open(s[1]["filename"]).convert("RGB")
|
||||||
|
else:
|
||||||
|
image = image_gradcam_with_itm[list(query.values())[0]][s[0]]
|
||||||
|
p1 = Image.fromarray(image.astype("uint8"), "RGB")
|
||||||
p1.thumbnail((400, 400))
|
p1.thumbnail((400, 400))
|
||||||
display(
|
display(
|
||||||
"Rank: "
|
"Rank: "
|
||||||
+ str(s[1]["rank " + list(query.values())[0]])
|
+ str(s[1][current_querry_rank])
|
||||||
+ " Val: "
|
+ " Val: "
|
||||||
+ str(s[1][list(query.values())[0]]),
|
+ str(s[1][current_querry_val]),
|
||||||
s[0],
|
s[0],
|
||||||
p1,
|
p1,
|
||||||
)
|
)
|
||||||
|
|||||||
66
notebooks/multimodal_search.ipynb
сгенерированный
66
notebooks/multimodal_search.ipynb
сгенерированный
@ -192,7 +192,7 @@
|
|||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"similarity = ms.MultimodalSearch.multimodal_search(\n",
|
"similarity, sorted_lists = ms.MultimodalSearch.multimodal_search(\n",
|
||||||
" mydict,\n",
|
" mydict,\n",
|
||||||
" model,\n",
|
" model,\n",
|
||||||
" vis_processors,\n",
|
" vis_processors,\n",
|
||||||
@ -201,6 +201,7 @@
|
|||||||
" image_keys,\n",
|
" image_keys,\n",
|
||||||
" features_image_stacked,\n",
|
" features_image_stacked,\n",
|
||||||
" search_query3,\n",
|
" search_query3,\n",
|
||||||
|
" filter_number_of_images=20,\n",
|
||||||
")"
|
")"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
@ -237,7 +238,68 @@
|
|||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"ms.MultimodalSearch.show_results(mydict, search_query3[4])"
|
"ms.MultimodalSearch.show_results(\n",
|
||||||
|
" mydict,\n",
|
||||||
|
" search_query3[5],\n",
|
||||||
|
")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "0b750e9f-fe64-4028-9caf-52d7187462f1",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"For even better results, a slightly different approach has been prepared that can improve search results. It is quite resource-intensive, so it is applied after the main algorithm has found the most relevant images. This approach works only with text queries. Among the parameters you can choose 3 models: `\"blip_base\"`, `\"blip_large\"`, `\"blip2_coco\"`. If you get the Out of Memory error, try reducing the batch_size value (minimum = 1), which is the number of images being processed simultaneously. With the parameter `need_grad_cam = True/False` you can enable the calculation of the heat map of each image to be processed. Thus the `image_text_match_reordering` function calculates new similarity values and new ranks for each image. The resulting values are added to the general dictionary."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "b3af7b39-6d0d-4da3-9b8f-7dfd3f5779be",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"itm_model = \"blip_base\"\n",
|
||||||
|
"# itm_model = \"blip_large\"\n",
|
||||||
|
"# itm_model = \"blip2_coco\""
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "caf1f4ae-4b37-4954-800e-7120f0419de5",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"itm_scores, image_gradcam_with_itm = ms.MultimodalSearch.image_text_match_reordering(\n",
|
||||||
|
" mydict,\n",
|
||||||
|
" search_query3,\n",
|
||||||
|
" itm_model,\n",
|
||||||
|
" image_keys,\n",
|
||||||
|
" sorted_lists,\n",
|
||||||
|
" batch_size=1,\n",
|
||||||
|
" need_grad_cam=True,\n",
|
||||||
|
")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "9e98c150-5fab-4251-bce7-0d8fc7b385b9",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"Then using the same output function you can add the `ITM=True` arguments to output the new image order. You can also add the `image_gradcam_with_itm` argument to output the heat maps of the calculated images. "
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "6a829b99-5230-463a-8b11-30ffbb67fc3a",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"ms.MultimodalSearch.show_results(\n",
|
||||||
|
" mydict, search_query3[0], itm=True, image_gradcam_with_itm=image_gradcam_with_itm\n",
|
||||||
|
")"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
|||||||
Загрузка…
x
Ссылка в новой задаче
Block a user