Этот коммит содержится в:
Petr Andriushchenko 2023-03-09 15:45:06 +01:00
родитель 2266df5919
Коммит 24ee78c23e
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4C4A5DCF634115B6
3 изменённых файлов: 49 добавлений и 7 удалений

Просмотреть файл

@ -273,7 +273,17 @@ class MultimodalSearch(AnalysisMethod):
image_keys,
features_image_stacked,
search_query,
filter_number_of_images=None,
filter_val_limit=None,
filter_rel_error=None,
):
if filter_number_of_images is None:
filter_number_of_images = len(self)
if filter_val_limit is None:
filter_val_limit = 0
if filter_rel_error is None:
filter_rel_error = 1e10
features_image_stacked.to(MultimodalSearch.multimodal_device)
with torch.no_grad():
@ -290,10 +300,26 @@ class MultimodalSearch(AnalysisMethod):
places = [[item.index(i) for i in range(len(item))] for item in sorted_lists]
for q in range(len(search_query)):
for i, key in zip(range(len(image_keys)), image_keys):
self[key]["rank " + list(search_query[q].values())[0]] = places[q][i]
self[key][list(search_query[q].values())[0]] = similarity[i][q].item()
max_val = similarity[sorted_lists[q][0]][q].item()
print(max_val)
for i, key in zip(range(len(image_keys)), sorted_lists[q]):
if (
i < filter_number_of_images
and similarity[key][q].item() > filter_val_limit
and 100 * abs(max_val - similarity[key][q].item()) / max_val
< filter_rel_error
):
self[image_keys[key]][
"rank " + list(search_query[q].values())[0]
] = places[q][key]
self[image_keys[key]][
list(search_query[q].values())[0]
] = similarity[key][q].item()
else:
self[image_keys[key]][
"rank " + list(search_query[q].values())[0]
] = None
self[image_keys[key]][list(search_query[q].values())[0]] = 0
return similarity, sorted_lists
def show_results(self, query):
@ -315,12 +341,18 @@ class MultimodalSearch(AnalysisMethod):
for s in sorted(
self.items(), key=lambda t: t[1][list(query.values())[0]], reverse=True
):
if s[1]["rank " + list(query.values())[0]] is None:
break
p1 = Image.open(s[1]["filename"]).convert("RGB")
p1.thumbnail((400, 400))
display(
p1,
"Rank: "
+ str(s[1]["rank " + list(query.values())[0]])
+ " Val: "
+ str(s[1][list(query.values())[0]]),
s[0],
p1,
)
display(
"--------------------------------------------------",
)

Просмотреть файл

@ -2,8 +2,7 @@ import pytest
import math
from PIL import Image
import numpy
from torch import device, cuda, no_grad
from lavis.models import load_model_and_preprocess
from torch import device, cuda
import misinformation.multimodal_search as ms
testdict = {

11
notebooks/multimodal_search.ipynb сгенерированный
Просмотреть файл

@ -174,6 +174,17 @@
"]"
]
},
{
"cell_type": "markdown",
"id": "8bcf3127-3dfd-4ff4-b9e7-a043099b1418",
"metadata": {},
"source": [
"You can filter your results in 3 different ways:\n",
"- `filter_number_of_images` limits the number of images found. That is, if the parameter `filter_number_of_images = 10`, then the first 10 images that best match the query will be shown. The other images ranks will be set to `None` and the similarity value to `0`.\n",
"- `filter_val_limit` limits the output of images with a similarity value not bigger than `filter_val_limit`. That is, if the parameter `filter_val_limit = 0.2`, all images with similarity less than 0.2 will be discarded.\n",
"- `filter_rel_error` (percentage) limits the output of images with a similarity value not bigger than `100 * abs(current_simularity_value - best_simularity_value_in_current_search)/best_simularity_value_in_current_search < filter_rel_error`. That is, if we set filter_rel_error = 30, it means that if the top1 image have 0.5 similarity value, we discard all image with similarity less than 0.35."
]
},
{
"cell_type": "code",
"execution_count": null,