зеркало из
https://github.com/ssciwr/AMMICO.git
synced 2025-10-29 21:16:06 +02:00
Change input format for multimodal search
Этот коммит содержится в:
родитель
b709f69d58
Коммит
70866dfc69
@ -84,12 +84,12 @@ class MultimodalSearch(AnalysisMethod):
|
|||||||
enabled=(MultimodalSearch.multimodal_device != torch.device("cpu"))
|
enabled=(MultimodalSearch.multimodal_device != torch.device("cpu"))
|
||||||
):
|
):
|
||||||
features_image = [
|
features_image = [
|
||||||
model.extract_features({"image": ten}, mode="image")
|
model.extract_features({"image": ten, "text_input": ""}, mode="image")
|
||||||
for ten in images_tensors
|
for ten in images_tensors
|
||||||
]
|
]
|
||||||
else:
|
else:
|
||||||
features_image = [
|
features_image = [
|
||||||
model.extract_features({"image": ten}, mode="image")
|
model.extract_features({"image": ten, "text_input": ""}, mode="image")
|
||||||
for ten in images_tensors
|
for ten in images_tensors
|
||||||
]
|
]
|
||||||
|
|
||||||
@ -113,7 +113,8 @@ class MultimodalSearch(AnalysisMethod):
|
|||||||
features_text = model.extract_features(sample_text, mode="text")
|
features_text = model.extract_features(sample_text, mode="text")
|
||||||
|
|
||||||
return features_text
|
return features_text
|
||||||
|
|
||||||
|
|
||||||
def parsing_images(self, model_type):
|
def parsing_images(self, model_type):
|
||||||
image_keys = sorted(self.keys())
|
image_keys = sorted(self.keys())
|
||||||
image_names = [self[k]["filename"] for k in image_keys]
|
image_names = [self[k]["filename"] for k in image_keys]
|
||||||
@ -132,31 +133,40 @@ class MultimodalSearch(AnalysisMethod):
|
|||||||
)
|
)
|
||||||
MultimodalSearch.save_tensors(features_image_stacked)
|
MultimodalSearch.save_tensors(features_image_stacked)
|
||||||
|
|
||||||
return image_keys, image_names, features_image_stacked
|
return model, vis_processors, txt_processors, image_keys, image_names, features_image_stacked
|
||||||
|
|
||||||
def multimodal_search(
|
def multimodal_search(
|
||||||
self, model_type, image_keys, features_image_stacked, search_query
|
self, model, vis_processors, txt_processors, model_type, image_keys, features_image_stacked, search_query
|
||||||
):
|
):
|
||||||
features_image_stacked.to(MultimodalSearch.multimodal_device)
|
features_image_stacked.to(MultimodalSearch.multimodal_device)
|
||||||
(
|
|
||||||
model,
|
for query in search_query:
|
||||||
vis_processors,
|
if (len(query)!=1):
|
||||||
txt_processors,
|
raise SyntaxError('Each querry must contain either an "image" or a "text_input"')
|
||||||
) = MultimodalSearch.load_feature_extractor_model(
|
|
||||||
MultimodalSearch.multimodal_device, model_type
|
multi_sample = []
|
||||||
)
|
for query in search_query:
|
||||||
multi_text_input = [txt_processors["eval"](query) for query in search_query]
|
if "text_input" in query.keys():
|
||||||
multi_sample = [{"text_input": [query]} for query in multi_text_input]
|
text_processing = txt_processors["eval"](query["text_input"])
|
||||||
multi_features_text = [
|
image_processing = ""
|
||||||
model.extract_features(sample, mode="text") for sample in multi_sample
|
elif "image" in query.keys():
|
||||||
]
|
_, image_processing = MultimodalSearch.read_and_process_images([query["image"]], vis_processors)
|
||||||
multi_features_text_stacked = torch.stack(
|
text_processing = ""
|
||||||
[
|
multi_sample.append({"image": image_processing, "text_input": text_processing})
|
||||||
features.text_embeds_proj[:, 0, :].squeeze(0)
|
|
||||||
for features in multi_features_text
|
multi_features_query = []
|
||||||
]
|
for query in multi_sample:
|
||||||
).to(MultimodalSearch.multimodal_device)
|
if query["image"] == "":
|
||||||
similarity = features_image_stacked @ multi_features_text_stacked.t()
|
features = model.extract_features(query, mode="text")
|
||||||
|
features_squeeze = features.text_embeds_proj[:, 0, :].squeeze(0).to(MultimodalSearch.multimodal_device)
|
||||||
|
multi_features_query.append(features_squeeze)
|
||||||
|
if query["text_input"] == "":
|
||||||
|
multi_features_query.append( MultimodalSearch.extract_image_features(
|
||||||
|
model, query["image"], model_type))
|
||||||
|
|
||||||
|
multi_features_stacked = torch.stack([query.squeeze(0) for query in multi_features_query]).to(MultimodalSearch.multimodal_device)
|
||||||
|
|
||||||
|
similarity = features_image_stacked @ multi_features_stacked.t()
|
||||||
sorted_lists = [
|
sorted_lists = [
|
||||||
sorted(range(len(similarity)), key=lambda k: similarity[k, i], reverse=True)
|
sorted(range(len(similarity)), key=lambda k: similarity[k, i], reverse=True)
|
||||||
for i in range(len(similarity[0]))
|
for i in range(len(similarity[0]))
|
||||||
@ -165,13 +175,19 @@ class MultimodalSearch(AnalysisMethod):
|
|||||||
|
|
||||||
for q in range(len(search_query)):
|
for q in range(len(search_query)):
|
||||||
for i, key in zip(range(len(image_keys)), image_keys):
|
for i, key in zip(range(len(image_keys)), image_keys):
|
||||||
self[key]["rank " + search_query[q]] = places[q][i]
|
self[key]["rank " + list(search_query[q].values())[0]] = places[q][i]
|
||||||
self[key][search_query[q]] = similarity[i][q].item()
|
self[key][list(search_query[q].values())[0]] = similarity[i][q].item()
|
||||||
|
|
||||||
return self
|
return similarity
|
||||||
|
|
||||||
def show_results(self, query):
|
def show_results(self, query):
|
||||||
for s in sorted(self.items(), key=lambda t: t[1][query], reverse=True):
|
if "image" in query.keys():
|
||||||
|
pic = Image.open(query["image"]).convert("RGB")
|
||||||
|
pic.thumbnail((400, 400))
|
||||||
|
display("Your search query: ", pic,"--------------------------------------------------", "Results:")
|
||||||
|
elif "text_input" in query.keys():
|
||||||
|
display("Your search query: " + query["text_input"], "--------------------------------------------------", "Results:")
|
||||||
|
for s in sorted(self.items(), key=lambda t: t[1][list(query.values())[0]], reverse=True):
|
||||||
p1 = Image.open(s[1]["filename"]).convert("RGB")
|
p1 = Image.open(s[1]["filename"]).convert("RGB")
|
||||||
p1.thumbnail((400, 400))
|
p1.thumbnail((400, 400))
|
||||||
display(p1, s[1][query])
|
display(p1, "Rank: " + str(s[1]["rank " + list(query.values())[0]]) + " Val: " + str(s[1][list(query.values())[0]]))
|
||||||
|
|||||||
944
notebooks/multimodal_search.ipynb
сгенерированный
944
notebooks/multimodal_search.ipynb
сгенерированный
Различия файлов скрыты, потому что одна или несколько строк слишком длинны
Загрузка…
x
Ссылка в новой задаче
Block a user