зеркало из
https://github.com/ssciwr/AMMICO.git
synced 2025-10-29 21:16:06 +02:00
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
Этот коммит содержится в:
родитель
70866dfc69
Коммит
b9158d4947
@ -84,7 +84,9 @@ class MultimodalSearch(AnalysisMethod):
|
||||
enabled=(MultimodalSearch.multimodal_device != torch.device("cpu"))
|
||||
):
|
||||
features_image = [
|
||||
model.extract_features({"image": ten, "text_input": ""}, mode="image")
|
||||
model.extract_features(
|
||||
{"image": ten, "text_input": ""}, mode="image"
|
||||
)
|
||||
for ten in images_tensors
|
||||
]
|
||||
else:
|
||||
@ -113,8 +115,7 @@ class MultimodalSearch(AnalysisMethod):
|
||||
features_text = model.extract_features(sample_text, mode="text")
|
||||
|
||||
return features_text
|
||||
|
||||
|
||||
|
||||
def parsing_images(self, model_type):
|
||||
image_keys = sorted(self.keys())
|
||||
image_names = [self[k]["filename"] for k in image_keys]
|
||||
@ -133,16 +134,32 @@ class MultimodalSearch(AnalysisMethod):
|
||||
)
|
||||
MultimodalSearch.save_tensors(features_image_stacked)
|
||||
|
||||
return model, vis_processors, txt_processors, image_keys, image_names, features_image_stacked
|
||||
return (
|
||||
model,
|
||||
vis_processors,
|
||||
txt_processors,
|
||||
image_keys,
|
||||
image_names,
|
||||
features_image_stacked,
|
||||
)
|
||||
|
||||
def multimodal_search(
|
||||
self, model, vis_processors, txt_processors, model_type, image_keys, features_image_stacked, search_query
|
||||
self,
|
||||
model,
|
||||
vis_processors,
|
||||
txt_processors,
|
||||
model_type,
|
||||
image_keys,
|
||||
features_image_stacked,
|
||||
search_query,
|
||||
):
|
||||
features_image_stacked.to(MultimodalSearch.multimodal_device)
|
||||
|
||||
|
||||
for query in search_query:
|
||||
if (len(query)!=1):
|
||||
raise SyntaxError('Each querry must contain either an "image" or a "text_input"')
|
||||
if len(query) != 1:
|
||||
raise SyntaxError(
|
||||
'Each querry must contain either an "image" or a "text_input"'
|
||||
)
|
||||
|
||||
multi_sample = []
|
||||
for query in search_query:
|
||||
@ -150,21 +167,34 @@ class MultimodalSearch(AnalysisMethod):
|
||||
text_processing = txt_processors["eval"](query["text_input"])
|
||||
image_processing = ""
|
||||
elif "image" in query.keys():
|
||||
_, image_processing = MultimodalSearch.read_and_process_images([query["image"]], vis_processors)
|
||||
_, image_processing = MultimodalSearch.read_and_process_images(
|
||||
[query["image"]], vis_processors
|
||||
)
|
||||
text_processing = ""
|
||||
multi_sample.append({"image": image_processing, "text_input": text_processing})
|
||||
|
||||
multi_sample.append(
|
||||
{"image": image_processing, "text_input": text_processing}
|
||||
)
|
||||
|
||||
multi_features_query = []
|
||||
for query in multi_sample:
|
||||
if query["image"] == "":
|
||||
features = model.extract_features(query, mode="text")
|
||||
features_squeeze = features.text_embeds_proj[:, 0, :].squeeze(0).to(MultimodalSearch.multimodal_device)
|
||||
features_squeeze = (
|
||||
features.text_embeds_proj[:, 0, :]
|
||||
.squeeze(0)
|
||||
.to(MultimodalSearch.multimodal_device)
|
||||
)
|
||||
multi_features_query.append(features_squeeze)
|
||||
if query["text_input"] == "":
|
||||
multi_features_query.append( MultimodalSearch.extract_image_features(
|
||||
model, query["image"], model_type))
|
||||
|
||||
multi_features_stacked = torch.stack([query.squeeze(0) for query in multi_features_query]).to(MultimodalSearch.multimodal_device)
|
||||
multi_features_query.append(
|
||||
MultimodalSearch.extract_image_features(
|
||||
model, query["image"], model_type
|
||||
)
|
||||
)
|
||||
|
||||
multi_features_stacked = torch.stack(
|
||||
[query.squeeze(0) for query in multi_features_query]
|
||||
).to(MultimodalSearch.multimodal_device)
|
||||
|
||||
similarity = features_image_stacked @ multi_features_stacked.t()
|
||||
sorted_lists = [
|
||||
@ -184,10 +214,27 @@ class MultimodalSearch(AnalysisMethod):
|
||||
if "image" in query.keys():
|
||||
pic = Image.open(query["image"]).convert("RGB")
|
||||
pic.thumbnail((400, 400))
|
||||
display("Your search query: ", pic,"--------------------------------------------------", "Results:")
|
||||
display(
|
||||
"Your search query: ",
|
||||
pic,
|
||||
"--------------------------------------------------",
|
||||
"Results:",
|
||||
)
|
||||
elif "text_input" in query.keys():
|
||||
display("Your search query: " + query["text_input"], "--------------------------------------------------", "Results:")
|
||||
for s in sorted(self.items(), key=lambda t: t[1][list(query.values())[0]], reverse=True):
|
||||
display(
|
||||
"Your search query: " + query["text_input"],
|
||||
"--------------------------------------------------",
|
||||
"Results:",
|
||||
)
|
||||
for s in sorted(
|
||||
self.items(), key=lambda t: t[1][list(query.values())[0]], reverse=True
|
||||
):
|
||||
p1 = Image.open(s[1]["filename"]).convert("RGB")
|
||||
p1.thumbnail((400, 400))
|
||||
display(p1, "Rank: " + str(s[1]["rank " + list(query.values())[0]]) + " Val: " + str(s[1][list(query.values())[0]]))
|
||||
display(
|
||||
p1,
|
||||
"Rank: "
|
||||
+ str(s[1]["rank " + list(query.values())[0]])
|
||||
+ " Val: "
|
||||
+ str(s[1][list(query.values())[0]]),
|
||||
)
|
||||
|
||||
801
notebooks/multimodal_search.ipynb
сгенерированный
801
notebooks/multimodal_search.ipynb
сгенерированный
Различия файлов скрыты, потому что одна или несколько строк слишком длинны
Загрузка…
x
Ссылка в новой задаче
Block a user