зеркало из
https://github.com/ssciwr/AMMICO.git
synced 2025-10-30 13:36:04 +02:00
Коммит
80665e5f82
2
.github/workflows/ci.yml
поставляемый
2
.github/workflows/ci.yml
поставляемый
@ -32,7 +32,7 @@ jobs:
|
|||||||
- name: Run pytest
|
- name: Run pytest
|
||||||
run: |
|
run: |
|
||||||
cd misinformation
|
cd misinformation
|
||||||
python -m pytest -m "not gcv" -svv --cov=. --cov-report=xml
|
python -m pytest -m "not gcv and not long" -svv --cov=. --cov-report=xml
|
||||||
- name: Upload coverage
|
- name: Upload coverage
|
||||||
if: matrix.os == 'ubuntu-22.04' && matrix.python-version == '3.9'
|
if: matrix.os == 'ubuntu-22.04' && matrix.python-version == '3.9'
|
||||||
uses: codecov/codecov-action@v3
|
uses: codecov/codecov-action@v3
|
||||||
|
|||||||
@ -113,7 +113,7 @@
|
|||||||
" image_keys,\n",
|
" image_keys,\n",
|
||||||
" image_names,\n",
|
" image_names,\n",
|
||||||
" features_image_stacked,\n",
|
" features_image_stacked,\n",
|
||||||
") = ms.MultimodalSearch.parsing_images(mydict, model_type)"
|
") = ms.MultimodalSearch.parsing_images(mydict, model_type, path_to_saved_tensors=\".\")"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -128,7 +128,9 @@
|
|||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": null,
|
||||||
"id": "56c6d488-f093-4661-835a-5c73a329c874",
|
"id": "56c6d488-f093-4661-835a-5c73a329c874",
|
||||||
"metadata": {},
|
"metadata": {
|
||||||
|
"tags": []
|
||||||
|
},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"# (\n",
|
"# (\n",
|
||||||
|
|||||||
@ -4,9 +4,14 @@ import torch.nn.functional as Func
|
|||||||
import requests
|
import requests
|
||||||
import lavis
|
import lavis
|
||||||
import os
|
import os
|
||||||
|
import numpy as np
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
from skimage import transform as skimage_transform
|
||||||
|
from scipy.ndimage import filters
|
||||||
|
from matplotlib import pyplot as plt
|
||||||
from IPython.display import display
|
from IPython.display import display
|
||||||
from lavis.models import load_model_and_preprocess
|
from lavis.models import load_model_and_preprocess, load_model, BlipBase
|
||||||
|
from lavis.processors import load_processor
|
||||||
|
|
||||||
|
|
||||||
class MultimodalSearch(AnalysisMethod):
|
class MultimodalSearch(AnalysisMethod):
|
||||||
@ -233,7 +238,7 @@ class MultimodalSearch(AnalysisMethod):
|
|||||||
}
|
}
|
||||||
|
|
||||||
for query in search_query:
|
for query in search_query:
|
||||||
if not (len(query) == 1) and (query in ("image", "text_input")):
|
if len(query) != 1 and (query in ("image", "text_input")):
|
||||||
raise SyntaxError(
|
raise SyntaxError(
|
||||||
'Each query must contain either an "image" or a "text_input"'
|
'Each query must contain either an "image" or a "text_input"'
|
||||||
)
|
)
|
||||||
@ -343,7 +348,288 @@ class MultimodalSearch(AnalysisMethod):
|
|||||||
self[image_keys[key]][list(search_query[q].values())[0]] = 0
|
self[image_keys[key]][list(search_query[q].values())[0]] = 0
|
||||||
return similarity, sorted_lists
|
return similarity, sorted_lists
|
||||||
|
|
||||||
def show_results(self, query):
|
def itm_text_precessing(self, search_query):
|
||||||
|
for query in search_query:
|
||||||
|
if not (len(query) == 1) and (query in ("image", "text_input")):
|
||||||
|
raise SyntaxError(
|
||||||
|
'Each querry must contain either an "image" or a "text_input"'
|
||||||
|
)
|
||||||
|
text_query_index = []
|
||||||
|
for i, query in zip(range(len(search_query)), search_query):
|
||||||
|
if "text_input" in query.keys():
|
||||||
|
text_query_index.append(i)
|
||||||
|
|
||||||
|
return text_query_index
|
||||||
|
|
||||||
|
def get_pathes_from_query(self, query):
|
||||||
|
paths = []
|
||||||
|
image_names = []
|
||||||
|
for s in sorted(
|
||||||
|
self.items(), key=lambda t: t[1][list(query.values())[0]], reverse=True
|
||||||
|
):
|
||||||
|
if s[1]["rank " + list(query.values())[0]] is None:
|
||||||
|
break
|
||||||
|
paths.append(s[1]["filename"])
|
||||||
|
image_names.append(s[0])
|
||||||
|
return paths, image_names
|
||||||
|
|
||||||
|
def read_and_process_images_itm(self, image_paths, vis_processor):
|
||||||
|
raw_images = [MultimodalSearch.read_img(self, path) for path in image_paths]
|
||||||
|
images = [vis_processor(r_img) for r_img in raw_images]
|
||||||
|
images_tensors = torch.stack(images).to(MultimodalSearch.multimodal_device)
|
||||||
|
|
||||||
|
return raw_images, images_tensors
|
||||||
|
|
||||||
|
def compute_gradcam_batch(
|
||||||
|
self,
|
||||||
|
model,
|
||||||
|
visual_input,
|
||||||
|
text_input,
|
||||||
|
tokenized_text,
|
||||||
|
block_num=6,
|
||||||
|
):
|
||||||
|
model.text_encoder.base_model.base_model.encoder.layer[
|
||||||
|
block_num
|
||||||
|
].crossattention.self.save_attention = True
|
||||||
|
|
||||||
|
output = model(
|
||||||
|
{"image": visual_input, "text_input": text_input}, match_head="itm"
|
||||||
|
)
|
||||||
|
loss = output[:, 1].sum()
|
||||||
|
|
||||||
|
model.zero_grad()
|
||||||
|
loss.backward()
|
||||||
|
with torch.no_grad():
|
||||||
|
mask = tokenized_text.attention_mask.view(
|
||||||
|
tokenized_text.attention_mask.size(0), 1, -1, 1, 1
|
||||||
|
) # (bsz,1,token_len, 1,1)
|
||||||
|
token_length = mask.sum() - 2
|
||||||
|
token_length = token_length.cpu()
|
||||||
|
# grads and cams [bsz, num_head, seq_len, image_patch]
|
||||||
|
grads = model.text_encoder.base_model.base_model.encoder.layer[
|
||||||
|
block_num
|
||||||
|
].crossattention.self.get_attn_gradients()
|
||||||
|
cams = model.text_encoder.base_model.base_model.encoder.layer[
|
||||||
|
block_num
|
||||||
|
].crossattention.self.get_attention_map()
|
||||||
|
|
||||||
|
# assume using vit large with 576 num image patch
|
||||||
|
cams = (
|
||||||
|
cams[:, :, :, 1:].reshape(visual_input.size(0), 12, -1, 24, 24) * mask
|
||||||
|
)
|
||||||
|
grads = (
|
||||||
|
grads[:, :, :, 1:]
|
||||||
|
.clamp(0)
|
||||||
|
.reshape(visual_input.size(0), 12, -1, 24, 24)
|
||||||
|
* mask
|
||||||
|
)
|
||||||
|
|
||||||
|
gradcam = cams * grads
|
||||||
|
# [enc token gradcam, average gradcam across token, gradcam for individual token]
|
||||||
|
# gradcam = torch.cat((gradcam[0:1,:], gradcam[1:token_length+1, :].sum(dim=0, keepdim=True)/token_length, gradcam[1:, :]))
|
||||||
|
gradcam = gradcam.mean(1).cpu().detach()
|
||||||
|
gradcam = (
|
||||||
|
gradcam[:, 1 : token_length + 1, :].sum(dim=1, keepdim=True)
|
||||||
|
/ token_length
|
||||||
|
)
|
||||||
|
|
||||||
|
return gradcam, output
|
||||||
|
|
||||||
|
def resize_img(self, raw_img):
|
||||||
|
w, h = raw_img.size
|
||||||
|
scaling_factor = 240 / w
|
||||||
|
resized_image = raw_img.resize(
|
||||||
|
(int(w * scaling_factor), int(h * scaling_factor))
|
||||||
|
)
|
||||||
|
return resized_image
|
||||||
|
|
||||||
|
def get_att_map(self, img, att_map, blur=True, overlap=True):
|
||||||
|
att_map -= att_map.min()
|
||||||
|
if att_map.max() > 0:
|
||||||
|
att_map /= att_map.max()
|
||||||
|
att_map = skimage_transform.resize(
|
||||||
|
att_map, (img.shape[:2]), order=3, mode="constant"
|
||||||
|
)
|
||||||
|
if blur:
|
||||||
|
att_map = filters.gaussian_filter(att_map, 0.02 * max(img.shape[:2]))
|
||||||
|
att_map -= att_map.min()
|
||||||
|
att_map /= att_map.max()
|
||||||
|
cmap = plt.get_cmap("jet")
|
||||||
|
att_mapv = cmap(att_map)
|
||||||
|
att_mapv = np.delete(att_mapv, 3, 2)
|
||||||
|
if overlap:
|
||||||
|
att_map = (
|
||||||
|
1 * (1 - att_map**0.7).reshape(att_map.shape + (1,)) * img
|
||||||
|
+ (att_map**0.7).reshape(att_map.shape + (1,)) * att_mapv
|
||||||
|
)
|
||||||
|
return att_map
|
||||||
|
|
||||||
|
def upload_model_blip2_coco(self):
|
||||||
|
itm_model = load_model(
|
||||||
|
"blip2_image_text_matching",
|
||||||
|
"coco",
|
||||||
|
is_eval=True,
|
||||||
|
device=MultimodalSearch.multimodal_device,
|
||||||
|
)
|
||||||
|
vis_processor = load_processor("blip_image_eval").build(image_size=364)
|
||||||
|
return itm_model, vis_processor
|
||||||
|
|
||||||
|
def upload_model_blip_base(self):
|
||||||
|
itm_model = load_model(
|
||||||
|
"blip_image_text_matching",
|
||||||
|
"base",
|
||||||
|
is_eval=True,
|
||||||
|
device=MultimodalSearch.multimodal_device,
|
||||||
|
)
|
||||||
|
vis_processor = load_processor("blip_image_eval").build(image_size=384)
|
||||||
|
return itm_model, vis_processor
|
||||||
|
|
||||||
|
def upload_model_blip_large(self):
|
||||||
|
itm_model = load_model(
|
||||||
|
"blip_image_text_matching",
|
||||||
|
"large",
|
||||||
|
is_eval=True,
|
||||||
|
device=MultimodalSearch.multimodal_device,
|
||||||
|
)
|
||||||
|
vis_processor = load_processor("blip_image_eval").build(image_size=384)
|
||||||
|
return itm_model, vis_processor
|
||||||
|
|
||||||
|
def image_text_match_reordering(
|
||||||
|
self,
|
||||||
|
search_query,
|
||||||
|
itm_model_type,
|
||||||
|
image_keys,
|
||||||
|
sorted_lists,
|
||||||
|
batch_size=1,
|
||||||
|
need_grad_cam=False,
|
||||||
|
):
|
||||||
|
if itm_model_type == "blip2_coco" and need_grad_cam is True:
|
||||||
|
raise SyntaxError(
|
||||||
|
"The blip2_coco model does not yet work with gradcam. Please set need_grad_cam to False"
|
||||||
|
)
|
||||||
|
|
||||||
|
choose_model = {
|
||||||
|
"blip_base": MultimodalSearch.upload_model_blip_base,
|
||||||
|
"blip_large": MultimodalSearch.upload_model_blip_large,
|
||||||
|
"blip2_coco": MultimodalSearch.upload_model_blip2_coco,
|
||||||
|
}
|
||||||
|
|
||||||
|
itm_model, vis_processor_itm = choose_model[itm_model_type](self)
|
||||||
|
text_processor = load_processor("blip_caption")
|
||||||
|
tokenizer = BlipBase.init_tokenizer()
|
||||||
|
|
||||||
|
if itm_model_type == "blip2_coco":
|
||||||
|
need_grad_cam = False
|
||||||
|
|
||||||
|
text_query_index = MultimodalSearch.itm_text_precessing(self, search_query)
|
||||||
|
|
||||||
|
avg_gradcams = []
|
||||||
|
itm_scores = []
|
||||||
|
itm_scores2 = []
|
||||||
|
image_gradcam_with_itm = {}
|
||||||
|
|
||||||
|
for index_text_query in text_query_index:
|
||||||
|
query = search_query[index_text_query]
|
||||||
|
pathes, image_names = MultimodalSearch.get_pathes_from_query(self, query)
|
||||||
|
num_batches = int(len(pathes) / batch_size)
|
||||||
|
num_batches_residue = len(pathes) % batch_size
|
||||||
|
|
||||||
|
local_itm_scores = []
|
||||||
|
local_avg_gradcams = []
|
||||||
|
|
||||||
|
if num_batches_residue != 0:
|
||||||
|
num_batches = num_batches + 1
|
||||||
|
for i in range(num_batches):
|
||||||
|
filenames_in_batch = pathes[i * batch_size : (i + 1) * batch_size]
|
||||||
|
current_len = len(filenames_in_batch)
|
||||||
|
raw_images, images = MultimodalSearch.read_and_process_images_itm(
|
||||||
|
self, filenames_in_batch, vis_processor_itm
|
||||||
|
)
|
||||||
|
queries_batch = [text_processor(query["text_input"])] * current_len
|
||||||
|
queries_tok_batch = tokenizer(queries_batch, return_tensors="pt").to(
|
||||||
|
MultimodalSearch.multimodal_device
|
||||||
|
)
|
||||||
|
|
||||||
|
if need_grad_cam:
|
||||||
|
gradcam, itm_output = MultimodalSearch.compute_gradcam_batch(
|
||||||
|
self,
|
||||||
|
itm_model,
|
||||||
|
images,
|
||||||
|
queries_batch,
|
||||||
|
queries_tok_batch,
|
||||||
|
)
|
||||||
|
norm_imgs = [np.float32(r_img) / 255 for r_img in raw_images]
|
||||||
|
|
||||||
|
for norm_img, grad_cam in zip(norm_imgs, gradcam):
|
||||||
|
avg_gradcam = MultimodalSearch.get_att_map(
|
||||||
|
self, norm_img, np.float32(grad_cam[0]), blur=True
|
||||||
|
)
|
||||||
|
local_avg_gradcams.append(avg_gradcam)
|
||||||
|
|
||||||
|
else:
|
||||||
|
itm_output = itm_model(
|
||||||
|
{"image": images, "text_input": queries_batch}, match_head="itm"
|
||||||
|
)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
itm_score = torch.nn.functional.softmax(itm_output, dim=1)
|
||||||
|
|
||||||
|
local_itm_scores.append(itm_score)
|
||||||
|
|
||||||
|
local_itm_scores2 = torch.cat(local_itm_scores)[:, 1]
|
||||||
|
if need_grad_cam:
|
||||||
|
localimage_gradcam_with_itm = {
|
||||||
|
n: i * 255 for n, i in zip(image_names, local_avg_gradcams)
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
localimage_gradcam_with_itm = ""
|
||||||
|
image_names_with_itm = {
|
||||||
|
n: i.item() for n, i in zip(image_names, local_itm_scores2)
|
||||||
|
}
|
||||||
|
itm_rank = torch.argsort(local_itm_scores2, descending=True)
|
||||||
|
image_names_with_new_rank = {
|
||||||
|
image_names[i.item()]: rank
|
||||||
|
for i, rank in zip(itm_rank, range(len(itm_rank)))
|
||||||
|
}
|
||||||
|
for i, key in zip(range(len(image_keys)), sorted_lists[index_text_query]):
|
||||||
|
if image_keys[key] in image_names:
|
||||||
|
self[image_keys[key]][
|
||||||
|
"itm " + list(search_query[index_text_query].values())[0]
|
||||||
|
] = image_names_with_itm[image_keys[key]]
|
||||||
|
self[image_keys[key]][
|
||||||
|
"itm_rank " + list(search_query[index_text_query].values())[0]
|
||||||
|
] = image_names_with_new_rank[image_keys[key]]
|
||||||
|
else:
|
||||||
|
self[image_keys[key]][
|
||||||
|
"itm " + list(search_query[index_text_query].values())[0]
|
||||||
|
] = 0
|
||||||
|
self[image_keys[key]][
|
||||||
|
"itm_rank " + list(search_query[index_text_query].values())[0]
|
||||||
|
] = None
|
||||||
|
|
||||||
|
avg_gradcams.append(local_avg_gradcams)
|
||||||
|
itm_scores.append(local_itm_scores)
|
||||||
|
itm_scores2.append(local_itm_scores2)
|
||||||
|
image_gradcam_with_itm[
|
||||||
|
list(search_query[index_text_query].values())[0]
|
||||||
|
] = localimage_gradcam_with_itm
|
||||||
|
del (
|
||||||
|
itm_model,
|
||||||
|
vis_processor_itm,
|
||||||
|
text_processor,
|
||||||
|
raw_images,
|
||||||
|
images,
|
||||||
|
tokenizer,
|
||||||
|
queries_batch,
|
||||||
|
queries_tok_batch,
|
||||||
|
itm_score,
|
||||||
|
)
|
||||||
|
if need_grad_cam:
|
||||||
|
del itm_output, gradcam, norm_img, grad_cam, avg_gradcam
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
return itm_scores2, image_gradcam_with_itm
|
||||||
|
|
||||||
|
def show_results(self, query, itm=False, image_gradcam_with_itm=False):
|
||||||
if "image" in query.keys():
|
if "image" in query.keys():
|
||||||
pic = Image.open(query["image"]).convert("RGB")
|
pic = Image.open(query["image"]).convert("RGB")
|
||||||
pic.thumbnail((400, 400))
|
pic.thumbnail((400, 400))
|
||||||
@ -359,18 +645,29 @@ class MultimodalSearch(AnalysisMethod):
|
|||||||
"--------------------------------------------------",
|
"--------------------------------------------------",
|
||||||
"Results:",
|
"Results:",
|
||||||
)
|
)
|
||||||
|
if itm:
|
||||||
|
current_querry_val = "itm " + list(query.values())[0]
|
||||||
|
current_querry_rank = "itm_rank " + list(query.values())[0]
|
||||||
|
else:
|
||||||
|
current_querry_val = list(query.values())[0]
|
||||||
|
current_querry_rank = "rank " + list(query.values())[0]
|
||||||
|
|
||||||
for s in sorted(
|
for s in sorted(
|
||||||
self.items(), key=lambda t: t[1][list(query.values())[0]], reverse=True
|
self.items(), key=lambda t: t[1][current_querry_val], reverse=True
|
||||||
):
|
):
|
||||||
if s[1]["rank " + list(query.values())[0]] is None:
|
if s[1][current_querry_rank] is None:
|
||||||
break
|
break
|
||||||
p1 = Image.open(s[1]["filename"]).convert("RGB")
|
if image_gradcam_with_itm is False:
|
||||||
|
p1 = Image.open(s[1]["filename"]).convert("RGB")
|
||||||
|
else:
|
||||||
|
image = image_gradcam_with_itm[list(query.values())[0]][s[0]]
|
||||||
|
p1 = Image.fromarray(image.astype("uint8"), "RGB")
|
||||||
p1.thumbnail((400, 400))
|
p1.thumbnail((400, 400))
|
||||||
display(
|
display(
|
||||||
"Rank: "
|
"Rank: "
|
||||||
+ str(s[1]["rank " + list(query.values())[0]])
|
+ str(s[1][current_querry_rank])
|
||||||
+ " Val: "
|
+ " Val: "
|
||||||
+ str(s[1][list(query.values())[0]]),
|
+ str(s[1][current_querry_val]),
|
||||||
s[0],
|
s[0],
|
||||||
p1,
|
p1,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -16,3 +16,33 @@ def set_environ(request):
|
|||||||
mypath + "/../../data/seismic-bonfire-329406-412821a70264.json"
|
mypath + "/../../data/seismic-bonfire-329406-412821a70264.json"
|
||||||
)
|
)
|
||||||
print(os.environ.get("GOOGLE_APPLICATION_CREDENTIALS"))
|
print(os.environ.get("GOOGLE_APPLICATION_CREDENTIALS"))
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def get_testdict(get_path):
|
||||||
|
testdict = {
|
||||||
|
"IMG_2746": {"filename": get_path + "IMG_2746.png"},
|
||||||
|
"IMG_2809": {"filename": get_path + "IMG_2809.png"},
|
||||||
|
}
|
||||||
|
return testdict
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def get_test_my_dict(get_path):
|
||||||
|
test_my_dict = {
|
||||||
|
"IMG_2746": {
|
||||||
|
"filename": get_path + "IMG_2746.png",
|
||||||
|
"rank A bus": 1,
|
||||||
|
"A bus": 0.15640679001808167,
|
||||||
|
"rank " + get_path + "IMG_3758.png": 1,
|
||||||
|
get_path + "IMG_3758.png": 0.7533495426177979,
|
||||||
|
},
|
||||||
|
"IMG_2809": {
|
||||||
|
"filename": get_path + "IMG_2809.png",
|
||||||
|
"rank A bus": 0,
|
||||||
|
"A bus": 0.1970970332622528,
|
||||||
|
"rank " + get_path + "IMG_3758.png": 0,
|
||||||
|
get_path + "IMG_3758.png": 0.8907483816146851,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
return test_my_dict
|
||||||
|
|||||||
@ -5,22 +5,17 @@ import numpy
|
|||||||
from torch import device, cuda
|
from torch import device, cuda
|
||||||
import misinformation.multimodal_search as ms
|
import misinformation.multimodal_search as ms
|
||||||
|
|
||||||
|
|
||||||
testdict = {
|
|
||||||
"IMG_2746": {"filename": "./test/data/IMG_2746.png"},
|
|
||||||
"IMG_2809": {"filename": "./test/data/IMG_2809.png"},
|
|
||||||
}
|
|
||||||
|
|
||||||
related_error = 1e-2
|
related_error = 1e-2
|
||||||
gpu_is_not_available = not cuda.is_available()
|
gpu_is_not_available = not cuda.is_available()
|
||||||
|
|
||||||
|
|
||||||
cuda.empty_cache()
|
cuda.empty_cache()
|
||||||
|
|
||||||
|
|
||||||
def test_read_img():
|
def test_read_img(get_testdict):
|
||||||
my_dict = {}
|
my_dict = {}
|
||||||
test_img = ms.MultimodalSearch.read_img(my_dict, testdict["IMG_2746"]["filename"])
|
test_img = ms.MultimodalSearch.read_img(
|
||||||
|
my_dict, get_testdict["IMG_2746"]["filename"]
|
||||||
|
)
|
||||||
assert list(numpy.array(test_img)[257][34]) == [70, 66, 63]
|
assert list(numpy.array(test_img)[257][34]) == [70, 66, 63]
|
||||||
|
|
||||||
|
|
||||||
@ -205,29 +200,29 @@ dict_image_gradcam_with_itm_for_blip = {
|
|||||||
"pre_sorted",
|
"pre_sorted",
|
||||||
),
|
),
|
||||||
[
|
[
|
||||||
# (
|
(
|
||||||
# device("cpu"),
|
device("cpu"),
|
||||||
# "blip2",
|
"blip2",
|
||||||
# pre_proc_pic_blip2_blip_albef,
|
pre_proc_pic_blip2_blip_albef,
|
||||||
# pre_proc_text_blip2_blip_albef,
|
pre_proc_text_blip2_blip_albef,
|
||||||
# pre_extracted_feature_img_blip2,
|
pre_extracted_feature_img_blip2,
|
||||||
# pre_extracted_feature_text_blip2,
|
pre_extracted_feature_text_blip2,
|
||||||
# simularity_blip2,
|
simularity_blip2,
|
||||||
# sorted_blip2,
|
sorted_blip2,
|
||||||
# ),
|
),
|
||||||
# pytest.param(
|
pytest.param(
|
||||||
# device("cuda"),
|
device("cuda"),
|
||||||
# "blip2",
|
"blip2",
|
||||||
# pre_proc_pic_blip2_blip_albef,
|
pre_proc_pic_blip2_blip_albef,
|
||||||
# pre_proc_text_blip2_blip_albef,
|
pre_proc_text_blip2_blip_albef,
|
||||||
# pre_extracted_feature_img_blip2,
|
pre_extracted_feature_img_blip2,
|
||||||
# pre_extracted_feature_text_blip2,
|
pre_extracted_feature_text_blip2,
|
||||||
# simularity_blip2,
|
simularity_blip2,
|
||||||
# sorted_blip2,
|
sorted_blip2,
|
||||||
# marks=pytest.mark.skipif(
|
marks=pytest.mark.skipif(
|
||||||
# gpu_is_not_available, reason="gpu_is_not_availible"
|
gpu_is_not_available, reason="gpu_is_not_availible"
|
||||||
# ),
|
),
|
||||||
# ),
|
),
|
||||||
(
|
(
|
||||||
device("cpu"),
|
device("cpu"),
|
||||||
"blip",
|
"blip",
|
||||||
@ -354,6 +349,8 @@ def test_parsing_images(
|
|||||||
pre_extracted_feature_text,
|
pre_extracted_feature_text,
|
||||||
pre_simularity,
|
pre_simularity,
|
||||||
pre_sorted,
|
pre_sorted,
|
||||||
|
get_path,
|
||||||
|
get_testdict,
|
||||||
tmp_path,
|
tmp_path,
|
||||||
):
|
):
|
||||||
ms.MultimodalSearch.multimodal_device = pre_multimodal_device
|
ms.MultimodalSearch.multimodal_device = pre_multimodal_device
|
||||||
@ -365,7 +362,7 @@ def test_parsing_images(
|
|||||||
_,
|
_,
|
||||||
features_image_stacked,
|
features_image_stacked,
|
||||||
) = ms.MultimodalSearch.parsing_images(
|
) = ms.MultimodalSearch.parsing_images(
|
||||||
testdict, pre_model, path_to_saved_tensors=tmp_path
|
get_testdict, pre_model, path_to_saved_tensors=tmp_path
|
||||||
)
|
)
|
||||||
|
|
||||||
for i, num in zip(range(10), features_image_stacked[0, 10:12].tolist()):
|
for i, num in zip(range(10), features_image_stacked[0, 10:12].tolist()):
|
||||||
@ -374,7 +371,7 @@ def test_parsing_images(
|
|||||||
is True
|
is True
|
||||||
)
|
)
|
||||||
|
|
||||||
test_pic = Image.open(testdict["IMG_2746"]["filename"]).convert("RGB")
|
test_pic = Image.open(get_testdict["IMG_2746"]["filename"]).convert("RGB")
|
||||||
test_querry = (
|
test_querry = (
|
||||||
"The bird sat on a tree located at the intersection of 23rd and 43rd streets."
|
"The bird sat on a tree located at the intersection of 23rd and 43rd streets."
|
||||||
)
|
)
|
||||||
@ -390,10 +387,10 @@ def test_parsing_images(
|
|||||||
|
|
||||||
search_query = [
|
search_query = [
|
||||||
{"text_input": test_querry},
|
{"text_input": test_querry},
|
||||||
{"image": testdict["IMG_2746"]["filename"]},
|
{"image": get_testdict["IMG_2746"]["filename"]},
|
||||||
]
|
]
|
||||||
multi_features_stacked = ms.MultimodalSearch.querys_processing(
|
multi_features_stacked = ms.MultimodalSearch.querys_processing(
|
||||||
testdict, search_query, model, txt_processor, vis_processor, pre_model
|
get_testdict, search_query, model, txt_processor, vis_processor, pre_model
|
||||||
)
|
)
|
||||||
|
|
||||||
for i, num in zip(range(10), multi_features_stacked[0, 10:12].tolist()):
|
for i, num in zip(range(10), multi_features_stacked[0, 10:12].tolist()):
|
||||||
@ -410,11 +407,11 @@ def test_parsing_images(
|
|||||||
|
|
||||||
search_query2 = [
|
search_query2 = [
|
||||||
{"text_input": "A bus"},
|
{"text_input": "A bus"},
|
||||||
{"image": "../misinformation/test/data/IMG_3758.png"},
|
{"image": get_path + "IMG_3758.png"},
|
||||||
]
|
]
|
||||||
|
|
||||||
similarity, sorted_list = ms.MultimodalSearch.multimodal_search(
|
similarity, sorted_list = ms.MultimodalSearch.multimodal_search(
|
||||||
testdict,
|
get_testdict,
|
||||||
model,
|
model,
|
||||||
vis_processor,
|
vis_processor,
|
||||||
txt_processor,
|
txt_processor,
|
||||||
@ -445,3 +442,81 @@ def test_parsing_images(
|
|||||||
multi_features_stacked,
|
multi_features_stacked,
|
||||||
)
|
)
|
||||||
cuda.empty_cache()
|
cuda.empty_cache()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.long
|
||||||
|
def test_itm(get_test_my_dict, get_path):
|
||||||
|
search_query3 = [
|
||||||
|
{"text_input": "A bus"},
|
||||||
|
{"image": get_path + "IMG_3758.png"},
|
||||||
|
]
|
||||||
|
image_keys = ["IMG_2746", "IMG_2809"]
|
||||||
|
sorted_list = [[1, 0], [1, 0]]
|
||||||
|
for itm_model in ["blip_base", "blip_large"]:
|
||||||
|
(
|
||||||
|
itm_scores,
|
||||||
|
image_gradcam_with_itm,
|
||||||
|
) = ms.MultimodalSearch.image_text_match_reordering(
|
||||||
|
get_test_my_dict,
|
||||||
|
search_query3,
|
||||||
|
itm_model,
|
||||||
|
image_keys,
|
||||||
|
sorted_list,
|
||||||
|
batch_size=1,
|
||||||
|
need_grad_cam=True,
|
||||||
|
)
|
||||||
|
for i, itm in zip(
|
||||||
|
range(len(dict_itm_scores_for_blib[itm_model])),
|
||||||
|
dict_itm_scores_for_blib[itm_model],
|
||||||
|
):
|
||||||
|
assert (
|
||||||
|
math.isclose(itm_scores[0].tolist()[i], itm, rel_tol=10 * related_error)
|
||||||
|
is True
|
||||||
|
)
|
||||||
|
for i, grad_cam in zip(
|
||||||
|
range(len(dict_image_gradcam_with_itm_for_blip[itm_model])),
|
||||||
|
dict_image_gradcam_with_itm_for_blip[itm_model],
|
||||||
|
):
|
||||||
|
assert (
|
||||||
|
math.isclose(
|
||||||
|
image_gradcam_with_itm["A bus"]["IMG_2809"][0][0].tolist()[i],
|
||||||
|
grad_cam,
|
||||||
|
rel_tol=10 * related_error,
|
||||||
|
)
|
||||||
|
is True
|
||||||
|
)
|
||||||
|
del itm_scores, image_gradcam_with_itm
|
||||||
|
cuda.empty_cache()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.long
|
||||||
|
def test_itm_blip2_coco(get_test_my_dict, get_path):
|
||||||
|
search_query3 = [
|
||||||
|
{"text_input": "A bus"},
|
||||||
|
{"image": get_path + "IMG_3758.png"},
|
||||||
|
]
|
||||||
|
image_keys = ["IMG_2746", "IMG_2809"]
|
||||||
|
sorted_list = [[1, 0], [1, 0]]
|
||||||
|
|
||||||
|
(
|
||||||
|
itm_scores,
|
||||||
|
image_gradcam_with_itm,
|
||||||
|
) = ms.MultimodalSearch.image_text_match_reordering(
|
||||||
|
get_test_my_dict,
|
||||||
|
search_query3,
|
||||||
|
"blip2_coco",
|
||||||
|
image_keys,
|
||||||
|
sorted_list,
|
||||||
|
batch_size=1,
|
||||||
|
need_grad_cam=False,
|
||||||
|
)
|
||||||
|
for i, itm in zip(
|
||||||
|
range(len(dict_itm_scores_for_blib["blip2_coco"])),
|
||||||
|
dict_itm_scores_for_blib["blip2_coco"],
|
||||||
|
):
|
||||||
|
assert (
|
||||||
|
math.isclose(itm_scores[0].tolist()[i], itm, rel_tol=10 * related_error)
|
||||||
|
is True
|
||||||
|
)
|
||||||
|
del itm_scores, image_gradcam_with_itm
|
||||||
|
cuda.empty_cache()
|
||||||
|
|||||||
115
notebooks/multimodal_search.ipynb
сгенерированный
115
notebooks/multimodal_search.ipynb
сгенерированный
@ -138,9 +138,7 @@
|
|||||||
" image_keys,\n",
|
" image_keys,\n",
|
||||||
" image_names,\n",
|
" image_names,\n",
|
||||||
" features_image_stacked,\n",
|
" features_image_stacked,\n",
|
||||||
") = ms.MultimodalSearch.parsing_images(\n",
|
") = ms.MultimodalSearch.parsing_images(mydict, model_type, path_to_saved_tensors=\".\")"
|
||||||
" mydict, model_type, path_to_saved_tensors=\"./saved_tensors/\"\n",
|
|
||||||
")"
|
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -170,7 +168,7 @@
|
|||||||
"# ) = ms.MultimodalSearch.parsing_images(\n",
|
"# ) = ms.MultimodalSearch.parsing_images(\n",
|
||||||
"# mydict,\n",
|
"# mydict,\n",
|
||||||
"# model_type,\n",
|
"# model_type,\n",
|
||||||
"# path_to_load_tensors=\"./saved_tensors/18_blip_saved_features_image.pt\",\n",
|
"# path_to_load_tensors=\".5_blip_saved_features_image.pt\",\n",
|
||||||
"# )"
|
"# )"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
@ -194,15 +192,14 @@
|
|||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": null,
|
||||||
"id": "c4196a52-d01e-42e4-8674-5712f7d6f792",
|
"id": "c4196a52-d01e-42e4-8674-5712f7d6f792",
|
||||||
"metadata": {},
|
"metadata": {
|
||||||
|
"tags": []
|
||||||
|
},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"search_query3 = [\n",
|
"search_query3 = [\n",
|
||||||
" {\"text_input\": \"politician press conference\"},\n",
|
" {\"text_input\": \"politician press conference\"},\n",
|
||||||
" {\"text_input\": \"a world map\"},\n",
|
" {\"text_input\": \"a world map\"},\n",
|
||||||
" {\"image\": \"../data/haos.png\"},\n",
|
|
||||||
" {\"image\": \"../data/image-34098-800.png\"},\n",
|
|
||||||
" {\"image\": \"../data/LeonPresserMorocco20032015_600.png\"},\n",
|
|
||||||
" {\"text_input\": \"a dog\"},\n",
|
" {\"text_input\": \"a dog\"},\n",
|
||||||
"]"
|
"]"
|
||||||
]
|
]
|
||||||
@ -222,10 +219,12 @@
|
|||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": null,
|
||||||
"id": "7f7dc52f-7ee9-4590-96b7-e0d9d3b82378",
|
"id": "7f7dc52f-7ee9-4590-96b7-e0d9d3b82378",
|
||||||
"metadata": {},
|
"metadata": {
|
||||||
|
"tags": []
|
||||||
|
},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"similarity = ms.MultimodalSearch.multimodal_search(\n",
|
"similarity, sorted_lists = ms.MultimodalSearch.multimodal_search(\n",
|
||||||
" mydict,\n",
|
" mydict,\n",
|
||||||
" model,\n",
|
" model,\n",
|
||||||
" vis_processors,\n",
|
" vis_processors,\n",
|
||||||
@ -234,6 +233,7 @@
|
|||||||
" image_keys,\n",
|
" image_keys,\n",
|
||||||
" features_image_stacked,\n",
|
" features_image_stacked,\n",
|
||||||
" search_query3,\n",
|
" search_query3,\n",
|
||||||
|
" filter_number_of_images=20,\n",
|
||||||
")"
|
")"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
@ -249,10 +249,12 @@
|
|||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": null,
|
||||||
"id": "9ad74b21-6187-4a58-9ed8-fd3e80f5a4ed",
|
"id": "9ad74b21-6187-4a58-9ed8-fd3e80f5a4ed",
|
||||||
"metadata": {},
|
"metadata": {
|
||||||
|
"tags": []
|
||||||
|
},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"mydict[\"100127S_ara\"]"
|
"mydict[\"109237S_spa\"]"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -267,10 +269,79 @@
|
|||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": null,
|
||||||
"id": "4324e4fd-e9aa-4933-bb12-074d54e0c510",
|
"id": "4324e4fd-e9aa-4933-bb12-074d54e0c510",
|
||||||
"metadata": {},
|
"metadata": {
|
||||||
|
"tags": []
|
||||||
|
},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"ms.MultimodalSearch.show_results(mydict, search_query3[4])"
|
"ms.MultimodalSearch.show_results(\n",
|
||||||
|
" mydict,\n",
|
||||||
|
" search_query3[0],\n",
|
||||||
|
")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "0b750e9f-fe64-4028-9caf-52d7187462f1",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"For even better results, a slightly different approach has been prepared that can improve search results. It is quite resource-intensive, so it is applied after the main algorithm has found the most relevant images. This approach works only with text queries. Among the parameters you can choose 3 models: `\"blip_base\"`, `\"blip_large\"`, `\"blip2_coco\"`. If you get the Out of Memory error, try reducing the batch_size value (minimum = 1), which is the number of images being processed simultaneously. With the parameter `need_grad_cam = True/False` you can enable the calculation of the heat map of each image to be processed. Thus the `image_text_match_reordering` function calculates new similarity values and new ranks for each image. The resulting values are added to the general dictionary."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "b3af7b39-6d0d-4da3-9b8f-7dfd3f5779be",
|
||||||
|
"metadata": {
|
||||||
|
"tags": []
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"itm_model = \"blip_base\"\n",
|
||||||
|
"# itm_model = \"blip_large\"\n",
|
||||||
|
"# itm_model = \"blip2_coco\""
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "caf1f4ae-4b37-4954-800e-7120f0419de5",
|
||||||
|
"metadata": {
|
||||||
|
"tags": []
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"itm_scores, image_gradcam_with_itm = ms.MultimodalSearch.image_text_match_reordering(\n",
|
||||||
|
" mydict,\n",
|
||||||
|
" search_query3,\n",
|
||||||
|
" itm_model,\n",
|
||||||
|
" image_keys,\n",
|
||||||
|
" sorted_lists,\n",
|
||||||
|
" batch_size=1,\n",
|
||||||
|
" need_grad_cam=True,\n",
|
||||||
|
")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "9e98c150-5fab-4251-bce7-0d8fc7b385b9",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"Then using the same output function you can add the `ITM=True` arguments to output the new image order. You can also add the `image_gradcam_with_itm` argument to output the heat maps of the calculated images. "
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "6a829b99-5230-463a-8b11-30ffbb67fc3a",
|
||||||
|
"metadata": {
|
||||||
|
"tags": []
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"ms.MultimodalSearch.show_results(\n",
|
||||||
|
" mydict, search_query3[0], itm=True, image_gradcam_with_itm=image_gradcam_with_itm\n",
|
||||||
|
")"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -320,7 +391,9 @@
|
|||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": null,
|
||||||
"id": "e78646d6-80be-4d3e-8123-3360957bcaa8",
|
"id": "e78646d6-80be-4d3e-8123-3360957bcaa8",
|
||||||
"metadata": {},
|
"metadata": {
|
||||||
|
"tags": []
|
||||||
|
},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"df.head(10)"
|
"df.head(10)"
|
||||||
@ -338,11 +411,21 @@
|
|||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": null,
|
||||||
"id": "185f7dde-20dc-44d8-9ab0-de41f9b5734d",
|
"id": "185f7dde-20dc-44d8-9ab0-de41f9b5734d",
|
||||||
"metadata": {},
|
"metadata": {
|
||||||
|
"tags": []
|
||||||
|
},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"df.to_csv(\"./data_out.csv\")"
|
"df.to_csv(\"./data_out.csv\")"
|
||||||
]
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "b6a79201-7c17-496c-a6a1-b8ecfd3dd1e8",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": []
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"metadata": {
|
"metadata": {
|
||||||
|
|||||||
Загрузка…
x
Ссылка в новой задаче
Block a user