зеркало из
https://github.com/ssciwr/AMMICO.git
synced 2025-10-30 13:36:04 +02:00
fix if-else, added clip ViT-L-14=336 model
Этот коммит содержится в:
родитель
779c5227ae
Коммит
4e4b7fac75
@ -15,54 +15,58 @@ class MultimodalSearch(AnalysisMethod):
|
|||||||
|
|
||||||
multimodal_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
multimodal_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
|
||||||
def load_feature_extractor_model(device, model_type):
|
def load_feature_extractor_model_blip2(device):
|
||||||
if model_type == "blip2":
|
model, vis_processors, txt_processors = load_model_and_preprocess(
|
||||||
model, vis_processors, txt_processors = load_model_and_preprocess(
|
name="blip2_feature_extractor",
|
||||||
name="blip2_feature_extractor",
|
model_type="pretrain",
|
||||||
model_type="pretrain",
|
is_eval=True,
|
||||||
is_eval=True,
|
device=device,
|
||||||
device=device,
|
)
|
||||||
)
|
return model, vis_processors, txt_processors
|
||||||
elif model_type == "blip":
|
|
||||||
model, vis_processors, txt_processors = load_model_and_preprocess(
|
|
||||||
name="blip_feature_extractor",
|
|
||||||
model_type="base",
|
|
||||||
is_eval=True,
|
|
||||||
device=device,
|
|
||||||
)
|
|
||||||
elif model_type == "albef":
|
|
||||||
model, vis_processors, txt_processors = load_model_and_preprocess(
|
|
||||||
name="albef_feature_extractor",
|
|
||||||
model_type="base",
|
|
||||||
is_eval=True,
|
|
||||||
device=device,
|
|
||||||
)
|
|
||||||
elif model_type == "clip_base":
|
|
||||||
model, vis_processors, txt_processors = load_model_and_preprocess(
|
|
||||||
name="clip_feature_extractor",
|
|
||||||
model_type="base",
|
|
||||||
is_eval=True,
|
|
||||||
device=device,
|
|
||||||
)
|
|
||||||
elif model_type == "clip_rn50":
|
|
||||||
model, vis_processors, txt_processors = load_model_and_preprocess(
|
|
||||||
name="clip_feature_extractor",
|
|
||||||
model_type="RN50",
|
|
||||||
is_eval=True,
|
|
||||||
device=device,
|
|
||||||
)
|
|
||||||
elif model_type == "clip_vitl14":
|
|
||||||
model, vis_processors, txt_processors = load_model_and_preprocess(
|
|
||||||
name="clip_feature_extractor",
|
|
||||||
model_type="ViT-L-14",
|
|
||||||
is_eval=True,
|
|
||||||
device=device,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
print(
|
|
||||||
"Please, use one of the following models: blip2, blip, albef, clip_base, clip_rn50, clip_vitl14"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
def load_feature_extractor_model_blip(device):
|
||||||
|
model, vis_processors, txt_processors = load_model_and_preprocess(
|
||||||
|
name="blip_feature_extractor",
|
||||||
|
model_type="base",
|
||||||
|
is_eval=True,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
return model, vis_processors, txt_processors
|
||||||
|
|
||||||
|
def load_feature_extractor_model_albef(device):
|
||||||
|
model, vis_processors, txt_processors = load_model_and_preprocess(
|
||||||
|
name="albef_feature_extractor",
|
||||||
|
model_type="base",
|
||||||
|
is_eval=True,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
return model, vis_processors, txt_processors
|
||||||
|
|
||||||
|
def load_feature_extractor_model_clip_base(device):
|
||||||
|
model, vis_processors, txt_processors = load_model_and_preprocess(
|
||||||
|
name="clip_feature_extractor",
|
||||||
|
model_type="base",
|
||||||
|
is_eval=True,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
return model, vis_processors, txt_processors
|
||||||
|
|
||||||
|
def load_feature_extractor_model_clip_vitl14(device):
|
||||||
|
model, vis_processors, txt_processors = load_model_and_preprocess(
|
||||||
|
name="clip_feature_extractor",
|
||||||
|
model_type="ViT-L-14",
|
||||||
|
is_eval=True,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
return model, vis_processors, txt_processors
|
||||||
|
|
||||||
|
def load_feature_extractor_model_clip_vitl14_336(device):
|
||||||
|
model, vis_processors, txt_processors = load_model_and_preprocess(
|
||||||
|
name="clip_feature_extractor",
|
||||||
|
model_type="ViT-L-14-336",
|
||||||
|
is_eval=True,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
return model, vis_processors, txt_processors
|
return model, vis_processors, txt_processors
|
||||||
|
|
||||||
def read_img(filepath):
|
def read_img(filepath):
|
||||||
@ -81,34 +85,10 @@ class MultimodalSearch(AnalysisMethod):
|
|||||||
|
|
||||||
return raw_images, images_tensors
|
return raw_images, images_tensors
|
||||||
|
|
||||||
def extract_image_features(model, images_tensors, model_type):
|
def extract_image_features_blip2(model, images_tensors):
|
||||||
if model_type == "blip2":
|
with torch.cuda.amp.autocast(
|
||||||
with torch.cuda.amp.autocast(
|
enabled=(MultimodalSearch.multimodal_device != torch.device("cpu"))
|
||||||
enabled=(MultimodalSearch.multimodal_device != torch.device("cpu"))
|
):
|
||||||
):
|
|
||||||
features_image = [
|
|
||||||
model.extract_features(
|
|
||||||
{"image": ten, "text_input": ""}, mode="image"
|
|
||||||
)
|
|
||||||
for ten in images_tensors
|
|
||||||
]
|
|
||||||
features_image_stacked = torch.stack(
|
|
||||||
[
|
|
||||||
feat.image_embeds_proj[:, 0, :].squeeze(0)
|
|
||||||
for feat in features_image
|
|
||||||
]
|
|
||||||
)
|
|
||||||
elif model_type in ("clip_base", "clip_rn50", "clip_vitl14"):
|
|
||||||
features_image = [
|
|
||||||
model.extract_features({"image": ten}) for ten in images_tensors
|
|
||||||
]
|
|
||||||
features_image_stacked = torch.stack(
|
|
||||||
[
|
|
||||||
Func.normalize(feat.float(), dim=-1).squeeze(0)
|
|
||||||
for feat in features_image
|
|
||||||
]
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
features_image = [
|
features_image = [
|
||||||
model.extract_features({"image": ten, "text_input": ""}, mode="image")
|
model.extract_features({"image": ten, "text_input": ""}, mode="image")
|
||||||
for ten in images_tensors
|
for ten in images_tensors
|
||||||
@ -118,6 +98,25 @@ class MultimodalSearch(AnalysisMethod):
|
|||||||
)
|
)
|
||||||
return features_image_stacked
|
return features_image_stacked
|
||||||
|
|
||||||
|
def extract_image_features_clip(model, images_tensors):
|
||||||
|
features_image = [
|
||||||
|
model.extract_features({"image": ten}) for ten in images_tensors
|
||||||
|
]
|
||||||
|
features_image_stacked = torch.stack(
|
||||||
|
[Func.normalize(feat.float(), dim=-1).squeeze(0) for feat in features_image]
|
||||||
|
)
|
||||||
|
return features_image_stacked
|
||||||
|
|
||||||
|
def extract_image_features_basic(model, images_tensors):
|
||||||
|
features_image = [
|
||||||
|
model.extract_features({"image": ten, "text_input": ""}, mode="image")
|
||||||
|
for ten in images_tensors
|
||||||
|
]
|
||||||
|
features_image_stacked = torch.stack(
|
||||||
|
[feat.image_embeds_proj[:, 0, :].squeeze(0) for feat in features_image]
|
||||||
|
)
|
||||||
|
return features_image_stacked
|
||||||
|
|
||||||
def save_tensors(
|
def save_tensors(
|
||||||
model_type, features_image_stacked, name="saved_features_image.pt"
|
model_type, features_image_stacked, name="saved_features_image.pt"
|
||||||
):
|
):
|
||||||
@ -137,9 +136,9 @@ class MultimodalSearch(AnalysisMethod):
|
|||||||
|
|
||||||
return features_text
|
return features_text
|
||||||
|
|
||||||
def parsing_images(self, model_type):
|
def parsing_images(self, model_type, path_to_saved_tensors=None):
|
||||||
|
|
||||||
if model_type in ("clip_base", "clip_rn50", "clip_vitl14"):
|
if model_type in ("clip_base", "clip_vitl14_336", "clip_vitl14"):
|
||||||
path_to_lib = lavis.__file__[:-11] + "models/clip_models/"
|
path_to_lib = lavis.__file__[:-11] + "models/clip_models/"
|
||||||
url = "https://raw.githubusercontent.com/salesforce/LAVIS/main/lavis/models/clip_models/bpe_simple_vocab_16e6.txt.gz"
|
url = "https://raw.githubusercontent.com/salesforce/LAVIS/main/lavis/models/clip_models/bpe_simple_vocab_16e6.txt.gz"
|
||||||
r = requests.get(url, allow_redirects=False)
|
r = requests.get(url, allow_redirects=False)
|
||||||
@ -148,20 +147,46 @@ class MultimodalSearch(AnalysisMethod):
|
|||||||
image_keys = sorted(self.keys())
|
image_keys = sorted(self.keys())
|
||||||
image_names = [self[k]["filename"] for k in image_keys]
|
image_names = [self[k]["filename"] for k in image_keys]
|
||||||
|
|
||||||
(
|
select_model = {
|
||||||
model,
|
"blip2": MultimodalSearch.load_feature_extractor_model_blip2,
|
||||||
vis_processors,
|
"blip": MultimodalSearch.load_feature_extractor_model_blip,
|
||||||
txt_processors,
|
"albef": MultimodalSearch.load_feature_extractor_model_albef,
|
||||||
) = MultimodalSearch.load_feature_extractor_model(
|
"clip_base": MultimodalSearch.load_feature_extractor_model_clip_base,
|
||||||
MultimodalSearch.multimodal_device, model_type
|
"clip_vitl14": MultimodalSearch.load_feature_extractor_model_clip_vitl14,
|
||||||
)
|
"clip_vitl14_336": MultimodalSearch.load_feature_extractor_model_clip_vitl14_336,
|
||||||
|
}
|
||||||
|
|
||||||
|
select_extract_image_features = {
|
||||||
|
"blip2": MultimodalSearch.extract_image_features_blip2,
|
||||||
|
"blip": MultimodalSearch.extract_image_features_basic,
|
||||||
|
"albef": MultimodalSearch.extract_image_features_basic,
|
||||||
|
"clip_base": MultimodalSearch.extract_image_features_clip,
|
||||||
|
"clip_vitl14": MultimodalSearch.extract_image_features_clip,
|
||||||
|
"clip_vitl14_336": MultimodalSearch.extract_image_features_clip,
|
||||||
|
}
|
||||||
|
|
||||||
|
if model_type in select_model.keys():
|
||||||
|
(model, vis_processors, txt_processors,) = select_model[
|
||||||
|
model_type
|
||||||
|
](MultimodalSearch.multimodal_device)
|
||||||
|
else:
|
||||||
|
raise SyntaxError(
|
||||||
|
"Please, use one of the following models: blip2, blip, albef, clip_base, clip_vitl14, clip_vitl14_336"
|
||||||
|
)
|
||||||
|
|
||||||
raw_images, images_tensors = MultimodalSearch.read_and_process_images(
|
raw_images, images_tensors = MultimodalSearch.read_and_process_images(
|
||||||
image_names, vis_processors
|
image_names, vis_processors
|
||||||
)
|
)
|
||||||
features_image_stacked = MultimodalSearch.extract_image_features(
|
if path_to_saved_tensors is None:
|
||||||
model, images_tensors, model_type
|
with torch.no_grad():
|
||||||
)
|
features_image_stacked = select_extract_image_features[model_type](
|
||||||
MultimodalSearch.save_tensors(model_type, features_image_stacked)
|
model, images_tensors
|
||||||
|
)
|
||||||
|
MultimodalSearch.save_tensors(model_type, features_image_stacked)
|
||||||
|
else:
|
||||||
|
features_image_stacked = MultimodalSearch.load_tensors(
|
||||||
|
str(path_to_saved_tensors)
|
||||||
|
)
|
||||||
|
|
||||||
return (
|
return (
|
||||||
model,
|
model,
|
||||||
@ -175,6 +200,16 @@ class MultimodalSearch(AnalysisMethod):
|
|||||||
def querys_processing(
|
def querys_processing(
|
||||||
self, search_query, model, txt_processors, vis_processors, model_type
|
self, search_query, model, txt_processors, vis_processors, model_type
|
||||||
):
|
):
|
||||||
|
|
||||||
|
select_extract_image_features = {
|
||||||
|
"blip2": MultimodalSearch.extract_image_features_blip2,
|
||||||
|
"blip": MultimodalSearch.extract_image_features_basic,
|
||||||
|
"albef": MultimodalSearch.extract_image_features_basic,
|
||||||
|
"clip_base": MultimodalSearch.extract_image_features_clip,
|
||||||
|
"clip_vitl14": MultimodalSearch.extract_image_features_clip,
|
||||||
|
"clip_vitl14_336": MultimodalSearch.extract_image_features_clip,
|
||||||
|
}
|
||||||
|
|
||||||
for query in search_query:
|
for query in search_query:
|
||||||
if not (len(query) == 1) and (query in ("image", "text_input")):
|
if not (len(query) == 1) and (query in ("image", "text_input")):
|
||||||
raise SyntaxError(
|
raise SyntaxError(
|
||||||
@ -194,10 +229,10 @@ class MultimodalSearch(AnalysisMethod):
|
|||||||
{"image": images_tensors, "text_input": text_processing}
|
{"image": images_tensors, "text_input": text_processing}
|
||||||
)
|
)
|
||||||
|
|
||||||
if model_type in ("clip_base", "clip_rn50", "clip_vitl14"):
|
multi_features_query = []
|
||||||
multi_features_query = []
|
for query in multi_sample:
|
||||||
for query in multi_sample:
|
if query["image"] == "":
|
||||||
if query["image"] == "":
|
if model_type in ("clip_base", "clip_vitl14_336", "clip_vitl14"):
|
||||||
features = model.extract_features(
|
features = model.extract_features(
|
||||||
{"text_input": query["text_input"]}
|
{"text_input": query["text_input"]}
|
||||||
)
|
)
|
||||||
@ -208,17 +243,7 @@ class MultimodalSearch(AnalysisMethod):
|
|||||||
multi_features_query.append(
|
multi_features_query.append(
|
||||||
Func.normalize(features_squeeze, dim=-1)
|
Func.normalize(features_squeeze, dim=-1)
|
||||||
)
|
)
|
||||||
if query["text_input"] == "":
|
else:
|
||||||
multi_features_query.append(
|
|
||||||
MultimodalSearch.extract_image_features(
|
|
||||||
model, query["image"], model_type
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
else:
|
|
||||||
multi_features_query = []
|
|
||||||
for query in multi_sample:
|
|
||||||
if query["image"] == "":
|
|
||||||
features = model.extract_features(query, mode="text")
|
features = model.extract_features(query, mode="text")
|
||||||
features_squeeze = (
|
features_squeeze = (
|
||||||
features.text_embeds_proj[:, 0, :]
|
features.text_embeds_proj[:, 0, :]
|
||||||
@ -226,12 +251,10 @@ class MultimodalSearch(AnalysisMethod):
|
|||||||
.to(MultimodalSearch.multimodal_device)
|
.to(MultimodalSearch.multimodal_device)
|
||||||
)
|
)
|
||||||
multi_features_query.append(features_squeeze)
|
multi_features_query.append(features_squeeze)
|
||||||
if query["text_input"] == "":
|
if query["text_input"] == "":
|
||||||
multi_features_query.append(
|
multi_features_query.append(
|
||||||
MultimodalSearch.extract_image_features(
|
select_extract_image_features[model_type](model, query["image"])
|
||||||
model, query["image"], model_type
|
)
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
multi_features_stacked = torch.stack(
|
multi_features_stacked = torch.stack(
|
||||||
[query.squeeze(0) for query in multi_features_query]
|
[query.squeeze(0) for query in multi_features_query]
|
||||||
@ -251,11 +274,13 @@ class MultimodalSearch(AnalysisMethod):
|
|||||||
):
|
):
|
||||||
features_image_stacked.to(MultimodalSearch.multimodal_device)
|
features_image_stacked.to(MultimodalSearch.multimodal_device)
|
||||||
|
|
||||||
multi_features_stacked = MultimodalSearch.querys_processing(
|
with torch.no_grad():
|
||||||
self, search_query, model, txt_processors, vis_processors, model_type
|
multi_features_stacked = MultimodalSearch.querys_processing(
|
||||||
)
|
self, search_query, model, txt_processors, vis_processors, model_type
|
||||||
|
)
|
||||||
|
|
||||||
similarity = features_image_stacked @ multi_features_stacked.t()
|
similarity = features_image_stacked @ multi_features_stacked.t()
|
||||||
|
similarity_soft_max = torch.nn.Softmax(dim=0)(similarity / 0.01)
|
||||||
sorted_lists = [
|
sorted_lists = [
|
||||||
sorted(range(len(similarity)), key=lambda k: similarity[k, i], reverse=True)
|
sorted(range(len(similarity)), key=lambda k: similarity[k, i], reverse=True)
|
||||||
for i in range(len(similarity[0]))
|
for i in range(len(similarity[0]))
|
||||||
@ -267,7 +292,7 @@ class MultimodalSearch(AnalysisMethod):
|
|||||||
self[key]["rank " + list(search_query[q].values())[0]] = places[q][i]
|
self[key]["rank " + list(search_query[q].values())[0]] = places[q][i]
|
||||||
self[key][list(search_query[q].values())[0]] = similarity[i][q].item()
|
self[key][list(search_query[q].values())[0]] = similarity[i][q].item()
|
||||||
|
|
||||||
return similarity
|
return similarity, similarity_soft_max
|
||||||
|
|
||||||
def show_results(self, query):
|
def show_results(self, query):
|
||||||
if "image" in query.keys():
|
if "image" in query.keys():
|
||||||
|
|||||||
29
notebooks/multimodal_search.ipynb
сгенерированный
29
notebooks/multimodal_search.ipynb
сгенерированный
@ -81,7 +81,7 @@
|
|||||||
"id": "66d6ede4-00bc-4aeb-9a36-e52d7de33fe5",
|
"id": "66d6ede4-00bc-4aeb-9a36-e52d7de33fe5",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"source": [
|
"source": [
|
||||||
"You can choose one of the following models: blip, blip2, albef, clip_base, clip_rn50, clip_vitl14"
|
"You can choose one of the following models: blip, blip2, albef, clip_base, clip_vitl14, clip_vitl14_336"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -91,7 +91,7 @@
|
|||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"model_type = \"blip\""
|
"model_type = \"clip_vitl14_336\""
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -116,17 +116,32 @@
|
|||||||
"id": "9ff8a894-566b-4c4f-acca-21c50b5b1f52",
|
"id": "9ff8a894-566b-4c4f-acca-21c50b5b1f52",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"source": [
|
"source": [
|
||||||
"The of all images `features_image_stacked` was saved in `saved_features_image.pt`. If you run it once for current model and set of images you do not need to repeat it again. Instead you can load this features with the command:"
|
"The tensors of all images `features_image_stacked` was saved in `<Number_of_images>_<model_name>_saved_features_image.pt`. If you run it once for current model and current set of images you do not need to repeat it again. Instead you can load this features with the command:"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": null,
|
||||||
"id": "c40e93f0-6bea-4886-b904-8b46ed6ec819",
|
"id": "56c6d488-f093-4661-835a-5c73a329c874",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"# features_image_stacked = ms.MultimodalSearch.load_tensors('saved_features_image.pt')"
|
"# (\n",
|
||||||
|
"# model,\n",
|
||||||
|
"# vis_processors,\n",
|
||||||
|
"# txt_processors,\n",
|
||||||
|
"# image_keys,\n",
|
||||||
|
"# image_names,\n",
|
||||||
|
"# features_image_stacked,\n",
|
||||||
|
"# ) = ms.MultimodalSearch.parsing_images(mydict, model_type,\"18_clip_base_saved_features_image.pt\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "309923c1-d6f8-4424-8fca-bde5f3a98b38",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"Here we already processed our image folder with 18 images with `clip_base` model. So you need just write the name `18_clip_base_saved_features_image.pt` of the saved file that consists of tensors of all images as a 3rd argument to the previous function. "
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -170,7 +185,7 @@
|
|||||||
" image_keys,\n",
|
" image_keys,\n",
|
||||||
" features_image_stacked,\n",
|
" features_image_stacked,\n",
|
||||||
" search_query3,\n",
|
" search_query3,\n",
|
||||||
");"
|
")"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -206,7 +221,7 @@
|
|||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"ms.MultimodalSearch.show_results(mydict, search_query3[0])"
|
"ms.MultimodalSearch.show_results(mydict, search_query3[4])"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
|||||||
Загрузка…
x
Ссылка в новой задаче
Block a user