зеркало из
https://github.com/ssciwr/AMMICO.git
synced 2025-10-29 21:16:06 +02:00
fixed saving path, added tmp_path (#67)
Этот коммит содержится в:
родитель
f89fa4e519
Коммит
14a1a03597
@ -3,6 +3,7 @@ import torch
|
|||||||
import torch.nn.functional as Func
|
import torch.nn.functional as Func
|
||||||
import requests
|
import requests
|
||||||
import lavis
|
import lavis
|
||||||
|
import os
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from IPython.display import display
|
from IPython.display import display
|
||||||
from lavis.models import load_model_and_preprocess
|
from lavis.models import load_model_and_preprocess
|
||||||
@ -118,15 +119,27 @@ class MultimodalSearch(AnalysisMethod):
|
|||||||
return features_image_stacked
|
return features_image_stacked
|
||||||
|
|
||||||
def save_tensors(
|
def save_tensors(
|
||||||
self, model_type, features_image_stacked, name="saved_features_image.pt"
|
self,
|
||||||
|
model_type,
|
||||||
|
features_image_stacked,
|
||||||
|
name="saved_features_image.pt",
|
||||||
|
path="./saved_tensors/",
|
||||||
):
|
):
|
||||||
|
if not os.path.exists(path):
|
||||||
|
os.makedirs(path)
|
||||||
with open(
|
with open(
|
||||||
str(len(features_image_stacked)) + "_" + model_type + "_" + name, "wb"
|
str(path)
|
||||||
|
+ str(len(features_image_stacked))
|
||||||
|
+ "_"
|
||||||
|
+ model_type
|
||||||
|
+ "_"
|
||||||
|
+ name,
|
||||||
|
"wb",
|
||||||
) as f:
|
) as f:
|
||||||
torch.save(features_image_stacked, f)
|
torch.save(features_image_stacked, f)
|
||||||
return name
|
return name
|
||||||
|
|
||||||
def load_tensors(self, name="saved_features_image.pt"):
|
def load_tensors(self, name):
|
||||||
features_image_stacked = torch.load(name)
|
features_image_stacked = torch.load(name)
|
||||||
return features_image_stacked
|
return features_image_stacked
|
||||||
|
|
||||||
@ -136,7 +149,12 @@ class MultimodalSearch(AnalysisMethod):
|
|||||||
|
|
||||||
return features_text
|
return features_text
|
||||||
|
|
||||||
def parsing_images(self, model_type, path_to_saved_tensors=None):
|
def parsing_images(
|
||||||
|
self,
|
||||||
|
model_type,
|
||||||
|
path_to_saved_tensors="./saved_tensors/",
|
||||||
|
path_to_load_tensors=None,
|
||||||
|
):
|
||||||
if model_type in ("clip_base", "clip_vitl14_336", "clip_vitl14"):
|
if model_type in ("clip_base", "clip_vitl14_336", "clip_vitl14"):
|
||||||
path_to_lib = lavis.__file__[:-11] + "models/clip_models/"
|
path_to_lib = lavis.__file__[:-11] + "models/clip_models/"
|
||||||
url = "https://raw.githubusercontent.com/salesforce/LAVIS/main/lavis/models/clip_models/bpe_simple_vocab_16e6.txt.gz"
|
url = "https://raw.githubusercontent.com/salesforce/LAVIS/main/lavis/models/clip_models/bpe_simple_vocab_16e6.txt.gz"
|
||||||
@ -180,15 +198,17 @@ class MultimodalSearch(AnalysisMethod):
|
|||||||
_, images_tensors = MultimodalSearch.read_and_process_images(
|
_, images_tensors = MultimodalSearch.read_and_process_images(
|
||||||
self, image_names, vis_processors
|
self, image_names, vis_processors
|
||||||
)
|
)
|
||||||
if path_to_saved_tensors is None:
|
if path_to_load_tensors is None:
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
features_image_stacked = select_extract_image_features[model_type](
|
features_image_stacked = select_extract_image_features[model_type](
|
||||||
self, model, images_tensors
|
self, model, images_tensors
|
||||||
)
|
)
|
||||||
MultimodalSearch.save_tensors(self, model_type, features_image_stacked)
|
MultimodalSearch.save_tensors(
|
||||||
|
self, model_type, features_image_stacked, path=path_to_saved_tensors
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
features_image_stacked = MultimodalSearch.load_tensors(
|
features_image_stacked = MultimodalSearch.load_tensors(
|
||||||
self, str(path_to_saved_tensors)
|
self, str(path_to_load_tensors)
|
||||||
)
|
)
|
||||||
|
|
||||||
return (
|
return (
|
||||||
@ -303,7 +323,6 @@ class MultimodalSearch(AnalysisMethod):
|
|||||||
|
|
||||||
for q in range(len(search_query)):
|
for q in range(len(search_query)):
|
||||||
max_val = similarity[sorted_lists[q][0]][q].item()
|
max_val = similarity[sorted_lists[q][0]][q].item()
|
||||||
print(max_val)
|
|
||||||
for i, key in zip(range(len(image_keys)), sorted_lists[q]):
|
for i, key in zip(range(len(image_keys)), sorted_lists[q]):
|
||||||
if (
|
if (
|
||||||
i < filter_number_of_images
|
i < filter_number_of_images
|
||||||
|
|||||||
@ -354,6 +354,7 @@ def test_parsing_images(
|
|||||||
pre_extracted_feature_text,
|
pre_extracted_feature_text,
|
||||||
pre_simularity,
|
pre_simularity,
|
||||||
pre_sorted,
|
pre_sorted,
|
||||||
|
tmp_path,
|
||||||
):
|
):
|
||||||
ms.MultimodalSearch.multimodal_device = pre_multimodal_device
|
ms.MultimodalSearch.multimodal_device = pre_multimodal_device
|
||||||
(
|
(
|
||||||
@ -363,7 +364,9 @@ def test_parsing_images(
|
|||||||
image_keys,
|
image_keys,
|
||||||
_,
|
_,
|
||||||
features_image_stacked,
|
features_image_stacked,
|
||||||
) = ms.MultimodalSearch.parsing_images(testdict, pre_model)
|
) = ms.MultimodalSearch.parsing_images(
|
||||||
|
testdict, pre_model, path_to_saved_tensors=tmp_path
|
||||||
|
)
|
||||||
|
|
||||||
for i, num in zip(range(10), features_image_stacked[0, 10:12].tolist()):
|
for i, num in zip(range(10), features_image_stacked[0, 10:12].tolist()):
|
||||||
assert (
|
assert (
|
||||||
|
|||||||
52
notebooks/multimodal_search.ipynb
сгенерированный
52
notebooks/multimodal_search.ipynb
сгенерированный
@ -20,7 +20,9 @@
|
|||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": null,
|
||||||
"id": "f10ad6c9-b1a0-4043-8c5d-ed660d77be37",
|
"id": "f10ad6c9-b1a0-4043-8c5d-ed660d77be37",
|
||||||
"metadata": {},
|
"metadata": {
|
||||||
|
"tags": []
|
||||||
|
},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"import misinformation\n",
|
"import misinformation\n",
|
||||||
@ -39,7 +41,9 @@
|
|||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": null,
|
||||||
"id": "8d3fe589-ff3c-4575-b8f5-650db85596bc",
|
"id": "8d3fe589-ff3c-4575-b8f5-650db85596bc",
|
||||||
"metadata": {},
|
"metadata": {
|
||||||
|
"tags": []
|
||||||
|
},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"images = misinformation.utils.find_files(\n",
|
"images = misinformation.utils.find_files(\n",
|
||||||
@ -52,7 +56,9 @@
|
|||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": null,
|
||||||
"id": "adf3db21-1f8b-4d44-bbef-ef0acf4623a0",
|
"id": "adf3db21-1f8b-4d44-bbef-ef0acf4623a0",
|
||||||
"metadata": {},
|
"metadata": {
|
||||||
|
"tags": []
|
||||||
|
},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"mydict = misinformation.utils.initialize_dict(images)"
|
"mydict = misinformation.utils.initialize_dict(images)"
|
||||||
@ -62,7 +68,9 @@
|
|||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": null,
|
||||||
"id": "d98b6227-886d-41b8-a377-896dd8ab3c2a",
|
"id": "d98b6227-886d-41b8-a377-896dd8ab3c2a",
|
||||||
"metadata": {},
|
"metadata": {
|
||||||
|
"tags": []
|
||||||
|
},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"mydict"
|
"mydict"
|
||||||
@ -88,7 +96,9 @@
|
|||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": null,
|
||||||
"id": "7bbca1f0-d4b0-43cd-8e05-ee39d37c328e",
|
"id": "7bbca1f0-d4b0-43cd-8e05-ee39d37c328e",
|
||||||
"metadata": {},
|
"metadata": {
|
||||||
|
"tags": []
|
||||||
|
},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"model_type = \"blip\"\n",
|
"model_type = \"blip\"\n",
|
||||||
@ -103,7 +113,9 @@
|
|||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": null,
|
||||||
"id": "ca095404-57d0-4f5d-aeb0-38c232252b17",
|
"id": "ca095404-57d0-4f5d-aeb0-38c232252b17",
|
||||||
"metadata": {},
|
"metadata": {
|
||||||
|
"tags": []
|
||||||
|
},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"(\n",
|
"(\n",
|
||||||
@ -113,7 +125,9 @@
|
|||||||
" image_keys,\n",
|
" image_keys,\n",
|
||||||
" image_names,\n",
|
" image_names,\n",
|
||||||
" features_image_stacked,\n",
|
" features_image_stacked,\n",
|
||||||
") = ms.MultimodalSearch.parsing_images(mydict, model_type)"
|
") = ms.MultimodalSearch.parsing_images(\n",
|
||||||
|
" mydict, model_type, path_to_saved_tensors=\"./saved_tensors/\"\n",
|
||||||
|
")"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -128,17 +142,23 @@
|
|||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": null,
|
||||||
"id": "56c6d488-f093-4661-835a-5c73a329c874",
|
"id": "56c6d488-f093-4661-835a-5c73a329c874",
|
||||||
"metadata": {},
|
"metadata": {
|
||||||
|
"tags": []
|
||||||
|
},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"# (\n",
|
"# (\n",
|
||||||
"# model,\n",
|
"# model,\n",
|
||||||
"# vis_processors,\n",
|
"# vis_processors,\n",
|
||||||
"# txt_processors,\n",
|
"# txt_processors,\n",
|
||||||
"# image_keys,\n",
|
"# image_keys,\n",
|
||||||
"# image_names,\n",
|
"# image_names,\n",
|
||||||
"# features_image_stacked,\n",
|
"# features_image_stacked,\n",
|
||||||
"# ) = ms.MultimodalSearch.parsing_images(mydict, model_type,\"18_clip_base_saved_features_image.pt\")"
|
"# ) = ms.MultimodalSearch.parsing_images(\n",
|
||||||
|
"# mydict,\n",
|
||||||
|
"# model_type,\n",
|
||||||
|
"# path_to_load_tensors=\"./saved_tensors/18_blip_saved_features_image.pt\",\n",
|
||||||
|
"# )"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -328,7 +348,7 @@
|
|||||||
"name": "python",
|
"name": "python",
|
||||||
"nbconvert_exporter": "python",
|
"nbconvert_exporter": "python",
|
||||||
"pygments_lexer": "ipython3",
|
"pygments_lexer": "ipython3",
|
||||||
"version": "3.10.8"
|
"version": "3.9.0"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"nbformat": 4,
|
"nbformat": 4,
|
||||||
|
|||||||
Загрузка…
x
Ссылка в новой задаче
Block a user