Этот коммит содержится в:
Petr Andriushchenko 2023-03-15 16:04:43 +01:00
родитель 65d916921b
Коммит 65dbf28eef
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4C4A5DCF634115B6
2 изменённых файлов: 358 добавлений и 10 удалений

Просмотреть файл

@ -3,9 +3,14 @@ import torch
import torch.nn.functional as Func import torch.nn.functional as Func
import requests import requests
import lavis import lavis
import numpy as np
from PIL import Image from PIL import Image
from skimage import transform as skimage_transform
from scipy.ndimage import filters
from matplotlib import pyplot as plt
from IPython.display import display from IPython.display import display
from lavis.models import load_model_and_preprocess from lavis.models import load_model_and_preprocess, load_model, BlipBase
from lavis.processors import load_processor
class MultimodalSearch(AnalysisMethod): class MultimodalSearch(AnalysisMethod):
@ -301,7 +306,6 @@ class MultimodalSearch(AnalysisMethod):
for q in range(len(search_query)): for q in range(len(search_query)):
max_val = similarity[sorted_lists[q][0]][q].item() max_val = similarity[sorted_lists[q][0]][q].item()
print(max_val)
for i, key in zip(range(len(image_keys)), sorted_lists[q]): for i, key in zip(range(len(image_keys)), sorted_lists[q]):
if ( if (
i < filter_number_of_images i < filter_number_of_images
@ -322,7 +326,278 @@ class MultimodalSearch(AnalysisMethod):
self[image_keys[key]][list(search_query[q].values())[0]] = 0 self[image_keys[key]][list(search_query[q].values())[0]] = 0
return similarity, sorted_lists return similarity, sorted_lists
def show_results(self, query): def itm_text_precessing(self, search_query):
for query in search_query:
if not (len(query) == 1) and (query in ("image", "text_input")):
raise SyntaxError(
'Each querry must contain either an "image" or a "text_input"'
)
text_query_index = []
for i, query in zip(range(len(search_query)), search_query):
if "text_input" in query.keys():
text_query_index.append(i)
return text_query_index
def itm_images_processing(self, image_paths, vis_processor):
raw_images = [MultimodalSearch.read_img(self, path) for path in image_paths]
images = [vis_processor(r_img) for r_img in raw_images]
images_tensors = torch.stack(images).to(MultimodalSearch.multimodal_device)
return raw_images, images_tensors
def get_pathes_from_query(self, query):
paths = []
image_names = []
for s in sorted(
self.items(), key=lambda t: t[1][list(query.values())[0]], reverse=True
):
if s[1]["rank " + list(query.values())[0]] is None:
break
paths.append(s[1]["filename"])
image_names.append(s[0])
return paths, image_names
def read_and_process_images_itm(self, image_paths, vis_processor):
raw_images = [MultimodalSearch.read_img(self, path) for path in image_paths]
images = [vis_processor(r_img) for r_img in raw_images]
images_tensors = torch.stack(images).to(MultimodalSearch.multimodal_device)
return raw_images, images_tensors
def compute_gradcam_batch(
self,
itm_model_type,
model,
visual_input,
text_input,
tokenized_text,
block_num=6,
):
if itm_model_type != "blip2_coco":
model.text_encoder.base_model.base_model.encoder.layer[
block_num
].crossattention.self.save_attention = True
output = model(
{"image": visual_input, "text_input": text_input}, match_head="itm"
)
loss = output[:, 1].sum()
model.zero_grad()
loss.backward()
with torch.no_grad():
mask = tokenized_text.attention_mask.view(
tokenized_text.attention_mask.size(0), 1, -1, 1, 1
) # (bsz,1,token_len, 1,1)
token_length = mask.sum() - 2
token_length = token_length.cpu()
# grads and cams [bsz, num_head, seq_len, image_patch]
grads = model.text_encoder.base_model.base_model.encoder.layer[
block_num
].crossattention.self.get_attn_gradients()
cams = model.text_encoder.base_model.base_model.encoder.layer[
block_num
].crossattention.self.get_attention_map()
# assume using vit large with 576 num image patch
cams = (
cams[:, :, :, 1:].reshape(visual_input.size(0), 12, -1, 24, 24) * mask
)
grads = (
grads[:, :, :, 1:]
.clamp(0)
.reshape(visual_input.size(0), 12, -1, 24, 24)
* mask
)
gradcam = cams * grads
# [enc token gradcam, average gradcam across token, gradcam for individual token]
# gradcam = torch.cat((gradcam[0:1,:], gradcam[1:token_length+1, :].sum(dim=0, keepdim=True)/token_length, gradcam[1:, :]))
gradcam = gradcam.mean(1).cpu().detach()
gradcam = (
gradcam[:, 1 : token_length + 1, :].sum(dim=1, keepdim=True)
/ token_length
)
return gradcam, output
def resize_img(self, raw_img):
w, h = raw_img.size
scaling_factor = 240 / w
resized_image = raw_img.resize(
(int(w * scaling_factor), int(h * scaling_factor))
)
return resized_image
def getAttMap(self, img, attMap, blur=True, overlap=True):
attMap -= attMap.min()
if attMap.max() > 0:
attMap /= attMap.max()
attMap = skimage_transform.resize(
attMap, (img.shape[:2]), order=3, mode="constant"
)
if blur:
attMap = filters.gaussian_filter(attMap, 0.02 * max(img.shape[:2]))
attMap -= attMap.min()
attMap /= attMap.max()
cmap = plt.get_cmap("jet")
attMapV = cmap(attMap)
attMapV = np.delete(attMapV, 3, 2)
if overlap:
attMap = (
1 * (1 - attMap**0.7).reshape(attMap.shape + (1,)) * img
+ (attMap**0.7).reshape(attMap.shape + (1,)) * attMapV
)
return attMap
def upload_model_blip2_coco(self):
itm_model = load_model(
"blip2_image_text_matching",
"coco",
is_eval=True,
device=MultimodalSearch.multimodal_device,
)
vis_processor = load_processor("blip_image_eval").build(image_size=364)
return itm_model, vis_processor
def upload_model_blip_base(self):
itm_model = load_model(
"blip_image_text_matching",
"base",
is_eval=True,
device=MultimodalSearch.multimodal_device,
)
vis_processor = load_processor("blip_image_eval").build(image_size=384)
return itm_model, vis_processor
def upload_model_blip_large(self):
itm_model = load_model(
"blip_image_text_matching",
"large",
is_eval=True,
device=MultimodalSearch.multimodal_device,
)
vis_processor = load_processor("blip_image_eval").build(image_size=384)
return itm_model, vis_processor
def image_text_match_reordering(
self,
search_query,
itm_model_type,
image_keys,
sorted_lists,
batch_size=1,
need_grad_cam=False,
):
choose_model = {
" ": MultimodalSearch.upload_model_blip_base,
"blip_large": MultimodalSearch.upload_model_blip_large,
"blip2_coco": MultimodalSearch.upload_model_blip2_coco,
}
itm_model, vis_processor = choose_model[itm_model_type](self)
text_processor = load_processor("blip_caption")
tokenizer = BlipBase.init_tokenizer()
text_query_index = MultimodalSearch.itm_text_precessing(self, search_query)
avg_gradcams = []
itm_scores = []
itm_scores2 = []
image_gradcam_with_itm = {}
for index_text_query in text_query_index:
query = search_query[index_text_query]
pathes, image_names = MultimodalSearch.get_pathes_from_query(self, query)
num_batches = int(len(pathes) / batch_size)
num_batches_residue = len(pathes) % batch_size
local_itm_scores = []
local_avg_gradcams = []
if num_batches_residue != 0:
num_batches = num_batches + 1
for i in range(num_batches):
filenames_in_batch = pathes[i * batch_size : (i + 1) * batch_size]
current_len = len(filenames_in_batch)
raw_images, images = MultimodalSearch.read_and_process_images_itm(
self, filenames_in_batch, vis_processor
)
queries_batch = [text_processor(query["text_input"])] * current_len
queries_tok_batch = tokenizer(queries_batch, return_tensors="pt").to(
MultimodalSearch.multimodal_device
)
if need_grad_cam:
gradcam, itm_output = MultimodalSearch.compute_gradcam_batch(
self,
itm_model_type,
itm_model,
images,
queries_batch,
queries_tok_batch,
)
norm_imgs = [np.float32(r_img) / 255 for r_img in raw_images]
for norm_img, grad_cam in zip(norm_imgs, gradcam):
avg_gradcam = MultimodalSearch.getAttMap(
self, norm_img, grad_cam[0], blur=True
)
local_avg_gradcams.append(avg_gradcam)
else:
itm_output = itm_model(
{"image": images, "text_input": queries_batch}, match_head="itm"
)
with torch.no_grad():
itm_score = torch.nn.functional.softmax(itm_output, dim=1)
local_itm_scores.append(itm_score)
local_itm_scores2 = torch.cat(local_itm_scores)[:, 1]
if need_grad_cam:
localimage_gradcam_with_itm = {
n: i * 255 for n, i in zip(image_names, local_avg_gradcams)
}
else:
localimage_gradcam_with_itm = ""
image_names_with_itm = {
n: i.item() for n, i in zip(image_names, local_itm_scores2)
}
itm_rank = torch.argsort(local_itm_scores2, descending=True)
image_names_with_new_rank = {
image_names[i.item()]: rank
for i, rank in zip(itm_rank, range(len(itm_rank)))
}
for i, key in zip(range(len(image_keys)), sorted_lists[index_text_query]):
if image_keys[key] in image_names:
self[image_keys[key]][
"itm " + list(search_query[index_text_query].values())[0]
] = image_names_with_itm[image_keys[key]]
self[image_keys[key]][
"itm_rank " + list(search_query[index_text_query].values())[0]
] = image_names_with_new_rank[image_keys[key]]
else:
self[image_keys[key]][
"itm " + list(search_query[index_text_query].values())[0]
] = 0
self[image_keys[key]][
"itm_rank " + list(search_query[index_text_query].values())[0]
] = None
avg_gradcams.append(local_avg_gradcams)
itm_scores.append(local_itm_scores)
itm_scores2.append(local_itm_scores2)
image_gradcam_with_itm[
list(search_query[index_text_query].values())[0]
] = localimage_gradcam_with_itm
return itm_scores2, image_gradcam_with_itm
def show_results(self, query, itm=False, image_gradcam_with_itm=False):
if "image" in query.keys(): if "image" in query.keys():
pic = Image.open(query["image"]).convert("RGB") pic = Image.open(query["image"]).convert("RGB")
pic.thumbnail((400, 400)) pic.thumbnail((400, 400))
@ -338,18 +613,29 @@ class MultimodalSearch(AnalysisMethod):
"--------------------------------------------------", "--------------------------------------------------",
"Results:", "Results:",
) )
if itm:
current_querry_val = "itm " + list(query.values())[0]
current_querry_rank = "itm_rank " + list(query.values())[0]
else:
current_querry_val = list(query.values())[0]
current_querry_rank = "rank " + list(query.values())[0]
for s in sorted( for s in sorted(
self.items(), key=lambda t: t[1][list(query.values())[0]], reverse=True self.items(), key=lambda t: t[1][current_querry_val], reverse=True
): ):
if s[1]["rank " + list(query.values())[0]] is None: if s[1][current_querry_rank] is None:
break break
if image_gradcam_with_itm is False:
p1 = Image.open(s[1]["filename"]).convert("RGB") p1 = Image.open(s[1]["filename"]).convert("RGB")
else:
image = image_gradcam_with_itm[list(query.values())[0]][s[0]]
p1 = Image.fromarray(image.astype("uint8"), "RGB")
p1.thumbnail((400, 400)) p1.thumbnail((400, 400))
display( display(
"Rank: " "Rank: "
+ str(s[1]["rank " + list(query.values())[0]]) + str(s[1][current_querry_rank])
+ " Val: " + " Val: "
+ str(s[1][list(query.values())[0]]), + str(s[1][current_querry_val]),
s[0], s[0],
p1, p1,
) )

