зеркало из
https://github.com/ssciwr/AMMICO.git
synced 2025-10-29 13:06:04 +02:00
added new itm functionality
Этот коммит содержится в:
родитель
65d916921b
Коммит
65dbf28eef
@ -3,9 +3,14 @@ import torch
|
||||
import torch.nn.functional as Func
|
||||
import requests
|
||||
import lavis
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from skimage import transform as skimage_transform
|
||||
from scipy.ndimage import filters
|
||||
from matplotlib import pyplot as plt
|
||||
from IPython.display import display
|
||||
from lavis.models import load_model_and_preprocess
|
||||
from lavis.models import load_model_and_preprocess, load_model, BlipBase
|
||||
from lavis.processors import load_processor
|
||||
|
||||
|
||||
class MultimodalSearch(AnalysisMethod):
|
||||
@ -301,7 +306,6 @@ class MultimodalSearch(AnalysisMethod):
|
||||
|
||||
for q in range(len(search_query)):
|
||||
max_val = similarity[sorted_lists[q][0]][q].item()
|
||||
print(max_val)
|
||||
for i, key in zip(range(len(image_keys)), sorted_lists[q]):
|
||||
if (
|
||||
i < filter_number_of_images
|
||||
@ -322,7 +326,278 @@ class MultimodalSearch(AnalysisMethod):
|
||||
self[image_keys[key]][list(search_query[q].values())[0]] = 0
|
||||
return similarity, sorted_lists
|
||||
|
||||
def show_results(self, query):
|
||||
def itm_text_precessing(self, search_query):
|
||||
for query in search_query:
|
||||
if not (len(query) == 1) and (query in ("image", "text_input")):
|
||||
raise SyntaxError(
|
||||
'Each querry must contain either an "image" or a "text_input"'
|
||||
)
|
||||
text_query_index = []
|
||||
for i, query in zip(range(len(search_query)), search_query):
|
||||
if "text_input" in query.keys():
|
||||
text_query_index.append(i)
|
||||
|
||||
return text_query_index
|
||||
|
||||
def itm_images_processing(self, image_paths, vis_processor):
|
||||
raw_images = [MultimodalSearch.read_img(self, path) for path in image_paths]
|
||||
images = [vis_processor(r_img) for r_img in raw_images]
|
||||
images_tensors = torch.stack(images).to(MultimodalSearch.multimodal_device)
|
||||
|
||||
return raw_images, images_tensors
|
||||
|
||||
def get_pathes_from_query(self, query):
|
||||
paths = []
|
||||
image_names = []
|
||||
for s in sorted(
|
||||
self.items(), key=lambda t: t[1][list(query.values())[0]], reverse=True
|
||||
):
|
||||
if s[1]["rank " + list(query.values())[0]] is None:
|
||||
break
|
||||
paths.append(s[1]["filename"])
|
||||
image_names.append(s[0])
|
||||
return paths, image_names
|
||||
|
||||
def read_and_process_images_itm(self, image_paths, vis_processor):
|
||||
raw_images = [MultimodalSearch.read_img(self, path) for path in image_paths]
|
||||
images = [vis_processor(r_img) for r_img in raw_images]
|
||||
images_tensors = torch.stack(images).to(MultimodalSearch.multimodal_device)
|
||||
|
||||
return raw_images, images_tensors
|
||||
|
||||
def compute_gradcam_batch(
|
||||
self,
|
||||
itm_model_type,
|
||||
model,
|
||||
visual_input,
|
||||
text_input,
|
||||
tokenized_text,
|
||||
block_num=6,
|
||||
):
|
||||
if itm_model_type != "blip2_coco":
|
||||
model.text_encoder.base_model.base_model.encoder.layer[
|
||||
block_num
|
||||
].crossattention.self.save_attention = True
|
||||
|
||||
output = model(
|
||||
{"image": visual_input, "text_input": text_input}, match_head="itm"
|
||||
)
|
||||
loss = output[:, 1].sum()
|
||||
|
||||
model.zero_grad()
|
||||
loss.backward()
|
||||
with torch.no_grad():
|
||||
mask = tokenized_text.attention_mask.view(
|
||||
tokenized_text.attention_mask.size(0), 1, -1, 1, 1
|
||||
) # (bsz,1,token_len, 1,1)
|
||||
token_length = mask.sum() - 2
|
||||
token_length = token_length.cpu()
|
||||
# grads and cams [bsz, num_head, seq_len, image_patch]
|
||||
grads = model.text_encoder.base_model.base_model.encoder.layer[
|
||||
block_num
|
||||
].crossattention.self.get_attn_gradients()
|
||||
cams = model.text_encoder.base_model.base_model.encoder.layer[
|
||||
block_num
|
||||
].crossattention.self.get_attention_map()
|
||||
|
||||
# assume using vit large with 576 num image patch
|
||||
cams = (
|
||||
cams[:, :, :, 1:].reshape(visual_input.size(0), 12, -1, 24, 24) * mask
|
||||
)
|
||||
grads = (
|
||||
grads[:, :, :, 1:]
|
||||
.clamp(0)
|
||||
.reshape(visual_input.size(0), 12, -1, 24, 24)
|
||||
* mask
|
||||
)
|
||||
|
||||
gradcam = cams * grads
|
||||
# [enc token gradcam, average gradcam across token, gradcam for individual token]
|
||||
# gradcam = torch.cat((gradcam[0:1,:], gradcam[1:token_length+1, :].sum(dim=0, keepdim=True)/token_length, gradcam[1:, :]))
|
||||
gradcam = gradcam.mean(1).cpu().detach()
|
||||
gradcam = (
|
||||
gradcam[:, 1 : token_length + 1, :].sum(dim=1, keepdim=True)
|
||||
/ token_length
|
||||
)
|
||||
|
||||
return gradcam, output
|
||||
|
||||
def resize_img(self, raw_img):
|
||||
w, h = raw_img.size
|
||||
scaling_factor = 240 / w
|
||||
resized_image = raw_img.resize(
|
||||
(int(w * scaling_factor), int(h * scaling_factor))
|
||||
)
|
||||
return resized_image
|
||||
|
||||
def getAttMap(self, img, attMap, blur=True, overlap=True):
|
||||
attMap -= attMap.min()
|
||||
if attMap.max() > 0:
|
||||
attMap /= attMap.max()
|
||||
attMap = skimage_transform.resize(
|
||||
attMap, (img.shape[:2]), order=3, mode="constant"
|
||||
)
|
||||
if blur:
|
||||
attMap = filters.gaussian_filter(attMap, 0.02 * max(img.shape[:2]))
|
||||
attMap -= attMap.min()
|
||||
attMap /= attMap.max()
|
||||
cmap = plt.get_cmap("jet")
|
||||
attMapV = cmap(attMap)
|
||||
attMapV = np.delete(attMapV, 3, 2)
|
||||
if overlap:
|
||||
attMap = (
|
||||
1 * (1 - attMap**0.7).reshape(attMap.shape + (1,)) * img
|
||||
+ (attMap**0.7).reshape(attMap.shape + (1,)) * attMapV
|
||||
)
|
||||
return attMap
|
||||
|
||||
def upload_model_blip2_coco(self):
|
||||
|
||||
itm_model = load_model(
|
||||
"blip2_image_text_matching",
|
||||
"coco",
|
||||
is_eval=True,
|
||||
device=MultimodalSearch.multimodal_device,
|
||||
)
|
||||
vis_processor = load_processor("blip_image_eval").build(image_size=364)
|
||||
return itm_model, vis_processor
|
||||
|
||||
def upload_model_blip_base(self):
|
||||
|
||||
itm_model = load_model(
|
||||
"blip_image_text_matching",
|
||||
"base",
|
||||
is_eval=True,
|
||||
device=MultimodalSearch.multimodal_device,
|
||||
)
|
||||
vis_processor = load_processor("blip_image_eval").build(image_size=384)
|
||||
return itm_model, vis_processor
|
||||
|
||||
def upload_model_blip_large(self):
|
||||
|
||||
itm_model = load_model(
|
||||
"blip_image_text_matching",
|
||||
"large",
|
||||
is_eval=True,
|
||||
device=MultimodalSearch.multimodal_device,
|
||||
)
|
||||
vis_processor = load_processor("blip_image_eval").build(image_size=384)
|
||||
return itm_model, vis_processor
|
||||
|
||||
def image_text_match_reordering(
|
||||
self,
|
||||
search_query,
|
||||
itm_model_type,
|
||||
image_keys,
|
||||
sorted_lists,
|
||||
batch_size=1,
|
||||
need_grad_cam=False,
|
||||
):
|
||||
choose_model = {
|
||||
" ": MultimodalSearch.upload_model_blip_base,
|
||||
"blip_large": MultimodalSearch.upload_model_blip_large,
|
||||
"blip2_coco": MultimodalSearch.upload_model_blip2_coco,
|
||||
}
|
||||
itm_model, vis_processor = choose_model[itm_model_type](self)
|
||||
text_processor = load_processor("blip_caption")
|
||||
tokenizer = BlipBase.init_tokenizer()
|
||||
|
||||
text_query_index = MultimodalSearch.itm_text_precessing(self, search_query)
|
||||
|
||||
avg_gradcams = []
|
||||
itm_scores = []
|
||||
itm_scores2 = []
|
||||
image_gradcam_with_itm = {}
|
||||
|
||||
for index_text_query in text_query_index:
|
||||
query = search_query[index_text_query]
|
||||
pathes, image_names = MultimodalSearch.get_pathes_from_query(self, query)
|
||||
num_batches = int(len(pathes) / batch_size)
|
||||
num_batches_residue = len(pathes) % batch_size
|
||||
|
||||
local_itm_scores = []
|
||||
local_avg_gradcams = []
|
||||
|
||||
if num_batches_residue != 0:
|
||||
num_batches = num_batches + 1
|
||||
for i in range(num_batches):
|
||||
filenames_in_batch = pathes[i * batch_size : (i + 1) * batch_size]
|
||||
current_len = len(filenames_in_batch)
|
||||
raw_images, images = MultimodalSearch.read_and_process_images_itm(
|
||||
self, filenames_in_batch, vis_processor
|
||||
)
|
||||
queries_batch = [text_processor(query["text_input"])] * current_len
|
||||
queries_tok_batch = tokenizer(queries_batch, return_tensors="pt").to(
|
||||
MultimodalSearch.multimodal_device
|
||||
)
|
||||
|
||||
if need_grad_cam:
|
||||
gradcam, itm_output = MultimodalSearch.compute_gradcam_batch(
|
||||
self,
|
||||
itm_model_type,
|
||||
itm_model,
|
||||
images,
|
||||
queries_batch,
|
||||
queries_tok_batch,
|
||||
)
|
||||
norm_imgs = [np.float32(r_img) / 255 for r_img in raw_images]
|
||||
|
||||
for norm_img, grad_cam in zip(norm_imgs, gradcam):
|
||||
avg_gradcam = MultimodalSearch.getAttMap(
|
||||
self, norm_img, grad_cam[0], blur=True
|
||||
)
|
||||
local_avg_gradcams.append(avg_gradcam)
|
||||
|
||||
else:
|
||||
itm_output = itm_model(
|
||||
{"image": images, "text_input": queries_batch}, match_head="itm"
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
itm_score = torch.nn.functional.softmax(itm_output, dim=1)
|
||||
|
||||
local_itm_scores.append(itm_score)
|
||||
|
||||
local_itm_scores2 = torch.cat(local_itm_scores)[:, 1]
|
||||
if need_grad_cam:
|
||||
localimage_gradcam_with_itm = {
|
||||
n: i * 255 for n, i in zip(image_names, local_avg_gradcams)
|
||||
}
|
||||
else:
|
||||
localimage_gradcam_with_itm = ""
|
||||
image_names_with_itm = {
|
||||
n: i.item() for n, i in zip(image_names, local_itm_scores2)
|
||||
}
|
||||
itm_rank = torch.argsort(local_itm_scores2, descending=True)
|
||||
image_names_with_new_rank = {
|
||||
image_names[i.item()]: rank
|
||||
for i, rank in zip(itm_rank, range(len(itm_rank)))
|
||||
}
|
||||
for i, key in zip(range(len(image_keys)), sorted_lists[index_text_query]):
|
||||
if image_keys[key] in image_names:
|
||||
self[image_keys[key]][
|
||||
"itm " + list(search_query[index_text_query].values())[0]
|
||||
] = image_names_with_itm[image_keys[key]]
|
||||
self[image_keys[key]][
|
||||
"itm_rank " + list(search_query[index_text_query].values())[0]
|
||||
] = image_names_with_new_rank[image_keys[key]]
|
||||
else:
|
||||
self[image_keys[key]][
|
||||
"itm " + list(search_query[index_text_query].values())[0]
|
||||
] = 0
|
||||
self[image_keys[key]][
|
||||
"itm_rank " + list(search_query[index_text_query].values())[0]
|
||||
] = None
|
||||
|
||||
avg_gradcams.append(local_avg_gradcams)
|
||||
itm_scores.append(local_itm_scores)
|
||||
itm_scores2.append(local_itm_scores2)
|
||||
image_gradcam_with_itm[
|
||||
list(search_query[index_text_query].values())[0]
|
||||
] = localimage_gradcam_with_itm
|
||||
return itm_scores2, image_gradcam_with_itm
|
||||
|
||||
def show_results(self, query, itm=False, image_gradcam_with_itm=False):
|
||||
if "image" in query.keys():
|
||||
pic = Image.open(query["image"]).convert("RGB")
|
||||
pic.thumbnail((400, 400))
|
||||
@ -338,18 +613,29 @@ class MultimodalSearch(AnalysisMethod):
|
||||
"--------------------------------------------------",
|
||||
"Results:",
|
||||
)
|
||||
if itm:
|
||||
current_querry_val = "itm " + list(query.values())[0]
|
||||
current_querry_rank = "itm_rank " + list(query.values())[0]
|
||||
else:
|
||||
current_querry_val = list(query.values())[0]
|
||||
current_querry_rank = "rank " + list(query.values())[0]
|
||||
|
||||
for s in sorted(
|
||||
self.items(), key=lambda t: t[1][list(query.values())[0]], reverse=True
|
||||
self.items(), key=lambda t: t[1][current_querry_val], reverse=True
|
||||
):
|
||||
if s[1]["rank " + list(query.values())[0]] is None:
|
||||
if s[1][current_querry_rank] is None:
|
||||
break
|
||||
p1 = Image.open(s[1]["filename"]).convert("RGB")
|
||||
if image_gradcam_with_itm is False:
|
||||
p1 = Image.open(s[1]["filename"]).convert("RGB")
|
||||
else:
|
||||
image = image_gradcam_with_itm[list(query.values())[0]][s[0]]
|
||||
p1 = Image.fromarray(image.astype("uint8"), "RGB")
|
||||
p1.thumbnail((400, 400))
|
||||
display(
|
||||
"Rank: "
|
||||
+ str(s[1]["rank " + list(query.values())[0]])
|
||||
+ str(s[1][current_querry_rank])
|
||||
+ " Val: "
|
||||
+ str(s[1][list(query.values())[0]]),
|
||||
+ str(s[1][current_querry_val]),
|
||||
s[0],
|
||||
p1,
|
||||
)
|
||||
|
||||
66
notebooks/multimodal_search.ipynb
сгенерированный
66
notebooks/multimodal_search.ipynb
сгенерированный
@ -192,7 +192,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"similarity = ms.MultimodalSearch.multimodal_search(\n",
|
||||
"similarity, sorted_lists = ms.MultimodalSearch.multimodal_search(\n",
|
||||
" mydict,\n",
|
||||
" model,\n",
|
||||
" vis_processors,\n",
|
||||
@ -201,6 +201,7 @@
|
||||
" image_keys,\n",
|
||||
" features_image_stacked,\n",
|
||||
" search_query3,\n",
|
||||
" filter_number_of_images=20,\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
@ -237,7 +238,68 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"ms.MultimodalSearch.show_results(mydict, search_query3[4])"
|
||||
"ms.MultimodalSearch.show_results(\n",
|
||||
" mydict,\n",
|
||||
" search_query3[5],\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "0b750e9f-fe64-4028-9caf-52d7187462f1",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"For even better results, a slightly different approach has been prepared that can improve search results. It is quite resource-intensive, so it is applied after the main algorithm has found the most relevant images. This approach works only with text queries. Among the parameters you can choose 3 models: `\"blip_base\"`, `\"blip_large\"`, `\"blip2_coco\"`. If you get the Out of Memory error, try reducing the batch_size value (minimum = 1), which is the number of images being processed simultaneously. With the parameter `need_grad_cam = True/False` you can enable the calculation of the heat map of each image to be processed. Thus the `image_text_match_reordering` function calculates new similarity values and new ranks for each image. The resulting values are added to the general dictionary."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "b3af7b39-6d0d-4da3-9b8f-7dfd3f5779be",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"itm_model = \"blip_base\"\n",
|
||||
"# itm_model = \"blip_large\"\n",
|
||||
"# itm_model = \"blip2_coco\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "caf1f4ae-4b37-4954-800e-7120f0419de5",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"itm_scores, image_gradcam_with_itm = ms.MultimodalSearch.image_text_match_reordering(\n",
|
||||
" mydict,\n",
|
||||
" search_query3,\n",
|
||||
" itm_model,\n",
|
||||
" image_keys,\n",
|
||||
" sorted_lists,\n",
|
||||
" batch_size=1,\n",
|
||||
" need_grad_cam=True,\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "9e98c150-5fab-4251-bce7-0d8fc7b385b9",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Then using the same output function you can add the `ITM=True` arguments to output the new image order. You can also add the `image_gradcam_with_itm` argument to output the heat maps of the calculated images. "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "6a829b99-5230-463a-8b11-30ffbb67fc3a",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"ms.MultimodalSearch.show_results(\n",
|
||||
" mydict, search_query3[0], itm=True, image_gradcam_with_itm=image_gradcam_with_itm\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
||||
Загрузка…
x
Ссылка в новой задаче
Block a user