From 14a1a035973a7ab0ddbbe978a419af2d20cb0aed Mon Sep 17 00:00:00 2001 From: Petr Andriushchenko Date: Fri, 31 Mar 2023 21:36:58 +0200 Subject: [PATCH] fixed saving path, added tmp_path (#67) --- misinformation/multimodal_search.py | 35 ++++++++++--- misinformation/test/test_multimodal_search.py | 5 +- notebooks/multimodal_search.ipynb | 52 +++++++++++++------ 3 files changed, 67 insertions(+), 25 deletions(-) diff --git a/misinformation/multimodal_search.py b/misinformation/multimodal_search.py index 7b47a0b..2b002de 100644 --- a/misinformation/multimodal_search.py +++ b/misinformation/multimodal_search.py @@ -3,6 +3,7 @@ import torch import torch.nn.functional as Func import requests import lavis +import os from PIL import Image from IPython.display import display from lavis.models import load_model_and_preprocess @@ -118,15 +119,27 @@ class MultimodalSearch(AnalysisMethod): return features_image_stacked def save_tensors( - self, model_type, features_image_stacked, name="saved_features_image.pt" + self, + model_type, + features_image_stacked, + name="saved_features_image.pt", + path="./saved_tensors/", ): + if not os.path.exists(path): + os.makedirs(path) with open( - str(len(features_image_stacked)) + "_" + model_type + "_" + name, "wb" + str(path) + + str(len(features_image_stacked)) + + "_" + + model_type + + "_" + + name, + "wb", ) as f: torch.save(features_image_stacked, f) return name - def load_tensors(self, name="saved_features_image.pt"): + def load_tensors(self, name): features_image_stacked = torch.load(name) return features_image_stacked @@ -136,7 +149,12 @@ class MultimodalSearch(AnalysisMethod): return features_text - def parsing_images(self, model_type, path_to_saved_tensors=None): + def parsing_images( + self, + model_type, + path_to_saved_tensors="./saved_tensors/", + path_to_load_tensors=None, + ): if model_type in ("clip_base", "clip_vitl14_336", "clip_vitl14"): path_to_lib = lavis.__file__[:-11] + "models/clip_models/" url = "https://raw.githubusercontent.com/salesforce/LAVIS/main/lavis/models/clip_models/bpe_simple_vocab_16e6.txt.gz" @@ -180,15 +198,17 @@ class MultimodalSearch(AnalysisMethod): _, images_tensors = MultimodalSearch.read_and_process_images( self, image_names, vis_processors ) - if path_to_saved_tensors is None: + if path_to_load_tensors is None: with torch.no_grad(): features_image_stacked = select_extract_image_features[model_type]( self, model, images_tensors ) - MultimodalSearch.save_tensors(self, model_type, features_image_stacked) + MultimodalSearch.save_tensors( + self, model_type, features_image_stacked, path=path_to_saved_tensors + ) else: features_image_stacked = MultimodalSearch.load_tensors( - self, str(path_to_saved_tensors) + self, str(path_to_load_tensors) ) return ( @@ -303,7 +323,6 @@ class MultimodalSearch(AnalysisMethod): for q in range(len(search_query)): max_val = similarity[sorted_lists[q][0]][q].item() - print(max_val) for i, key in zip(range(len(image_keys)), sorted_lists[q]): if ( i < filter_number_of_images diff --git a/misinformation/test/test_multimodal_search.py b/misinformation/test/test_multimodal_search.py index 4aaaccf..98ce575 100644 --- a/misinformation/test/test_multimodal_search.py +++ b/misinformation/test/test_multimodal_search.py @@ -354,6 +354,7 @@ def test_parsing_images( pre_extracted_feature_text, pre_simularity, pre_sorted, + tmp_path, ): ms.MultimodalSearch.multimodal_device = pre_multimodal_device ( @@ -363,7 +364,9 @@ def test_parsing_images( image_keys, _, features_image_stacked, - ) = ms.MultimodalSearch.parsing_images(testdict, pre_model) + ) = ms.MultimodalSearch.parsing_images( + testdict, pre_model, path_to_saved_tensors=tmp_path + ) for i, num in zip(range(10), features_image_stacked[0, 10:12].tolist()): assert ( diff --git a/notebooks/multimodal_search.ipynb b/notebooks/multimodal_search.ipynb index 24664a7..8f40d15 100644 --- a/notebooks/multimodal_search.ipynb +++ b/notebooks/multimodal_search.ipynb @@ -20,7 +20,9 @@ "cell_type": "code", "execution_count": null, "id": "f10ad6c9-b1a0-4043-8c5d-ed660d77be37", - "metadata": {}, + "metadata": { + "tags": [] + }, "outputs": [], "source": [ "import misinformation\n", @@ -39,7 +41,9 @@ "cell_type": "code", "execution_count": null, "id": "8d3fe589-ff3c-4575-b8f5-650db85596bc", - "metadata": {}, + "metadata": { + "tags": [] + }, "outputs": [], "source": [ "images = misinformation.utils.find_files(\n", @@ -52,7 +56,9 @@ "cell_type": "code", "execution_count": null, "id": "adf3db21-1f8b-4d44-bbef-ef0acf4623a0", - "metadata": {}, + "metadata": { + "tags": [] + }, "outputs": [], "source": [ "mydict = misinformation.utils.initialize_dict(images)" @@ -62,7 +68,9 @@ "cell_type": "code", "execution_count": null, "id": "d98b6227-886d-41b8-a377-896dd8ab3c2a", - "metadata": {}, + "metadata": { + "tags": [] + }, "outputs": [], "source": [ "mydict" @@ -88,7 +96,9 @@ "cell_type": "code", "execution_count": null, "id": "7bbca1f0-d4b0-43cd-8e05-ee39d37c328e", - "metadata": {}, + "metadata": { + "tags": [] + }, "outputs": [], "source": [ "model_type = \"blip\"\n", @@ -103,7 +113,9 @@ "cell_type": "code", "execution_count": null, "id": "ca095404-57d0-4f5d-aeb0-38c232252b17", - "metadata": {}, + "metadata": { + "tags": [] + }, "outputs": [], "source": [ "(\n", @@ -113,7 +125,9 @@ " image_keys,\n", " image_names,\n", " features_image_stacked,\n", - ") = ms.MultimodalSearch.parsing_images(mydict, model_type)" + ") = ms.MultimodalSearch.parsing_images(\n", + " mydict, model_type, path_to_saved_tensors=\"./saved_tensors/\"\n", + ")" ] }, { @@ -128,17 +142,23 @@ "cell_type": "code", "execution_count": null, "id": "56c6d488-f093-4661-835a-5c73a329c874", - "metadata": {}, + "metadata": { + "tags": [] + }, "outputs": [], "source": [ "# (\n", - "# model,\n", - "# vis_processors,\n", - "# txt_processors,\n", - "# image_keys,\n", - "# image_names,\n", - "# features_image_stacked,\n", - "# ) = ms.MultimodalSearch.parsing_images(mydict, model_type,\"18_clip_base_saved_features_image.pt\")" + "# model,\n", + "# vis_processors,\n", + "# txt_processors,\n", + "# image_keys,\n", + "# image_names,\n", + "# features_image_stacked,\n", + "# ) = ms.MultimodalSearch.parsing_images(\n", + "# mydict,\n", + "# model_type,\n", + "# path_to_load_tensors=\"./saved_tensors/18_blip_saved_features_image.pt\",\n", + "# )" ] }, { @@ -328,7 +348,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.8" + "version": "3.9.0" } }, "nbformat": 4,