66
notebooks/multimodal_search.ipynb сгенерированный
Просмотреть файл

@ -192,7 +192,7 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"similarity = ms.MultimodalSearch.multimodal_search(\n", "similarity, sorted_lists = ms.MultimodalSearch.multimodal_search(\n",
" mydict,\n", " mydict,\n",
" model,\n", " model,\n",
" vis_processors,\n", " vis_processors,\n",
@ -201,6 +201,7 @@
" image_keys,\n", " image_keys,\n",
" features_image_stacked,\n", " features_image_stacked,\n",
" search_query3,\n", " search_query3,\n",
" filter_number_of_images=20,\n",
")" ")"
] ]
}, },
@ -237,7 +238,68 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"ms.MultimodalSearch.show_results(mydict, search_query3[4])" "ms.MultimodalSearch.show_results(\n",
" mydict,\n",
" search_query3[5],\n",
")"
]
},
{
"cell_type": "markdown",
"id": "0b750e9f-fe64-4028-9caf-52d7187462f1",
"metadata": {},
"source": [
"For even better results, a slightly different approach has been prepared that can improve search results. It is quite resource-intensive, so it is applied after the main algorithm has found the most relevant images. This approach works only with text queries. Among the parameters you can choose 3 models: `\"blip_base\"`, `\"blip_large\"`, `\"blip2_coco\"`. If you get the Out of Memory error, try reducing the batch_size value (minimum = 1), which is the number of images being processed simultaneously. With the parameter `need_grad_cam = True/False` you can enable the calculation of the heat map of each image to be processed. Thus the `image_text_match_reordering` function calculates new similarity values and new ranks for each image. The resulting values are added to the general dictionary."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b3af7b39-6d0d-4da3-9b8f-7dfd3f5779be",
"metadata": {},
"outputs": [],
"source": [
"itm_model = \"blip_base\"\n",
"# itm_model = \"blip_large\"\n",
"# itm_model = \"blip2_coco\""
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "caf1f4ae-4b37-4954-800e-7120f0419de5",
"metadata": {},
"outputs": [],
"source": [
"itm_scores, image_gradcam_with_itm = ms.MultimodalSearch.image_text_match_reordering(\n",
" mydict,\n",
" search_query3,\n",
" itm_model,\n",
" image_keys,\n",
" sorted_lists,\n",
" batch_size=1,\n",
" need_grad_cam=True,\n",
")"
]
},
{
"cell_type": "markdown",
"id": "9e98c150-5fab-4251-bce7-0d8fc7b385b9",
"metadata": {},
"source": [
"Then using the same output function you can add the `ITM=True` arguments to output the new image order. You can also add the `image_gradcam_with_itm` argument to output the heat maps of the calculated images. "
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6a829b99-5230-463a-8b11-30ffbb67fc3a",
"metadata": {},
"outputs": [],
"source": [
"ms.MultimodalSearch.show_results(\n",
" mydict, search_query3[0], itm=True, image_gradcam_with_itm=image_gradcam_with_itm\n",
")"
] ]
}, },
{ {