fixed saving path, added tmp_path (#67)

Этот коммит содержится в:
Petr Andriushchenko 2023-03-31 21:36:58 +02:00 коммит произвёл GitHub
родитель f89fa4e519
Коммит 14a1a03597
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
3 изменённых файлов: 67 добавлений и 25 удалений

Просмотреть файл

@ -3,6 +3,7 @@ import torch
import torch.nn.functional as Func
import requests
import lavis
import os
from PIL import Image
from IPython.display import display
from lavis.models import load_model_and_preprocess
@ -118,15 +119,27 @@ class MultimodalSearch(AnalysisMethod):
return features_image_stacked
def save_tensors(
self, model_type, features_image_stacked, name="saved_features_image.pt"
self,
model_type,
features_image_stacked,
name="saved_features_image.pt",
path="./saved_tensors/",
):
if not os.path.exists(path):
os.makedirs(path)
with open(
str(len(features_image_stacked)) + "_" + model_type + "_" + name, "wb"
str(path)
+ str(len(features_image_stacked))
+ "_"
+ model_type
+ "_"
+ name,
"wb",
) as f:
torch.save(features_image_stacked, f)
return name
def load_tensors(self, name="saved_features_image.pt"):
def load_tensors(self, name):
features_image_stacked = torch.load(name)
return features_image_stacked
@ -136,7 +149,12 @@ class MultimodalSearch(AnalysisMethod):
return features_text
def parsing_images(self, model_type, path_to_saved_tensors=None):
def parsing_images(
self,
model_type,
path_to_saved_tensors="./saved_tensors/",
path_to_load_tensors=None,
):
if model_type in ("clip_base", "clip_vitl14_336", "clip_vitl14"):
path_to_lib = lavis.__file__[:-11] + "models/clip_models/"
url = "https://raw.githubusercontent.com/salesforce/LAVIS/main/lavis/models/clip_models/bpe_simple_vocab_16e6.txt.gz"
@ -180,15 +198,17 @@ class MultimodalSearch(AnalysisMethod):
_, images_tensors = MultimodalSearch.read_and_process_images(
self, image_names, vis_processors
)
if path_to_saved_tensors is None:
if path_to_load_tensors is None:
with torch.no_grad():
features_image_stacked = select_extract_image_features[model_type](
self, model, images_tensors
)
MultimodalSearch.save_tensors(self, model_type, features_image_stacked)
MultimodalSearch.save_tensors(
self, model_type, features_image_stacked, path=path_to_saved_tensors
)
else:
features_image_stacked = MultimodalSearch.load_tensors(
self, str(path_to_saved_tensors)
self, str(path_to_load_tensors)
)
return (
@ -303,7 +323,6 @@ class MultimodalSearch(AnalysisMethod):
for q in range(len(search_query)):
max_val = similarity[sorted_lists[q][0]][q].item()
print(max_val)
for i, key in zip(range(len(image_keys)), sorted_lists[q]):
if (
i < filter_number_of_images

Просмотреть файл

@ -354,6 +354,7 @@ def test_parsing_images(
pre_extracted_feature_text,
pre_simularity,
pre_sorted,
tmp_path,
):
ms.MultimodalSearch.multimodal_device = pre_multimodal_device
(
@ -363,7 +364,9 @@ def test_parsing_images(
image_keys,
_,
features_image_stacked,
) = ms.MultimodalSearch.parsing_images(testdict, pre_model)
) = ms.MultimodalSearch.parsing_images(
testdict, pre_model, path_to_saved_tensors=tmp_path
)
for i, num in zip(range(10), features_image_stacked[0, 10:12].tolist()):
assert (

52
notebooks/multimodal_search.ipynb сгенерированный
Просмотреть файл

@ -20,7 +20,9 @@
"cell_type": "code",
"execution_count": null,
"id": "f10ad6c9-b1a0-4043-8c5d-ed660d77be37",
"metadata": {},
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"import misinformation\n",
@ -39,7 +41,9 @@
"cell_type": "code",
"execution_count": null,
"id": "8d3fe589-ff3c-4575-b8f5-650db85596bc",
"metadata": {},
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"images = misinformation.utils.find_files(\n",
@ -52,7 +56,9 @@
"cell_type": "code",
"execution_count": null,
"id": "adf3db21-1f8b-4d44-bbef-ef0acf4623a0",
"metadata": {},
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"mydict = misinformation.utils.initialize_dict(images)"
@ -62,7 +68,9 @@
"cell_type": "code",
"execution_count": null,
"id": "d98b6227-886d-41b8-a377-896dd8ab3c2a",
"metadata": {},
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"mydict"
@ -88,7 +96,9 @@
"cell_type": "code",
"execution_count": null,
"id": "7bbca1f0-d4b0-43cd-8e05-ee39d37c328e",
"metadata": {},
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"model_type = \"blip\"\n",
@ -103,7 +113,9 @@
"cell_type": "code",
"execution_count": null,
"id": "ca095404-57d0-4f5d-aeb0-38c232252b17",
"metadata": {},
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"(\n",
@ -113,7 +125,9 @@
" image_keys,\n",
" image_names,\n",
" features_image_stacked,\n",
") = ms.MultimodalSearch.parsing_images(mydict, model_type)"
") = ms.MultimodalSearch.parsing_images(\n",
" mydict, model_type, path_to_saved_tensors=\"./saved_tensors/\"\n",
")"
]
},
{
@ -128,17 +142,23 @@
"cell_type": "code",
"execution_count": null,
"id": "56c6d488-f093-4661-835a-5c73a329c874",
"metadata": {},
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"# (\n",
"# model,\n",
"# vis_processors,\n",
"# txt_processors,\n",
"# image_keys,\n",
"# image_names,\n",
"# features_image_stacked,\n",
"# ) = ms.MultimodalSearch.parsing_images(mydict, model_type,\"18_clip_base_saved_features_image.pt\")"
"# model,\n",
"# vis_processors,\n",
"# txt_processors,\n",
"# image_keys,\n",
"# image_names,\n",
"# features_image_stacked,\n",
"# ) = ms.MultimodalSearch.parsing_images(\n",
"# mydict,\n",
"# model_type,\n",
"# path_to_load_tensors=\"./saved_tensors/18_blip_saved_features_image.pt\",\n",
"# )"
]
},
{
@ -328,7 +348,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.8"
"version": "3.9.0"
}
},
"nbformat": 4,