diff --git a/misinformation/multimodal_search.py b/misinformation/multimodal_search.py index 908aea1..7294ed4 100644 --- a/misinformation/multimodal_search.py +++ b/misinformation/multimodal_search.py @@ -15,7 +15,7 @@ class MultimodalSearch(AnalysisMethod): multimodal_device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - def load_feature_extractor_model_blip2(device): + def load_feature_extractor_model_blip2(self, device): model, vis_processors, txt_processors = load_model_and_preprocess( name="blip2_feature_extractor", model_type="pretrain", @@ -24,7 +24,7 @@ class MultimodalSearch(AnalysisMethod): ) return model, vis_processors, txt_processors - def load_feature_extractor_model_blip(device): + def load_feature_extractor_model_blip(self, device): model, vis_processors, txt_processors = load_model_and_preprocess( name="blip_feature_extractor", model_type="base", @@ -33,7 +33,7 @@ class MultimodalSearch(AnalysisMethod): ) return model, vis_processors, txt_processors - def load_feature_extractor_model_albef(device): + def load_feature_extractor_model_albef(self, device): model, vis_processors, txt_processors = load_model_and_preprocess( name="albef_feature_extractor", model_type="base", @@ -42,7 +42,7 @@ class MultimodalSearch(AnalysisMethod): ) return model, vis_processors, txt_processors - def load_feature_extractor_model_clip_base(device): + def load_feature_extractor_model_clip_base(self, device): model, vis_processors, txt_processors = load_model_and_preprocess( name="clip_feature_extractor", model_type="base", @@ -51,7 +51,7 @@ class MultimodalSearch(AnalysisMethod): ) return model, vis_processors, txt_processors - def load_feature_extractor_model_clip_vitl14(device): + def load_feature_extractor_model_clip_vitl14(self, device): model, vis_processors, txt_processors = load_model_and_preprocess( name="clip_feature_extractor", model_type="ViT-L-14", @@ -60,7 +60,7 @@ class MultimodalSearch(AnalysisMethod): ) return model, vis_processors, txt_processors - def load_feature_extractor_model_clip_vitl14_336(device): + def load_feature_extractor_model_clip_vitl14_336(self, device): model, vis_processors, txt_processors = load_model_and_preprocess( name="clip_feature_extractor", model_type="ViT-L-14-336", @@ -69,12 +69,12 @@ class MultimodalSearch(AnalysisMethod): ) return model, vis_processors, txt_processors - def read_img(filepath): + def read_img(self, filepath): raw_image = Image.open(filepath).convert("RGB") return raw_image - def read_and_process_images(image_paths, vis_processor): - raw_images = [MultimodalSearch.read_img(path) for path in image_paths] + def read_and_process_images(self, image_paths, vis_processor): + raw_images = [MultimodalSearch.read_img(self, path) for path in image_paths] images = [ vis_processor["eval"](r_img) .unsqueeze(0) @@ -85,7 +85,7 @@ class MultimodalSearch(AnalysisMethod): return raw_images, images_tensors - def extract_image_features_blip2(model, images_tensors): + def extract_image_features_blip2(self, model, images_tensors): with torch.cuda.amp.autocast( enabled=(MultimodalSearch.multimodal_device != torch.device("cpu")) ): @@ -98,7 +98,7 @@ class MultimodalSearch(AnalysisMethod): ) return features_image_stacked - def extract_image_features_clip(model, images_tensors): + def extract_image_features_clip(self, model, images_tensors): features_image = [ model.extract_features({"image": ten}) for ten in images_tensors ] @@ -107,7 +107,7 @@ class MultimodalSearch(AnalysisMethod): ) return features_image_stacked - def extract_image_features_basic(model, images_tensors): + def extract_image_features_basic(self, model, images_tensors): features_image = [ model.extract_features({"image": ten, "text_input": ""}, mode="image") for ten in images_tensors @@ -118,7 +118,7 @@ class MultimodalSearch(AnalysisMethod): return features_image_stacked def save_tensors( - model_type, features_image_stacked, name="saved_features_image.pt" + self, model_type, features_image_stacked, name="saved_features_image.pt" ): with open( str(len(features_image_stacked)) + "_" + model_type + "_" + name, "wb" @@ -126,11 +126,11 @@ class MultimodalSearch(AnalysisMethod): torch.save(features_image_stacked, f) return name - def load_tensors(name="saved_features_image.pt"): + def load_tensors(self, name="saved_features_image.pt"): features_image_stacked = torch.load(name) return features_image_stacked - def extract_text_features(model, text_input): + def extract_text_features(self, model, text_input): sample_text = {"text_input": [text_input]} features_text = model.extract_features(sample_text, mode="text") @@ -168,24 +168,24 @@ class MultimodalSearch(AnalysisMethod): if model_type in select_model.keys(): (model, vis_processors, txt_processors,) = select_model[ model_type - ](MultimodalSearch.multimodal_device) + ](self, MultimodalSearch.multimodal_device) else: raise SyntaxError( "Please, use one of the following models: blip2, blip, albef, clip_base, clip_vitl14, clip_vitl14_336" ) raw_images, images_tensors = MultimodalSearch.read_and_process_images( - image_names, vis_processors + self, image_names, vis_processors ) if path_to_saved_tensors is None: with torch.no_grad(): features_image_stacked = select_extract_image_features[model_type]( - model, images_tensors + self, model, images_tensors ) - MultimodalSearch.save_tensors(model_type, features_image_stacked) + MultimodalSearch.save_tensors(self, model_type, features_image_stacked) else: features_image_stacked = MultimodalSearch.load_tensors( - str(path_to_saved_tensors) + self, str(path_to_saved_tensors) ) return ( @@ -222,7 +222,7 @@ class MultimodalSearch(AnalysisMethod): images_tensors = "" elif "image" in query.keys(): _, images_tensors = MultimodalSearch.read_and_process_images( - [query["image"]], vis_processors + self, [query["image"]], vis_processors ) text_processing = "" multi_sample.append( @@ -253,7 +253,9 @@ class MultimodalSearch(AnalysisMethod): multi_features_query.append(features_squeeze) if query["text_input"] == "": multi_features_query.append( - select_extract_image_features[model_type](model, query["image"]) + select_extract_image_features[model_type]( + self, model, query["image"] + ) ) multi_features_stacked = torch.stack( @@ -280,7 +282,7 @@ class MultimodalSearch(AnalysisMethod): ) similarity = features_image_stacked @ multi_features_stacked.t() - similarity_soft_max = torch.nn.Softmax(dim=0)(similarity / 0.01) + # similarity_soft_max = torch.nn.Softmax(dim=0)(similarity / 0.01) sorted_lists = [ sorted(range(len(similarity)), key=lambda k: similarity[k, i], reverse=True) for i in range(len(similarity[0])) @@ -292,7 +294,7 @@ class MultimodalSearch(AnalysisMethod): self[key]["rank " + list(search_query[q].values())[0]] = places[q][i] self[key][list(search_query[q].values())[0]] = similarity[i][q].item() - return similarity, similarity_soft_max + return similarity, sorted_lists def show_results(self, query): if "image" in query.keys():