fixed saving path, added tmp_path (#67)

Этот коммит содержится в:
Petr Andriushchenko 2023-03-31 21:36:58 +02:00 коммит произвёл GitHub
родитель f89fa4e519
Коммит 14a1a03597
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
3 изменённых файлов: 67 добавлений и 25 удалений

Просмотреть файл

@ -3,6 +3,7 @@ import torch
import torch.nn.functional as Func import torch.nn.functional as Func
import requests import requests
import lavis import lavis
import os
from PIL import Image from PIL import Image
from IPython.display import display from IPython.display import display
from lavis.models import load_model_and_preprocess from lavis.models import load_model_and_preprocess
@ -118,15 +119,27 @@ class MultimodalSearch(AnalysisMethod):
return features_image_stacked return features_image_stacked
def save_tensors( def save_tensors(
self, model_type, features_image_stacked, name="saved_features_image.pt" self,
model_type,
features_image_stacked,
name="saved_features_image.pt",
path="./saved_tensors/",
): ):
if not os.path.exists(path):
os.makedirs(path)
with open( with open(
str(len(features_image_stacked)) + "_" + model_type + "_" + name, "wb" str(path)
+ str(len(features_image_stacked))
+ "_"
+ model_type
+ "_"
+ name,
"wb",
) as f: ) as f:
torch.save(features_image_stacked, f) torch.save(features_image_stacked, f)
return name return name
def load_tensors(self, name="saved_features_image.pt"): def load_tensors(self, name):
features_image_stacked = torch.load(name) features_image_stacked = torch.load(name)
return features_image_stacked return features_image_stacked
@ -136,7 +149,12 @@ class MultimodalSearch(AnalysisMethod):
return features_text return features_text
def parsing_images(self, model_type, path_to_saved_tensors=None): def parsing_images(
self,
model_type,
path_to_saved_tensors="./saved_tensors/",
path_to_load_tensors=None,
):
if model_type in ("clip_base", "clip_vitl14_336", "clip_vitl14"): if model_type in ("clip_base", "clip_vitl14_336", "clip_vitl14"):
path_to_lib = lavis.__file__[:-11] + "models/clip_models/" path_to_lib = lavis.__file__[:-11] + "models/clip_models/"
url = "https://raw.githubusercontent.com/salesforce/LAVIS/main/lavis/models/clip_models/bpe_simple_vocab_16e6.txt.gz" url = "https://raw.githubusercontent.com/salesforce/LAVIS/main/lavis/models/clip_models/bpe_simple_vocab_16e6.txt.gz"
@ -180,15 +198,17 @@ class MultimodalSearch(AnalysisMethod):
_, images_tensors = MultimodalSearch.read_and_process_images( _, images_tensors = MultimodalSearch.read_and_process_images(
self, image_names, vis_processors self, image_names, vis_processors
) )
if path_to_saved_tensors is None: if path_to_load_tensors is None:
with torch.no_grad(): with torch.no_grad():
features_image_stacked = select_extract_image_features[model_type]( features_image_stacked = select_extract_image_features[model_type](
self, model, images_tensors self, model, images_tensors
) )
MultimodalSearch.save_tensors(self, model_type, features_image_stacked) MultimodalSearch.save_tensors(
self, model_type, features_image_stacked, path=path_to_saved_tensors
)
else: else:
features_image_stacked = MultimodalSearch.load_tensors( features_image_stacked = MultimodalSearch.load_tensors(
self, str(path_to_saved_tensors) self, str(path_to_load_tensors)
) )
return ( return (
@ -303,7 +323,6 @@ class MultimodalSearch(AnalysisMethod):
for q in range(len(search_query)): for q in range(len(search_query)):
max_val = similarity[sorted_lists[q][0]][q].item() max_val = similarity[sorted_lists[q][0]][q].item()
print(max_val)
for i, key in zip(range(len(image_keys)), sorted_lists[q]): for i, key in zip(range(len(image_keys)), sorted_lists[q]):
if ( if (
i < filter_number_of_images i < filter_number_of_images

Просмотреть файл

@ -354,6 +354,7 @@ def test_parsing_images(
pre_extracted_feature_text, pre_extracted_feature_text,
pre_simularity, pre_simularity,
pre_sorted, pre_sorted,
tmp_path,
): ):
ms.MultimodalSearch.multimodal_device = pre_multimodal_device ms.MultimodalSearch.multimodal_device = pre_multimodal_device
( (
@ -363,7 +364,9 @@ def test_parsing_images(
image_keys, image_keys,
_, _,
features_image_stacked, features_image_stacked,
) = ms.MultimodalSearch.parsing_images(testdict, pre_model) ) = ms.MultimodalSearch.parsing_images(
testdict, pre_model, path_to_saved_tensors=tmp_path
)
for i, num in zip(range(10), features_image_stacked[0, 10:12].tolist()): for i, num in zip(range(10), features_image_stacked[0, 10:12].tolist()):
assert ( assert (

52
notebooks/multimodal_search.ipynb сгенерированный
Просмотреть файл

@ -20,7 +20,9 @@
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"id": "f10ad6c9-b1a0-4043-8c5d-ed660d77be37", "id": "f10ad6c9-b1a0-4043-8c5d-ed660d77be37",
"metadata": {}, "metadata": {
"tags": []
},
"outputs": [], "outputs": [],
"source": [ "source": [
"import misinformation\n", "import misinformation\n",
@ -39,7 +41,9 @@
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"id": "8d3fe589-ff3c-4575-b8f5-650db85596bc", "id": "8d3fe589-ff3c-4575-b8f5-650db85596bc",
"metadata": {}, "metadata": {
"tags": []
},
"outputs": [], "outputs": [],
"source": [ "source": [
"images = misinformation.utils.find_files(\n", "images = misinformation.utils.find_files(\n",
@ -52,7 +56,9 @@
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"id": "adf3db21-1f8b-4d44-bbef-ef0acf4623a0", "id": "adf3db21-1f8b-4d44-bbef-ef0acf4623a0",
"metadata": {}, "metadata": {
"tags": []
},
"outputs": [], "outputs": [],
"source": [ "source": [
"mydict = misinformation.utils.initialize_dict(images)" "mydict = misinformation.utils.initialize_dict(images)"
@ -62,7 +68,9 @@
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"id": "d98b6227-886d-41b8-a377-896dd8ab3c2a", "id": "d98b6227-886d-41b8-a377-896dd8ab3c2a",
"metadata": {}, "metadata": {
"tags": []
},
"outputs": [], "outputs": [],
"source": [ "source": [
"mydict" "mydict"
@ -88,7 +96,9 @@
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"id": "7bbca1f0-d4b0-43cd-8e05-ee39d37c328e", "id": "7bbca1f0-d4b0-43cd-8e05-ee39d37c328e",
"metadata": {}, "metadata": {
"tags": []
},
"outputs": [], "outputs": [],
"source": [ "source": [
"model_type = \"blip\"\n", "model_type = \"blip\"\n",
@ -103,7 +113,9 @@
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"id": "ca095404-57d0-4f5d-aeb0-38c232252b17", "id": "ca095404-57d0-4f5d-aeb0-38c232252b17",
"metadata": {}, "metadata": {
"tags": []
},
"outputs": [], "outputs": [],
"source": [ "source": [
"(\n", "(\n",
@ -113,7 +125,9 @@
" image_keys,\n", " image_keys,\n",
" image_names,\n", " image_names,\n",
" features_image_stacked,\n", " features_image_stacked,\n",
") = ms.MultimodalSearch.parsing_images(mydict, model_type)" ") = ms.MultimodalSearch.parsing_images(\n",
" mydict, model_type, path_to_saved_tensors=\"./saved_tensors/\"\n",
")"
] ]
}, },
{ {
@ -128,17 +142,23 @@
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"id": "56c6d488-f093-4661-835a-5c73a329c874", "id": "56c6d488-f093-4661-835a-5c73a329c874",
"metadata": {}, "metadata": {
"tags": []
},
"outputs": [], "outputs": [],
"source": [ "source": [
"# (\n", "# (\n",
"# model,\n", "# model,\n",
"# vis_processors,\n", "# vis_processors,\n",
"# txt_processors,\n", "# txt_processors,\n",
"# image_keys,\n", "# image_keys,\n",
"# image_names,\n", "# image_names,\n",
"# features_image_stacked,\n", "# features_image_stacked,\n",
"# ) = ms.MultimodalSearch.parsing_images(mydict, model_type,\"18_clip_base_saved_features_image.pt\")" "# ) = ms.MultimodalSearch.parsing_images(\n",
"# mydict,\n",
"# model_type,\n",
"# path_to_load_tensors=\"./saved_tensors/18_blip_saved_features_image.pt\",\n",
"# )"
] ]
}, },
{ {
@ -328,7 +348,7 @@
"name": "python", "name": "python",
"nbconvert_exporter": "python", "nbconvert_exporter": "python",
"pygments_lexer": "ipython3", "pygments_lexer": "ipython3",
"version": "3.10.8" "version": "3.9.0"
} }
}, },
"nbformat": 4, "nbformat": 4,