зеркало из
https://github.com/ssciwr/AMMICO.git
synced 2025-10-29 21:16:06 +02:00
added filtering function
Этот коммит содержится в:
родитель
2266df5919
Коммит
24ee78c23e
@ -273,7 +273,17 @@ class MultimodalSearch(AnalysisMethod):
|
||||
image_keys,
|
||||
features_image_stacked,
|
||||
search_query,
|
||||
filter_number_of_images=None,
|
||||
filter_val_limit=None,
|
||||
filter_rel_error=None,
|
||||
):
|
||||
if filter_number_of_images is None:
|
||||
filter_number_of_images = len(self)
|
||||
if filter_val_limit is None:
|
||||
filter_val_limit = 0
|
||||
if filter_rel_error is None:
|
||||
filter_rel_error = 1e10
|
||||
|
||||
features_image_stacked.to(MultimodalSearch.multimodal_device)
|
||||
|
||||
with torch.no_grad():
|
||||
@ -290,10 +300,26 @@ class MultimodalSearch(AnalysisMethod):
|
||||
places = [[item.index(i) for i in range(len(item))] for item in sorted_lists]
|
||||
|
||||
for q in range(len(search_query)):
|
||||
for i, key in zip(range(len(image_keys)), image_keys):
|
||||
self[key]["rank " + list(search_query[q].values())[0]] = places[q][i]
|
||||
self[key][list(search_query[q].values())[0]] = similarity[i][q].item()
|
||||
|
||||
max_val = similarity[sorted_lists[q][0]][q].item()
|
||||
print(max_val)
|
||||
for i, key in zip(range(len(image_keys)), sorted_lists[q]):
|
||||
if (
|
||||
i < filter_number_of_images
|
||||
and similarity[key][q].item() > filter_val_limit
|
||||
and 100 * abs(max_val - similarity[key][q].item()) / max_val
|
||||
< filter_rel_error
|
||||
):
|
||||
self[image_keys[key]][
|
||||
"rank " + list(search_query[q].values())[0]
|
||||
] = places[q][key]
|
||||
self[image_keys[key]][
|
||||
list(search_query[q].values())[0]
|
||||
] = similarity[key][q].item()
|
||||
else:
|
||||
self[image_keys[key]][
|
||||
"rank " + list(search_query[q].values())[0]
|
||||
] = None
|
||||
self[image_keys[key]][list(search_query[q].values())[0]] = 0
|
||||
return similarity, sorted_lists
|
||||
|
||||
def show_results(self, query):
|
||||
@ -315,12 +341,18 @@ class MultimodalSearch(AnalysisMethod):
|
||||
for s in sorted(
|
||||
self.items(), key=lambda t: t[1][list(query.values())[0]], reverse=True
|
||||
):
|
||||
if s[1]["rank " + list(query.values())[0]] is None:
|
||||
break
|
||||
p1 = Image.open(s[1]["filename"]).convert("RGB")
|
||||
p1.thumbnail((400, 400))
|
||||
display(
|
||||
p1,
|
||||
"Rank: "
|
||||
+ str(s[1]["rank " + list(query.values())[0]])
|
||||
+ " Val: "
|
||||
+ str(s[1][list(query.values())[0]]),
|
||||
s[0],
|
||||
p1,
|
||||
)
|
||||
display(
|
||||
"--------------------------------------------------",
|
||||
)
|
||||
|
||||
@ -2,8 +2,7 @@ import pytest
|
||||
import math
|
||||
from PIL import Image
|
||||
import numpy
|
||||
from torch import device, cuda, no_grad
|
||||
from lavis.models import load_model_and_preprocess
|
||||
from torch import device, cuda
|
||||
import misinformation.multimodal_search as ms
|
||||
|
||||
testdict = {
|
||||
|
||||
11
notebooks/multimodal_search.ipynb
сгенерированный
11
notebooks/multimodal_search.ipynb
сгенерированный
@ -174,6 +174,17 @@
|
||||
"]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "8bcf3127-3dfd-4ff4-b9e7-a043099b1418",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"You can filter your results in 3 different ways:\n",
|
||||
"- `filter_number_of_images` limits the number of images found. That is, if the parameter `filter_number_of_images = 10`, then the first 10 images that best match the query will be shown. The other images ranks will be set to `None` and the similarity value to `0`.\n",
|
||||
"- `filter_val_limit` limits the output of images with a similarity value not bigger than `filter_val_limit`. That is, if the parameter `filter_val_limit = 0.2`, all images with similarity less than 0.2 will be discarded.\n",
|
||||
"- `filter_rel_error` (percentage) limits the output of images with a similarity value not bigger than `100 * abs(current_simularity_value - best_simularity_value_in_current_search)/best_simularity_value_in_current_search < filter_rel_error`. That is, if we set filter_rel_error = 30, it means that if the top1 image have 0.5 similarity value, we discard all image with similarity less than 0.35."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
|
||||
Загрузка…
x
Ссылка в новой задаче
Block a user