зеркало из
https://github.com/ssciwr/AMMICO.git
synced 2025-10-30 21:46:04 +02:00
Wiriting docs to multimodal_search.py
Этот коммит содержится в:
родитель
ca3994c72f
Коммит
f524ecaad8
@ -20,7 +20,18 @@ class MultimodalSearch(AnalysisMethod):
|
|||||||
|
|
||||||
multimodal_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
multimodal_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
|
||||||
def load_feature_extractor_model_blip2(self, device):
|
def load_feature_extractor_model_blip2(self, device: str = "cpu"):
|
||||||
|
"""
|
||||||
|
Load pretrain blip2_feature_extractor model and preprocessors for visual and text inputs from lavis.models.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
device (str): device to use. Can be "cpu" or "cuda". Default: "cpu".
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
model (torch.nn.Module): model.
|
||||||
|
vis_processors (dict): preprocessors for visual inputs.
|
||||||
|
txt_processors (dict): preprocessors for text inputs.
|
||||||
|
"""
|
||||||
model, vis_processors, txt_processors = load_model_and_preprocess(
|
model, vis_processors, txt_processors = load_model_and_preprocess(
|
||||||
name="blip2_feature_extractor",
|
name="blip2_feature_extractor",
|
||||||
model_type="pretrain",
|
model_type="pretrain",
|
||||||
@ -29,7 +40,18 @@ class MultimodalSearch(AnalysisMethod):
|
|||||||
)
|
)
|
||||||
return model, vis_processors, txt_processors
|
return model, vis_processors, txt_processors
|
||||||
|
|
||||||
def load_feature_extractor_model_blip(self, device):
|
def load_feature_extractor_model_blip(self, device: str = "cpu"):
|
||||||
|
"""
|
||||||
|
Load base blip_feature_extractor model and preprocessors for visual and text inputs from lavis.models.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
device (str): device to use. Can be "cpu" or "cuda". Default: "cpu".
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
model (torch.nn.Module): model.
|
||||||
|
vis_processors (dict): preprocessors for visual inputs.
|
||||||
|
txt_processors (dict): preprocessors for text inputs.
|
||||||
|
"""
|
||||||
model, vis_processors, txt_processors = load_model_and_preprocess(
|
model, vis_processors, txt_processors = load_model_and_preprocess(
|
||||||
name="blip_feature_extractor",
|
name="blip_feature_extractor",
|
||||||
model_type="base",
|
model_type="base",
|
||||||
@ -38,7 +60,18 @@ class MultimodalSearch(AnalysisMethod):
|
|||||||
)
|
)
|
||||||
return model, vis_processors, txt_processors
|
return model, vis_processors, txt_processors
|
||||||
|
|
||||||
def load_feature_extractor_model_albef(self, device):
|
def load_feature_extractor_model_albef(self, device: str = "cpu"):
|
||||||
|
"""
|
||||||
|
Load base albef_feature_extractor model and preprocessors for visual and text inputs from lavis.models.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
device (str): device to use. Can be "cpu" or "cuda". Default: "cpu".
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
model (torch.nn.Module): model.
|
||||||
|
vis_processors (dict): preprocessors for visual inputs.
|
||||||
|
txt_processors (dict): preprocessors for text inputs.
|
||||||
|
"""
|
||||||
model, vis_processors, txt_processors = load_model_and_preprocess(
|
model, vis_processors, txt_processors = load_model_and_preprocess(
|
||||||
name="albef_feature_extractor",
|
name="albef_feature_extractor",
|
||||||
model_type="base",
|
model_type="base",
|
||||||
@ -47,7 +80,18 @@ class MultimodalSearch(AnalysisMethod):
|
|||||||
)
|
)
|
||||||
return model, vis_processors, txt_processors
|
return model, vis_processors, txt_processors
|
||||||
|
|
||||||
def load_feature_extractor_model_clip_base(self, device):
|
def load_feature_extractor_model_clip_base(self, device: str = "cpu"):
|
||||||
|
"""
|
||||||
|
Load base clip_feature_extractor model and preprocessors for visual and text inputs from lavis.models.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
device (str): device to use. Can be "cpu" or "cuda". Default: "cpu".
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
model (torch.nn.Module): model.
|
||||||
|
vis_processors (dict): preprocessors for visual inputs.
|
||||||
|
txt_processors (dict): preprocessors for text inputs.
|
||||||
|
"""
|
||||||
model, vis_processors, txt_processors = load_model_and_preprocess(
|
model, vis_processors, txt_processors = load_model_and_preprocess(
|
||||||
name="clip_feature_extractor",
|
name="clip_feature_extractor",
|
||||||
model_type="base",
|
model_type="base",
|
||||||
@ -56,7 +100,18 @@ class MultimodalSearch(AnalysisMethod):
|
|||||||
)
|
)
|
||||||
return model, vis_processors, txt_processors
|
return model, vis_processors, txt_processors
|
||||||
|
|
||||||
def load_feature_extractor_model_clip_vitl14(self, device):
|
def load_feature_extractor_model_clip_vitl14(self, device: str = "cpu"):
|
||||||
|
"""
|
||||||
|
Load ViT-L-14 clip_feature_extractor model and preprocessors for visual and text inputs from lavis.models.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
device (str): device to use. Can be "cpu" or "cuda". Default: "cpu".
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
model (torch.nn.Module): model.
|
||||||
|
vis_processors (dict): preprocessors for visual inputs.
|
||||||
|
txt_processors (dict): preprocessors for text inputs.
|
||||||
|
"""
|
||||||
model, vis_processors, txt_processors = load_model_and_preprocess(
|
model, vis_processors, txt_processors = load_model_and_preprocess(
|
||||||
name="clip_feature_extractor",
|
name="clip_feature_extractor",
|
||||||
model_type="ViT-L-14",
|
model_type="ViT-L-14",
|
||||||
@ -65,7 +120,18 @@ class MultimodalSearch(AnalysisMethod):
|
|||||||
)
|
)
|
||||||
return model, vis_processors, txt_processors
|
return model, vis_processors, txt_processors
|
||||||
|
|
||||||
def load_feature_extractor_model_clip_vitl14_336(self, device):
|
def load_feature_extractor_model_clip_vitl14_336(self, device: str = "cpu"):
|
||||||
|
"""
|
||||||
|
Load ViT-L-14-336 clip_feature_extractor model and preprocessors for visual and text inputs from lavis.models.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
device (str): device to use. Can be "cpu" or "cuda". Default: "cpu".
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
model (torch.nn.Module): model.
|
||||||
|
vis_processors (dict): preprocessors for visual inputs.
|
||||||
|
txt_processors (dict): preprocessors for text inputs.
|
||||||
|
"""
|
||||||
model, vis_processors, txt_processors = load_model_and_preprocess(
|
model, vis_processors, txt_processors = load_model_and_preprocess(
|
||||||
name="clip_feature_extractor",
|
name="clip_feature_extractor",
|
||||||
model_type="ViT-L-14-336",
|
model_type="ViT-L-14-336",
|
||||||
@ -74,11 +140,31 @@ class MultimodalSearch(AnalysisMethod):
|
|||||||
)
|
)
|
||||||
return model, vis_processors, txt_processors
|
return model, vis_processors, txt_processors
|
||||||
|
|
||||||
def read_img(self, filepath):
|
def read_img(self, filepath: str) -> Image:
|
||||||
|
"""
|
||||||
|
Load Image from filepath.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
filepath (str): path to image.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
raw_image (PIL.Image): image.
|
||||||
|
"""
|
||||||
raw_image = Image.open(filepath).convert("RGB")
|
raw_image = Image.open(filepath).convert("RGB")
|
||||||
return raw_image
|
return raw_image
|
||||||
|
|
||||||
def read_and_process_images(self, image_paths, vis_processor):
|
def read_and_process_images(self, image_paths: str, vis_processor) -> tuple:
|
||||||
|
"""
|
||||||
|
Read and process images with vis_processor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image_paths (str): paths to images.
|
||||||
|
vis_processor (dict): preprocessors for visual inputs.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
raw_images (list): list of images.
|
||||||
|
images_tensors (torch.Tensor): tensors of images stacked in device.
|
||||||
|
"""
|
||||||
raw_images = [MultimodalSearch.read_img(self, path) for path in image_paths]
|
raw_images = [MultimodalSearch.read_img(self, path) for path in image_paths]
|
||||||
images = [
|
images = [
|
||||||
vis_processor["eval"](r_img)
|
vis_processor["eval"](r_img)
|
||||||
@ -90,7 +176,19 @@ class MultimodalSearch(AnalysisMethod):
|
|||||||
|
|
||||||
return raw_images, images_tensors
|
return raw_images, images_tensors
|
||||||
|
|
||||||
def extract_image_features_blip2(self, model, images_tensors):
|
def extract_image_features_blip2(
|
||||||
|
self, model, images_tensors: torch.Tensor
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Extract image features from images_tensors with blip2_feature_extractor model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model (torch.nn.Module): model.
|
||||||
|
images_tensors (torch.Tensor): tensors of images stacked in device.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
features_image_stacked (torch.Tensor): tensors of images features stacked in device.
|
||||||
|
"""
|
||||||
with torch.cuda.amp.autocast(
|
with torch.cuda.amp.autocast(
|
||||||
enabled=(MultimodalSearch.multimodal_device != torch.device("cpu"))
|
enabled=(MultimodalSearch.multimodal_device != torch.device("cpu"))
|
||||||
):
|
):
|
||||||
@ -103,7 +201,19 @@ class MultimodalSearch(AnalysisMethod):
|
|||||||
)
|
)
|
||||||
return features_image_stacked
|
return features_image_stacked
|
||||||
|
|
||||||
def extract_image_features_clip(self, model, images_tensors):
|
def extract_image_features_clip(
|
||||||
|
self, model, images_tensors: torch.Tensor
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Extract image features from images_tensors with clip_feature_extractor model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model (torch.nn.Module): model.
|
||||||
|
images_tensors (torch.Tensor): tensors of images stacked in device.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
features_image_stacked (torch.Tensor): tensors of images features stacked in device.
|
||||||
|
"""
|
||||||
features_image = [
|
features_image = [
|
||||||
model.extract_features({"image": ten}) for ten in images_tensors
|
model.extract_features({"image": ten}) for ten in images_tensors
|
||||||
]
|
]
|
||||||
@ -112,7 +222,19 @@ class MultimodalSearch(AnalysisMethod):
|
|||||||
)
|
)
|
||||||
return features_image_stacked
|
return features_image_stacked
|
||||||
|
|
||||||
def extract_image_features_basic(self, model, images_tensors):
|
def extract_image_features_basic(
|
||||||
|
self, model, images_tensors: torch.Tensor
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Extract image features from images_tensors with blip_feature_extractor or albef_feature_extractor model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model (torch.nn.Module): model.
|
||||||
|
images_tensors (torch.Tensor): tensors of images stacked in device.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
features_image_stacked (torch.Tensor): tensors of images features stacked in device.
|
||||||
|
"""
|
||||||
features_image = [
|
features_image = [
|
||||||
model.extract_features({"image": ten, "text_input": ""}, mode="image")
|
model.extract_features({"image": ten, "text_input": ""}, mode="image")
|
||||||
for ten in images_tensors
|
for ten in images_tensors
|
||||||
@ -124,11 +246,23 @@ class MultimodalSearch(AnalysisMethod):
|
|||||||
|
|
||||||
def save_tensors(
|
def save_tensors(
|
||||||
self,
|
self,
|
||||||
model_type,
|
model_type: str,
|
||||||
features_image_stacked,
|
features_image_stacked: torch.Tensor,
|
||||||
name="saved_features_image.pt",
|
name: str = "saved_features_image.pt",
|
||||||
path="./saved_tensors/",
|
path: str = "./saved_tensors/",
|
||||||
):
|
) -> str:
|
||||||
|
"""
|
||||||
|
Save tensors as binary to given path.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_type (str): type of the model.
|
||||||
|
features_image_stacked (torch.Tensor): tensors of images features stacked in device.
|
||||||
|
name (str): name of the file. Default: "saved_features_image.pt".
|
||||||
|
path (str): path to save the file. Default: "./saved_tensors/".
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
name (str): name of the file.
|
||||||
|
"""
|
||||||
if not os.path.exists(path):
|
if not os.path.exists(path):
|
||||||
os.makedirs(path)
|
os.makedirs(path)
|
||||||
with open(
|
with open(
|
||||||
@ -143,11 +277,30 @@ class MultimodalSearch(AnalysisMethod):
|
|||||||
torch.save(features_image_stacked, f)
|
torch.save(features_image_stacked, f)
|
||||||
return name
|
return name
|
||||||
|
|
||||||
def load_tensors(self, name):
|
def load_tensors(self, name: str) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Load tensors from given path.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name (str): name of the file.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
features_image_stacked (torch.Tensor): tensors of images features.
|
||||||
|
"""
|
||||||
features_image_stacked = torch.load(name)
|
features_image_stacked = torch.load(name)
|
||||||
return features_image_stacked
|
return features_image_stacked
|
||||||
|
|
||||||
def extract_text_features(self, model, text_input):
|
def extract_text_features(self, model, text_input: str) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Extract text features from text_input with feature_extractor model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model (torch.nn.Module): model.
|
||||||
|
text_input (str): text.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
features_text (torch.Tensor): tensors of text features.
|
||||||
|
"""
|
||||||
sample_text = {"text_input": [text_input]}
|
sample_text = {"text_input": [text_input]}
|
||||||
features_text = model.extract_features(sample_text, mode="text")
|
features_text = model.extract_features(sample_text, mode="text")
|
||||||
|
|
||||||
@ -155,10 +308,26 @@ class MultimodalSearch(AnalysisMethod):
|
|||||||
|
|
||||||
def parsing_images(
|
def parsing_images(
|
||||||
self,
|
self,
|
||||||
model_type,
|
model_type: str,
|
||||||
path_to_saved_tensors="./saved_tensors/",
|
path_to_save_tensors: str = "./saved_tensors/",
|
||||||
path_to_load_tensors=None,
|
path_to_load_tensors: str = None,
|
||||||
):
|
) -> tuple:
|
||||||
|
"""
|
||||||
|
Parsing images with feature_extractor model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_type (str): type of the model.
|
||||||
|
path_to_save_tensors (str): path to save the tensors. Default: "./saved_tensors/".
|
||||||
|
path_to_load_tensors (str): path to load the tesors. Default: None.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
model (torch.nn.Module): model.
|
||||||
|
vis_processors (dict): preprocessors for visual inputs.
|
||||||
|
txt_processors (dict): preprocessors for text inputs.
|
||||||
|
image_keys (list): sorted list of image keys.
|
||||||
|
image_names (list): sorted list of image names.
|
||||||
|
features_image_stacked (torch.Tensor): tensors of images features stacked in device.
|
||||||
|
"""
|
||||||
if model_type in ("clip_base", "clip_vitl14_336", "clip_vitl14"):
|
if model_type in ("clip_base", "clip_vitl14_336", "clip_vitl14"):
|
||||||
path_to_lib = lavis.__file__[:-11] + "models/clip_models/"
|
path_to_lib = lavis.__file__[:-11] + "models/clip_models/"
|
||||||
url = "https://raw.githubusercontent.com/salesforce/LAVIS/main/lavis/models/clip_models/bpe_simple_vocab_16e6.txt.gz"
|
url = "https://raw.githubusercontent.com/salesforce/LAVIS/main/lavis/models/clip_models/bpe_simple_vocab_16e6.txt.gz"
|
||||||
@ -208,7 +377,7 @@ class MultimodalSearch(AnalysisMethod):
|
|||||||
self, model, images_tensors
|
self, model, images_tensors
|
||||||
)
|
)
|
||||||
MultimodalSearch.save_tensors(
|
MultimodalSearch.save_tensors(
|
||||||
self, model_type, features_image_stacked, path=path_to_saved_tensors
|
self, model_type, features_image_stacked, path=path_to_save_tensors
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
features_image_stacked = MultimodalSearch.load_tensors(
|
features_image_stacked = MultimodalSearch.load_tensors(
|
||||||
@ -225,8 +394,21 @@ class MultimodalSearch(AnalysisMethod):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def querys_processing(
|
def querys_processing(
|
||||||
self, search_query, model, txt_processors, vis_processors, model_type
|
self, search_query: list, model, txt_processors, vis_processors, model_type: str
|
||||||
):
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Process querys.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
search_query (list): list of querys.
|
||||||
|
model (torch.nn.Module): model.
|
||||||
|
txt_processors (dict): preprocessors for text inputs.
|
||||||
|
vis_processors (dict): preprocessors for visual inputs.
|
||||||
|
model_type (str): type of the model.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
multi_features_stacked (torch.Tensor): tensors of querys features.
|
||||||
|
"""
|
||||||
select_extract_image_features = {
|
select_extract_image_features = {
|
||||||
"blip2": MultimodalSearch.extract_image_features_blip2,
|
"blip2": MultimodalSearch.extract_image_features_blip2,
|
||||||
"blip": MultimodalSearch.extract_image_features_basic,
|
"blip": MultimodalSearch.extract_image_features_basic,
|
||||||
@ -295,14 +477,33 @@ class MultimodalSearch(AnalysisMethod):
|
|||||||
model,
|
model,
|
||||||
vis_processors,
|
vis_processors,
|
||||||
txt_processors,
|
txt_processors,
|
||||||
model_type,
|
model_type: str,
|
||||||
image_keys,
|
image_keys: list,
|
||||||
features_image_stacked,
|
features_image_stacked: torch.Tensor,
|
||||||
search_query,
|
search_query: list,
|
||||||
filter_number_of_images=None,
|
filter_number_of_images: str = None,
|
||||||
filter_val_limit=None,
|
filter_val_limit: str = None,
|
||||||
filter_rel_error=None,
|
filter_rel_error: str = None,
|
||||||
):
|
) -> tuple:
|
||||||
|
"""
|
||||||
|
Search for images with given querys.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model (torch.nn.Module): model.
|
||||||
|
vis_processors (dict): preprocessors for visual inputs.
|
||||||
|
txt_processors (dict): preprocessors for text inputs.
|
||||||
|
model_type (str): type of the model.
|
||||||
|
image_keys (list): sorted list of image keys.
|
||||||
|
features_image_stacked (torch.Tensor): tensors of images features stacked in device.
|
||||||
|
search_query (list): list of querys.
|
||||||
|
filter_number_of_images (str): number of images to show. Default: None.
|
||||||
|
filter_val_limit (str): limit of similarity value. Default: None.
|
||||||
|
filter_rel_error (str): limit of relative error. Default: None.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
similarity (torch.Tensor): similarity between images and querys.
|
||||||
|
sorted_lists (list): sorted list of similarity.
|
||||||
|
"""
|
||||||
if filter_number_of_images is None:
|
if filter_number_of_images is None:
|
||||||
filter_number_of_images = len(self)
|
filter_number_of_images = len(self)
|
||||||
if filter_val_limit is None:
|
if filter_val_limit is None:
|
||||||
@ -343,7 +544,16 @@ class MultimodalSearch(AnalysisMethod):
|
|||||||
self[image_keys[key]][list(search_query[q].values())[0]] = 0
|
self[image_keys[key]][list(search_query[q].values())[0]] = 0
|
||||||
return similarity, sorted_lists
|
return similarity, sorted_lists
|
||||||
|
|
||||||
def itm_text_precessing(self, search_query):
|
def itm_text_precessing(self, search_query: list[dict[str, str]]) -> list:
|
||||||
|
"""
|
||||||
|
Process text querys for itm model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
search_query (list): list of querys.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
text_query_index (list): list of indexes of text querys.
|
||||||
|
"""
|
||||||
for query in search_query:
|
for query in search_query:
|
||||||
if (len(query) != 1) and (query in ("image", "text_input")):
|
if (len(query) != 1) and (query in ("image", "text_input")):
|
||||||
raise SyntaxError(
|
raise SyntaxError(
|
||||||
@ -356,7 +566,17 @@ class MultimodalSearch(AnalysisMethod):
|
|||||||
|
|
||||||
return text_query_index
|
return text_query_index
|
||||||
|
|
||||||
def get_pathes_from_query(self, query):
|
def get_pathes_from_query(self, query: dict[str, str]) -> tuple:
|
||||||
|
"""
|
||||||
|
Get pathes and image names from query.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query (dict): query.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
paths (list): list of pathes.
|
||||||
|
image_names (list): list of image names.
|
||||||
|
"""
|
||||||
paths = []
|
paths = []
|
||||||
image_names = []
|
image_names = []
|
||||||
for s in sorted(
|
for s in sorted(
|
||||||
@ -368,7 +588,18 @@ class MultimodalSearch(AnalysisMethod):
|
|||||||
image_names.append(s[0])
|
image_names.append(s[0])
|
||||||
return paths, image_names
|
return paths, image_names
|
||||||
|
|
||||||
def read_and_process_images_itm(self, image_paths, vis_processor):
|
def read_and_process_images_itm(self, image_paths: list, vis_processor) -> tuple:
|
||||||
|
"""
|
||||||
|
Read and process images with vis_processor for itm model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image_paths (list): paths to images.
|
||||||
|
vis_processor (dict): preprocessors for visual inputs.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
raw_images (list): list of images.
|
||||||
|
images_tensors (torch.Tensor): tensors of images stacked in device.
|
||||||
|
"""
|
||||||
raw_images = [MultimodalSearch.read_img(self, path) for path in image_paths]
|
raw_images = [MultimodalSearch.read_img(self, path) for path in image_paths]
|
||||||
images = [vis_processor(r_img) for r_img in raw_images]
|
images = [vis_processor(r_img) for r_img in raw_images]
|
||||||
images_tensors = torch.stack(images).to(MultimodalSearch.multimodal_device)
|
images_tensors = torch.stack(images).to(MultimodalSearch.multimodal_device)
|
||||||
@ -377,12 +608,26 @@ class MultimodalSearch(AnalysisMethod):
|
|||||||
|
|
||||||
def compute_gradcam_batch(
|
def compute_gradcam_batch(
|
||||||
self,
|
self,
|
||||||
model,
|
model: torch.nn.Module,
|
||||||
visual_input,
|
visual_input: torch.Tensor,
|
||||||
text_input,
|
text_input: str,
|
||||||
tokenized_text,
|
tokenized_text: torch.Tensor,
|
||||||
block_num=6,
|
block_num: str = 6,
|
||||||
):
|
) -> tuple:
|
||||||
|
"""
|
||||||
|
Compute gradcam for itm model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model (torch.nn.Module): model.
|
||||||
|
visual_input (torch.Tensor): tensors of images features stacked in device.
|
||||||
|
text_input (str): text.
|
||||||
|
tokenized_text (torch.Tensor): tokenized text.
|
||||||
|
block_num (int): number of block. Default: 6.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
gradcam (torch.Tensor): gradcam.
|
||||||
|
output (torch.Tensor): output of model.
|
||||||
|
"""
|
||||||
model.text_encoder.base_model.base_model.encoder.layer[
|
model.text_encoder.base_model.base_model.encoder.layer[
|
||||||
block_num
|
block_num
|
||||||
].crossattention.self.save_attention = True
|
].crossattention.self.save_attention = True
|
||||||
@ -430,7 +675,16 @@ class MultimodalSearch(AnalysisMethod):
|
|||||||
|
|
||||||
return gradcam, output
|
return gradcam, output
|
||||||
|
|
||||||
def resize_img(self, raw_img):
|
def resize_img(self, raw_img: Image):
|
||||||
|
"""
|
||||||
|
Proportional resize image to 240 p width.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
raw_img (PIL.Image): image.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
resized_image (PIL.Image): proportional resized image to 240p.
|
||||||
|
"""
|
||||||
w, h = raw_img.size
|
w, h = raw_img.size
|
||||||
scaling_factor = 240 / w
|
scaling_factor = 240 / w
|
||||||
resized_image = raw_img.resize(
|
resized_image = raw_img.resize(
|
||||||
@ -438,7 +692,25 @@ class MultimodalSearch(AnalysisMethod):
|
|||||||
)
|
)
|
||||||
return resized_image
|
return resized_image
|
||||||
|
|
||||||
def get_att_map(self, img, att_map, blur=True, overlap=True):
|
def get_att_map(
|
||||||
|
self,
|
||||||
|
img: np.ndarray,
|
||||||
|
att_map: np.ndarray,
|
||||||
|
blur: bool = True,
|
||||||
|
overlap: bool = True,
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Get attention map.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
img (np.ndarray): image.
|
||||||
|
att_map (np.ndarray): attention map.
|
||||||
|
blur (bool): blur attention map. Default: True.
|
||||||
|
overlap (bool): overlap attention map with image. Default: True.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
att_map (np.ndarray): attention map.
|
||||||
|
"""
|
||||||
att_map -= att_map.min()
|
att_map -= att_map.min()
|
||||||
if att_map.max() > 0:
|
if att_map.max() > 0:
|
||||||
att_map /= att_map.max()
|
att_map /= att_map.max()
|
||||||
@ -459,7 +731,14 @@ class MultimodalSearch(AnalysisMethod):
|
|||||||
)
|
)
|
||||||
return att_map
|
return att_map
|
||||||
|
|
||||||
def upload_model_blip2_coco(self):
|
def upload_model_blip2_coco(self) -> tuple:
|
||||||
|
"""
|
||||||
|
Load coco blip2_image_text_matching model and preprocessors for visual inputs from lavis.models.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
itm_model (torch.nn.Module): model.
|
||||||
|
vis_processor (dict): preprocessors for visual inputs.
|
||||||
|
"""
|
||||||
itm_model = load_model(
|
itm_model = load_model(
|
||||||
"blip2_image_text_matching",
|
"blip2_image_text_matching",
|
||||||
"coco",
|
"coco",
|
||||||
@ -469,7 +748,14 @@ class MultimodalSearch(AnalysisMethod):
|
|||||||
vis_processor = load_processor("blip_image_eval").build(image_size=364)
|
vis_processor = load_processor("blip_image_eval").build(image_size=364)
|
||||||
return itm_model, vis_processor
|
return itm_model, vis_processor
|
||||||
|
|
||||||
def upload_model_blip_base(self):
|
def upload_model_blip_base(self) -> tuple:
|
||||||
|
"""
|
||||||
|
Load base blip_image_text_matching model and preprocessors for visual input from lavis.models.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
itm_model (torch.nn.Module): model.
|
||||||
|
vis_processor (dict): preprocessors for visual inputs.
|
||||||
|
"""
|
||||||
itm_model = load_model(
|
itm_model = load_model(
|
||||||
"blip_image_text_matching",
|
"blip_image_text_matching",
|
||||||
"base",
|
"base",
|
||||||
@ -479,7 +765,14 @@ class MultimodalSearch(AnalysisMethod):
|
|||||||
vis_processor = load_processor("blip_image_eval").build(image_size=384)
|
vis_processor = load_processor("blip_image_eval").build(image_size=384)
|
||||||
return itm_model, vis_processor
|
return itm_model, vis_processor
|
||||||
|
|
||||||
def upload_model_blip_large(self):
|
def upload_model_blip_large(self) -> tuple:
|
||||||
|
"""
|
||||||
|
Load large blip_image_text_matching model and preprocessors for visual input from lavis.models.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
itm_model (torch.nn.Module): model.
|
||||||
|
vis_processor (dict): preprocessors for visual inputs.
|
||||||
|
"""
|
||||||
itm_model = load_model(
|
itm_model = load_model(
|
||||||
"blip_image_text_matching",
|
"blip_image_text_matching",
|
||||||
"large",
|
"large",
|
||||||
@ -491,13 +784,28 @@ class MultimodalSearch(AnalysisMethod):
|
|||||||
|
|
||||||
def image_text_match_reordering(
|
def image_text_match_reordering(
|
||||||
self,
|
self,
|
||||||
search_query,
|
search_query: list[dict[str, str]],
|
||||||
itm_model_type,
|
itm_model_type: str,
|
||||||
image_keys,
|
image_keys: list,
|
||||||
sorted_lists,
|
sorted_lists: list[list],
|
||||||
batch_size=1,
|
batch_size: int = 1,
|
||||||
need_grad_cam=False,
|
need_grad_cam: bool = False,
|
||||||
):
|
) -> tuple:
|
||||||
|
"""
|
||||||
|
Reorder images with itm model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
search_query (list): list of querys.
|
||||||
|
itm_model_type (str): type of the model.
|
||||||
|
image_keys (list): sorted list of image keys.
|
||||||
|
sorted_lists (list): sorted list of similarity.
|
||||||
|
batch_size (int): batch size. Default: 1.
|
||||||
|
need_grad_cam (bool): need gradcam. Default: False. blip2_coco model does not yet work with gradcam.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
itm_scores2: list of itm scores.
|
||||||
|
image_gradcam_with_itm: dict of image names and gradcam.
|
||||||
|
"""
|
||||||
if itm_model_type == "blip2_coco" and need_grad_cam is True:
|
if itm_model_type == "blip2_coco" and need_grad_cam is True:
|
||||||
raise SyntaxError(
|
raise SyntaxError(
|
||||||
"The blip2_coco model does not yet work with gradcam. Please set need_grad_cam to False"
|
"The blip2_coco model does not yet work with gradcam. Please set need_grad_cam to False"
|
||||||
@ -624,7 +932,17 @@ class MultimodalSearch(AnalysisMethod):
|
|||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
return itm_scores2, image_gradcam_with_itm
|
return itm_scores2, image_gradcam_with_itm
|
||||||
|
|
||||||
def show_results(self, query, itm=False, image_gradcam_with_itm=False):
|
def show_results(
|
||||||
|
self, query: dict, itm=False, image_gradcam_with_itm=False
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Show results of search.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query (dict): query.
|
||||||
|
itm (bool): use itm model. Default: False.
|
||||||
|
image_gradcam_with_itm (bool): use gradcam. Default: False.
|
||||||
|
"""
|
||||||
if "image" in query.keys():
|
if "image" in query.keys():
|
||||||
pic = Image.open(query["image"]).convert("RGB")
|
pic = Image.open(query["image"]).convert("RGB")
|
||||||
pic.thumbnail((400, 400))
|
pic.thumbnail((400, 400))
|
||||||
|
|||||||
Загрузка…
x
Ссылка в новой задаче
Block a user