зеркало из
https://github.com/ssciwr/AMMICO.git
synced 2025-10-29 21:16:06 +02:00
fixed saving path, added tmp_path (#67)
Этот коммит содержится в:
родитель
f89fa4e519
Коммит
14a1a03597
@ -3,6 +3,7 @@ import torch
|
||||
import torch.nn.functional as Func
|
||||
import requests
|
||||
import lavis
|
||||
import os
|
||||
from PIL import Image
|
||||
from IPython.display import display
|
||||
from lavis.models import load_model_and_preprocess
|
||||
@ -118,15 +119,27 @@ class MultimodalSearch(AnalysisMethod):
|
||||
return features_image_stacked
|
||||
|
||||
def save_tensors(
|
||||
self, model_type, features_image_stacked, name="saved_features_image.pt"
|
||||
self,
|
||||
model_type,
|
||||
features_image_stacked,
|
||||
name="saved_features_image.pt",
|
||||
path="./saved_tensors/",
|
||||
):
|
||||
if not os.path.exists(path):
|
||||
os.makedirs(path)
|
||||
with open(
|
||||
str(len(features_image_stacked)) + "_" + model_type + "_" + name, "wb"
|
||||
str(path)
|
||||
+ str(len(features_image_stacked))
|
||||
+ "_"
|
||||
+ model_type
|
||||
+ "_"
|
||||
+ name,
|
||||
"wb",
|
||||
) as f:
|
||||
torch.save(features_image_stacked, f)
|
||||
return name
|
||||
|
||||
def load_tensors(self, name="saved_features_image.pt"):
|
||||
def load_tensors(self, name):
|
||||
features_image_stacked = torch.load(name)
|
||||
return features_image_stacked
|
||||
|
||||
@ -136,7 +149,12 @@ class MultimodalSearch(AnalysisMethod):
|
||||
|
||||
return features_text
|
||||
|
||||
def parsing_images(self, model_type, path_to_saved_tensors=None):
|
||||
def parsing_images(
|
||||
self,
|
||||
model_type,
|
||||
path_to_saved_tensors="./saved_tensors/",
|
||||
path_to_load_tensors=None,
|
||||
):
|
||||
if model_type in ("clip_base", "clip_vitl14_336", "clip_vitl14"):
|
||||
path_to_lib = lavis.__file__[:-11] + "models/clip_models/"
|
||||
url = "https://raw.githubusercontent.com/salesforce/LAVIS/main/lavis/models/clip_models/bpe_simple_vocab_16e6.txt.gz"
|
||||
@ -180,15 +198,17 @@ class MultimodalSearch(AnalysisMethod):
|
||||
_, images_tensors = MultimodalSearch.read_and_process_images(
|
||||
self, image_names, vis_processors
|
||||
)
|
||||
if path_to_saved_tensors is None:
|
||||
if path_to_load_tensors is None:
|
||||
with torch.no_grad():
|
||||
features_image_stacked = select_extract_image_features[model_type](
|
||||
self, model, images_tensors
|
||||
)
|
||||
MultimodalSearch.save_tensors(self, model_type, features_image_stacked)
|
||||
MultimodalSearch.save_tensors(
|
||||
self, model_type, features_image_stacked, path=path_to_saved_tensors
|
||||
)
|
||||
else:
|
||||
features_image_stacked = MultimodalSearch.load_tensors(
|
||||
self, str(path_to_saved_tensors)
|
||||
self, str(path_to_load_tensors)
|
||||
)
|
||||
|
||||
return (
|
||||
@ -303,7 +323,6 @@ class MultimodalSearch(AnalysisMethod):
|
||||
|
||||
for q in range(len(search_query)):
|
||||
max_val = similarity[sorted_lists[q][0]][q].item()
|
||||
print(max_val)
|
||||
for i, key in zip(range(len(image_keys)), sorted_lists[q]):
|
||||
if (
|
||||
i < filter_number_of_images
|
||||
|
||||
@ -354,6 +354,7 @@ def test_parsing_images(
|
||||
pre_extracted_feature_text,
|
||||
pre_simularity,
|
||||
pre_sorted,
|
||||
tmp_path,
|
||||
):
|
||||
ms.MultimodalSearch.multimodal_device = pre_multimodal_device
|
||||
(
|
||||
@ -363,7 +364,9 @@ def test_parsing_images(
|
||||
image_keys,
|
||||
_,
|
||||
features_image_stacked,
|
||||
) = ms.MultimodalSearch.parsing_images(testdict, pre_model)
|
||||
) = ms.MultimodalSearch.parsing_images(
|
||||
testdict, pre_model, path_to_saved_tensors=tmp_path
|
||||
)
|
||||
|
||||
for i, num in zip(range(10), features_image_stacked[0, 10:12].tolist()):
|
||||
assert (
|
||||
|
||||
40
notebooks/multimodal_search.ipynb
сгенерированный
40
notebooks/multimodal_search.ipynb
сгенерированный
@ -20,7 +20,9 @@
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "f10ad6c9-b1a0-4043-8c5d-ed660d77be37",
|
||||
"metadata": {},
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import misinformation\n",
|
||||
@ -39,7 +41,9 @@
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "8d3fe589-ff3c-4575-b8f5-650db85596bc",
|
||||
"metadata": {},
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"images = misinformation.utils.find_files(\n",
|
||||
@ -52,7 +56,9 @@
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "adf3db21-1f8b-4d44-bbef-ef0acf4623a0",
|
||||
"metadata": {},
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"mydict = misinformation.utils.initialize_dict(images)"
|
||||
@ -62,7 +68,9 @@
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "d98b6227-886d-41b8-a377-896dd8ab3c2a",
|
||||
"metadata": {},
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"mydict"
|
||||
@ -88,7 +96,9 @@
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "7bbca1f0-d4b0-43cd-8e05-ee39d37c328e",
|
||||
"metadata": {},
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"model_type = \"blip\"\n",
|
||||
@ -103,7 +113,9 @@
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "ca095404-57d0-4f5d-aeb0-38c232252b17",
|
||||
"metadata": {},
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"(\n",
|
||||
@ -113,7 +125,9 @@
|
||||
" image_keys,\n",
|
||||
" image_names,\n",
|
||||
" features_image_stacked,\n",
|
||||
") = ms.MultimodalSearch.parsing_images(mydict, model_type)"
|
||||
") = ms.MultimodalSearch.parsing_images(\n",
|
||||
" mydict, model_type, path_to_saved_tensors=\"./saved_tensors/\"\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -128,7 +142,9 @@
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "56c6d488-f093-4661-835a-5c73a329c874",
|
||||
"metadata": {},
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# (\n",
|
||||
@ -138,7 +154,11 @@
|
||||
"# image_keys,\n",
|
||||
"# image_names,\n",
|
||||
"# features_image_stacked,\n",
|
||||
"# ) = ms.MultimodalSearch.parsing_images(mydict, model_type,\"18_clip_base_saved_features_image.pt\")"
|
||||
"# ) = ms.MultimodalSearch.parsing_images(\n",
|
||||
"# mydict,\n",
|
||||
"# model_type,\n",
|
||||
"# path_to_load_tensors=\"./saved_tensors/18_blip_saved_features_image.pt\",\n",
|
||||
"# )"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -328,7 +348,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.8"
|
||||
"version": "3.9.0"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
||||
Загрузка…
x
Ссылка в новой задаче
Block a user