зеркало из
https://github.com/ssciwr/AMMICO.git
synced 2025-10-29 13:06:04 +02:00
fix code smells
Этот коммит содержится в:
родитель
4e4b7fac75
Коммит
c208039b7c
@ -15,7 +15,7 @@ class MultimodalSearch(AnalysisMethod):
|
||||
|
||||
multimodal_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
def load_feature_extractor_model_blip2(device):
|
||||
def load_feature_extractor_model_blip2(self, device):
|
||||
model, vis_processors, txt_processors = load_model_and_preprocess(
|
||||
name="blip2_feature_extractor",
|
||||
model_type="pretrain",
|
||||
@ -24,7 +24,7 @@ class MultimodalSearch(AnalysisMethod):
|
||||
)
|
||||
return model, vis_processors, txt_processors
|
||||
|
||||
def load_feature_extractor_model_blip(device):
|
||||
def load_feature_extractor_model_blip(self, device):
|
||||
model, vis_processors, txt_processors = load_model_and_preprocess(
|
||||
name="blip_feature_extractor",
|
||||
model_type="base",
|
||||
@ -33,7 +33,7 @@ class MultimodalSearch(AnalysisMethod):
|
||||
)
|
||||
return model, vis_processors, txt_processors
|
||||
|
||||
def load_feature_extractor_model_albef(device):
|
||||
def load_feature_extractor_model_albef(self, device):
|
||||
model, vis_processors, txt_processors = load_model_and_preprocess(
|
||||
name="albef_feature_extractor",
|
||||
model_type="base",
|
||||
@ -42,7 +42,7 @@ class MultimodalSearch(AnalysisMethod):
|
||||
)
|
||||
return model, vis_processors, txt_processors
|
||||
|
||||
def load_feature_extractor_model_clip_base(device):
|
||||
def load_feature_extractor_model_clip_base(self, device):
|
||||
model, vis_processors, txt_processors = load_model_and_preprocess(
|
||||
name="clip_feature_extractor",
|
||||
model_type="base",
|
||||
@ -51,7 +51,7 @@ class MultimodalSearch(AnalysisMethod):
|
||||
)
|
||||
return model, vis_processors, txt_processors
|
||||
|
||||
def load_feature_extractor_model_clip_vitl14(device):
|
||||
def load_feature_extractor_model_clip_vitl14(self, device):
|
||||
model, vis_processors, txt_processors = load_model_and_preprocess(
|
||||
name="clip_feature_extractor",
|
||||
model_type="ViT-L-14",
|
||||
@ -60,7 +60,7 @@ class MultimodalSearch(AnalysisMethod):
|
||||
)
|
||||
return model, vis_processors, txt_processors
|
||||
|
||||
def load_feature_extractor_model_clip_vitl14_336(device):
|
||||
def load_feature_extractor_model_clip_vitl14_336(self, device):
|
||||
model, vis_processors, txt_processors = load_model_and_preprocess(
|
||||
name="clip_feature_extractor",
|
||||
model_type="ViT-L-14-336",
|
||||
@ -69,12 +69,12 @@ class MultimodalSearch(AnalysisMethod):
|
||||
)
|
||||
return model, vis_processors, txt_processors
|
||||
|
||||
def read_img(filepath):
|
||||
def read_img(self, filepath):
|
||||
raw_image = Image.open(filepath).convert("RGB")
|
||||
return raw_image
|
||||
|
||||
def read_and_process_images(image_paths, vis_processor):
|
||||
raw_images = [MultimodalSearch.read_img(path) for path in image_paths]
|
||||
def read_and_process_images(self, image_paths, vis_processor):
|
||||
raw_images = [MultimodalSearch.read_img(self, path) for path in image_paths]
|
||||
images = [
|
||||
vis_processor["eval"](r_img)
|
||||
.unsqueeze(0)
|
||||
@ -85,7 +85,7 @@ class MultimodalSearch(AnalysisMethod):
|
||||
|
||||
return raw_images, images_tensors
|
||||
|
||||
def extract_image_features_blip2(model, images_tensors):
|
||||
def extract_image_features_blip2(self, model, images_tensors):
|
||||
with torch.cuda.amp.autocast(
|
||||
enabled=(MultimodalSearch.multimodal_device != torch.device("cpu"))
|
||||
):
|
||||
@ -98,7 +98,7 @@ class MultimodalSearch(AnalysisMethod):
|
||||
)
|
||||
return features_image_stacked
|
||||
|
||||
def extract_image_features_clip(model, images_tensors):
|
||||
def extract_image_features_clip(self, model, images_tensors):
|
||||
features_image = [
|
||||
model.extract_features({"image": ten}) for ten in images_tensors
|
||||
]
|
||||
@ -107,7 +107,7 @@ class MultimodalSearch(AnalysisMethod):
|
||||
)
|
||||
return features_image_stacked
|
||||
|
||||
def extract_image_features_basic(model, images_tensors):
|
||||
def extract_image_features_basic(self, model, images_tensors):
|
||||
features_image = [
|
||||
model.extract_features({"image": ten, "text_input": ""}, mode="image")
|
||||
for ten in images_tensors
|
||||
@ -118,7 +118,7 @@ class MultimodalSearch(AnalysisMethod):
|
||||
return features_image_stacked
|
||||
|
||||
def save_tensors(
|
||||
model_type, features_image_stacked, name="saved_features_image.pt"
|
||||
self, model_type, features_image_stacked, name="saved_features_image.pt"
|
||||
):
|
||||
with open(
|
||||
str(len(features_image_stacked)) + "_" + model_type + "_" + name, "wb"
|
||||
@ -126,11 +126,11 @@ class MultimodalSearch(AnalysisMethod):
|
||||
torch.save(features_image_stacked, f)
|
||||
return name
|
||||
|
||||
def load_tensors(name="saved_features_image.pt"):
|
||||
def load_tensors(self, name="saved_features_image.pt"):
|
||||
features_image_stacked = torch.load(name)
|
||||
return features_image_stacked
|
||||
|
||||
def extract_text_features(model, text_input):
|
||||
def extract_text_features(self, model, text_input):
|
||||
sample_text = {"text_input": [text_input]}
|
||||
features_text = model.extract_features(sample_text, mode="text")
|
||||
|
||||
@ -168,24 +168,24 @@ class MultimodalSearch(AnalysisMethod):
|
||||
if model_type in select_model.keys():
|
||||
(model, vis_processors, txt_processors,) = select_model[
|
||||
model_type
|
||||
](MultimodalSearch.multimodal_device)
|
||||
](self, MultimodalSearch.multimodal_device)
|
||||
else:
|
||||
raise SyntaxError(
|
||||
"Please, use one of the following models: blip2, blip, albef, clip_base, clip_vitl14, clip_vitl14_336"
|
||||
)
|
||||
|
||||
raw_images, images_tensors = MultimodalSearch.read_and_process_images(
|
||||
image_names, vis_processors
|
||||
self, image_names, vis_processors
|
||||
)
|
||||
if path_to_saved_tensors is None:
|
||||
with torch.no_grad():
|
||||
features_image_stacked = select_extract_image_features[model_type](
|
||||
model, images_tensors
|
||||
self, model, images_tensors
|
||||
)
|
||||
MultimodalSearch.save_tensors(model_type, features_image_stacked)
|
||||
MultimodalSearch.save_tensors(self, model_type, features_image_stacked)
|
||||
else:
|
||||
features_image_stacked = MultimodalSearch.load_tensors(
|
||||
str(path_to_saved_tensors)
|
||||
self, str(path_to_saved_tensors)
|
||||
)
|
||||
|
||||
return (
|
||||
@ -222,7 +222,7 @@ class MultimodalSearch(AnalysisMethod):
|
||||
images_tensors = ""
|
||||
elif "image" in query.keys():
|
||||
_, images_tensors = MultimodalSearch.read_and_process_images(
|
||||
[query["image"]], vis_processors
|
||||
self, [query["image"]], vis_processors
|
||||
)
|
||||
text_processing = ""
|
||||
multi_sample.append(
|
||||
@ -253,7 +253,9 @@ class MultimodalSearch(AnalysisMethod):
|
||||
multi_features_query.append(features_squeeze)
|
||||
if query["text_input"] == "":
|
||||
multi_features_query.append(
|
||||
select_extract_image_features[model_type](model, query["image"])
|
||||
select_extract_image_features[model_type](
|
||||
self, model, query["image"]
|
||||
)
|
||||
)
|
||||
|
||||
multi_features_stacked = torch.stack(
|
||||
@ -280,7 +282,7 @@ class MultimodalSearch(AnalysisMethod):
|
||||
)
|
||||
|
||||
similarity = features_image_stacked @ multi_features_stacked.t()
|
||||
similarity_soft_max = torch.nn.Softmax(dim=0)(similarity / 0.01)
|
||||
# similarity_soft_max = torch.nn.Softmax(dim=0)(similarity / 0.01)
|
||||
sorted_lists = [
|
||||
sorted(range(len(similarity)), key=lambda k: similarity[k, i], reverse=True)
|
||||
for i in range(len(similarity[0]))
|
||||
@ -292,7 +294,7 @@ class MultimodalSearch(AnalysisMethod):
|
||||
self[key]["rank " + list(search_query[q].values())[0]] = places[q][i]
|
||||
self[key][list(search_query[q].values())[0]] = similarity[i][q].item()
|
||||
|
||||
return similarity, similarity_soft_max
|
||||
return similarity, sorted_lists
|
||||
|
||||
def show_results(self, query):
|
||||
if "image" in query.keys():
|
||||
|
||||
Загрузка…
x
Ссылка в новой задаче
Block a user