From 4e4b7fac75fa2e7e853a737f66ffe1552b050844 Mon Sep 17 00:00:00 2001 From: Petr Andriushchenko Date: Wed, 22 Feb 2023 14:22:58 +0100 Subject: [PATCH] fix if-else, added clip ViT-L-14=336 model --- misinformation/multimodal_search.py | 251 +++++++++++++++------------- notebooks/multimodal_search.ipynb | 29 +++- 2 files changed, 160 insertions(+), 120 deletions(-) diff --git a/misinformation/multimodal_search.py b/misinformation/multimodal_search.py index 6c051ce..908aea1 100644 --- a/misinformation/multimodal_search.py +++ b/misinformation/multimodal_search.py @@ -15,54 +15,58 @@ class MultimodalSearch(AnalysisMethod): multimodal_device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - def load_feature_extractor_model(device, model_type): - if model_type == "blip2": - model, vis_processors, txt_processors = load_model_and_preprocess( - name="blip2_feature_extractor", - model_type="pretrain", - is_eval=True, - device=device, - ) - elif model_type == "blip": - model, vis_processors, txt_processors = load_model_and_preprocess( - name="blip_feature_extractor", - model_type="base", - is_eval=True, - device=device, - ) - elif model_type == "albef": - model, vis_processors, txt_processors = load_model_and_preprocess( - name="albef_feature_extractor", - model_type="base", - is_eval=True, - device=device, - ) - elif model_type == "clip_base": - model, vis_processors, txt_processors = load_model_and_preprocess( - name="clip_feature_extractor", - model_type="base", - is_eval=True, - device=device, - ) - elif model_type == "clip_rn50": - model, vis_processors, txt_processors = load_model_and_preprocess( - name="clip_feature_extractor", - model_type="RN50", - is_eval=True, - device=device, - ) - elif model_type == "clip_vitl14": - model, vis_processors, txt_processors = load_model_and_preprocess( - name="clip_feature_extractor", - model_type="ViT-L-14", - is_eval=True, - device=device, - ) - else: - print( - "Please, use one of the following models: blip2, blip, albef, clip_base, clip_rn50, clip_vitl14" - ) + def load_feature_extractor_model_blip2(device): + model, vis_processors, txt_processors = load_model_and_preprocess( + name="blip2_feature_extractor", + model_type="pretrain", + is_eval=True, + device=device, + ) + return model, vis_processors, txt_processors + def load_feature_extractor_model_blip(device): + model, vis_processors, txt_processors = load_model_and_preprocess( + name="blip_feature_extractor", + model_type="base", + is_eval=True, + device=device, + ) + return model, vis_processors, txt_processors + + def load_feature_extractor_model_albef(device): + model, vis_processors, txt_processors = load_model_and_preprocess( + name="albef_feature_extractor", + model_type="base", + is_eval=True, + device=device, + ) + return model, vis_processors, txt_processors + + def load_feature_extractor_model_clip_base(device): + model, vis_processors, txt_processors = load_model_and_preprocess( + name="clip_feature_extractor", + model_type="base", + is_eval=True, + device=device, + ) + return model, vis_processors, txt_processors + + def load_feature_extractor_model_clip_vitl14(device): + model, vis_processors, txt_processors = load_model_and_preprocess( + name="clip_feature_extractor", + model_type="ViT-L-14", + is_eval=True, + device=device, + ) + return model, vis_processors, txt_processors + + def load_feature_extractor_model_clip_vitl14_336(device): + model, vis_processors, txt_processors = load_model_and_preprocess( + name="clip_feature_extractor", + model_type="ViT-L-14-336", + is_eval=True, + device=device, + ) return model, vis_processors, txt_processors def read_img(filepath): @@ -81,34 +85,10 @@ class MultimodalSearch(AnalysisMethod): return raw_images, images_tensors - def extract_image_features(model, images_tensors, model_type): - if model_type == "blip2": - with torch.cuda.amp.autocast( - enabled=(MultimodalSearch.multimodal_device != torch.device("cpu")) - ): - features_image = [ - model.extract_features( - {"image": ten, "text_input": ""}, mode="image" - ) - for ten in images_tensors - ] - features_image_stacked = torch.stack( - [ - feat.image_embeds_proj[:, 0, :].squeeze(0) - for feat in features_image - ] - ) - elif model_type in ("clip_base", "clip_rn50", "clip_vitl14"): - features_image = [ - model.extract_features({"image": ten}) for ten in images_tensors - ] - features_image_stacked = torch.stack( - [ - Func.normalize(feat.float(), dim=-1).squeeze(0) - for feat in features_image - ] - ) - else: + def extract_image_features_blip2(model, images_tensors): + with torch.cuda.amp.autocast( + enabled=(MultimodalSearch.multimodal_device != torch.device("cpu")) + ): features_image = [ model.extract_features({"image": ten, "text_input": ""}, mode="image") for ten in images_tensors @@ -118,6 +98,25 @@ class MultimodalSearch(AnalysisMethod): ) return features_image_stacked + def extract_image_features_clip(model, images_tensors): + features_image = [ + model.extract_features({"image": ten}) for ten in images_tensors + ] + features_image_stacked = torch.stack( + [Func.normalize(feat.float(), dim=-1).squeeze(0) for feat in features_image] + ) + return features_image_stacked + + def extract_image_features_basic(model, images_tensors): + features_image = [ + model.extract_features({"image": ten, "text_input": ""}, mode="image") + for ten in images_tensors + ] + features_image_stacked = torch.stack( + [feat.image_embeds_proj[:, 0, :].squeeze(0) for feat in features_image] + ) + return features_image_stacked + def save_tensors( model_type, features_image_stacked, name="saved_features_image.pt" ): @@ -137,9 +136,9 @@ class MultimodalSearch(AnalysisMethod): return features_text - def parsing_images(self, model_type): + def parsing_images(self, model_type, path_to_saved_tensors=None): - if model_type in ("clip_base", "clip_rn50", "clip_vitl14"): + if model_type in ("clip_base", "clip_vitl14_336", "clip_vitl14"): path_to_lib = lavis.__file__[:-11] + "models/clip_models/" url = "https://raw.githubusercontent.com/salesforce/LAVIS/main/lavis/models/clip_models/bpe_simple_vocab_16e6.txt.gz" r = requests.get(url, allow_redirects=False) @@ -148,20 +147,46 @@ class MultimodalSearch(AnalysisMethod): image_keys = sorted(self.keys()) image_names = [self[k]["filename"] for k in image_keys] - ( - model, - vis_processors, - txt_processors, - ) = MultimodalSearch.load_feature_extractor_model( - MultimodalSearch.multimodal_device, model_type - ) + select_model = { + "blip2": MultimodalSearch.load_feature_extractor_model_blip2, + "blip": MultimodalSearch.load_feature_extractor_model_blip, + "albef": MultimodalSearch.load_feature_extractor_model_albef, + "clip_base": MultimodalSearch.load_feature_extractor_model_clip_base, + "clip_vitl14": MultimodalSearch.load_feature_extractor_model_clip_vitl14, + "clip_vitl14_336": MultimodalSearch.load_feature_extractor_model_clip_vitl14_336, + } + + select_extract_image_features = { + "blip2": MultimodalSearch.extract_image_features_blip2, + "blip": MultimodalSearch.extract_image_features_basic, + "albef": MultimodalSearch.extract_image_features_basic, + "clip_base": MultimodalSearch.extract_image_features_clip, + "clip_vitl14": MultimodalSearch.extract_image_features_clip, + "clip_vitl14_336": MultimodalSearch.extract_image_features_clip, + } + + if model_type in select_model.keys(): + (model, vis_processors, txt_processors,) = select_model[ + model_type + ](MultimodalSearch.multimodal_device) + else: + raise SyntaxError( + "Please, use one of the following models: blip2, blip, albef, clip_base, clip_vitl14, clip_vitl14_336" + ) + raw_images, images_tensors = MultimodalSearch.read_and_process_images( image_names, vis_processors ) - features_image_stacked = MultimodalSearch.extract_image_features( - model, images_tensors, model_type - ) - MultimodalSearch.save_tensors(model_type, features_image_stacked) + if path_to_saved_tensors is None: + with torch.no_grad(): + features_image_stacked = select_extract_image_features[model_type]( + model, images_tensors + ) + MultimodalSearch.save_tensors(model_type, features_image_stacked) + else: + features_image_stacked = MultimodalSearch.load_tensors( + str(path_to_saved_tensors) + ) return ( model, @@ -175,6 +200,16 @@ class MultimodalSearch(AnalysisMethod): def querys_processing( self, search_query, model, txt_processors, vis_processors, model_type ): + + select_extract_image_features = { + "blip2": MultimodalSearch.extract_image_features_blip2, + "blip": MultimodalSearch.extract_image_features_basic, + "albef": MultimodalSearch.extract_image_features_basic, + "clip_base": MultimodalSearch.extract_image_features_clip, + "clip_vitl14": MultimodalSearch.extract_image_features_clip, + "clip_vitl14_336": MultimodalSearch.extract_image_features_clip, + } + for query in search_query: if not (len(query) == 1) and (query in ("image", "text_input")): raise SyntaxError( @@ -194,10 +229,10 @@ class MultimodalSearch(AnalysisMethod): {"image": images_tensors, "text_input": text_processing} ) - if model_type in ("clip_base", "clip_rn50", "clip_vitl14"): - multi_features_query = [] - for query in multi_sample: - if query["image"] == "": + multi_features_query = [] + for query in multi_sample: + if query["image"] == "": + if model_type in ("clip_base", "clip_vitl14_336", "clip_vitl14"): features = model.extract_features( {"text_input": query["text_input"]} ) @@ -208,17 +243,7 @@ class MultimodalSearch(AnalysisMethod): multi_features_query.append( Func.normalize(features_squeeze, dim=-1) ) - if query["text_input"] == "": - multi_features_query.append( - MultimodalSearch.extract_image_features( - model, query["image"], model_type - ) - ) - - else: - multi_features_query = [] - for query in multi_sample: - if query["image"] == "": + else: features = model.extract_features(query, mode="text") features_squeeze = ( features.text_embeds_proj[:, 0, :] @@ -226,12 +251,10 @@ class MultimodalSearch(AnalysisMethod): .to(MultimodalSearch.multimodal_device) ) multi_features_query.append(features_squeeze) - if query["text_input"] == "": - multi_features_query.append( - MultimodalSearch.extract_image_features( - model, query["image"], model_type - ) - ) + if query["text_input"] == "": + multi_features_query.append( + select_extract_image_features[model_type](model, query["image"]) + ) multi_features_stacked = torch.stack( [query.squeeze(0) for query in multi_features_query] @@ -251,11 +274,13 @@ class MultimodalSearch(AnalysisMethod): ): features_image_stacked.to(MultimodalSearch.multimodal_device) - multi_features_stacked = MultimodalSearch.querys_processing( - self, search_query, model, txt_processors, vis_processors, model_type - ) + with torch.no_grad(): + multi_features_stacked = MultimodalSearch.querys_processing( + self, search_query, model, txt_processors, vis_processors, model_type + ) similarity = features_image_stacked @ multi_features_stacked.t() + similarity_soft_max = torch.nn.Softmax(dim=0)(similarity / 0.01) sorted_lists = [ sorted(range(len(similarity)), key=lambda k: similarity[k, i], reverse=True) for i in range(len(similarity[0])) @@ -267,7 +292,7 @@ class MultimodalSearch(AnalysisMethod): self[key]["rank " + list(search_query[q].values())[0]] = places[q][i] self[key][list(search_query[q].values())[0]] = similarity[i][q].item() - return similarity + return similarity, similarity_soft_max def show_results(self, query): if "image" in query.keys(): diff --git a/notebooks/multimodal_search.ipynb b/notebooks/multimodal_search.ipynb index cb7687c..a0a4788 100644 --- a/notebooks/multimodal_search.ipynb +++ b/notebooks/multimodal_search.ipynb @@ -81,7 +81,7 @@ "id": "66d6ede4-00bc-4aeb-9a36-e52d7de33fe5", "metadata": {}, "source": [ - "You can choose one of the following models: blip, blip2, albef, clip_base, clip_rn50, clip_vitl14" + "You can choose one of the following models: blip, blip2, albef, clip_base, clip_vitl14, clip_vitl14_336" ] }, { @@ -91,7 +91,7 @@ "metadata": {}, "outputs": [], "source": [ - "model_type = \"blip\"" + "model_type = \"clip_vitl14_336\"" ] }, { @@ -116,17 +116,32 @@ "id": "9ff8a894-566b-4c4f-acca-21c50b5b1f52", "metadata": {}, "source": [ - "The of all images `features_image_stacked` was saved in `saved_features_image.pt`. If you run it once for current model and set of images you do not need to repeat it again. Instead you can load this features with the command:" + "The tensors of all images `features_image_stacked` was saved in `__saved_features_image.pt`. If you run it once for current model and current set of images you do not need to repeat it again. Instead you can load this features with the command:" ] }, { "cell_type": "code", "execution_count": null, - "id": "c40e93f0-6bea-4886-b904-8b46ed6ec819", + "id": "56c6d488-f093-4661-835a-5c73a329c874", "metadata": {}, "outputs": [], "source": [ - "# features_image_stacked = ms.MultimodalSearch.load_tensors('saved_features_image.pt')" + "# (\n", + "# model,\n", + "# vis_processors,\n", + "# txt_processors,\n", + "# image_keys,\n", + "# image_names,\n", + "# features_image_stacked,\n", + "# ) = ms.MultimodalSearch.parsing_images(mydict, model_type,\"18_clip_base_saved_features_image.pt\")" + ] + }, + { + "cell_type": "markdown", + "id": "309923c1-d6f8-4424-8fca-bde5f3a98b38", + "metadata": {}, + "source": [ + "Here we already processed our image folder with 18 images with `clip_base` model. So you need just write the name `18_clip_base_saved_features_image.pt` of the saved file that consists of tensors of all images as a 3rd argument to the previous function. " ] }, { @@ -170,7 +185,7 @@ " image_keys,\n", " features_image_stacked,\n", " search_query3,\n", - ");" + ")" ] }, { @@ -206,7 +221,7 @@ "metadata": {}, "outputs": [], "source": [ - "ms.MultimodalSearch.show_results(mydict, search_query3[0])" + "ms.MultimodalSearch.show_results(mydict, search_query3[4])" ] }, {