зеркало из
https://github.com/ssciwr/AMMICO.git
synced 2025-10-29 13:06:04 +02:00
Wiriting docs to multimodal_search.py
Этот коммит содержится в:
родитель
ca3994c72f
Коммит
f524ecaad8
@ -20,7 +20,18 @@ class MultimodalSearch(AnalysisMethod):
|
||||
|
||||
multimodal_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
def load_feature_extractor_model_blip2(self, device):
|
||||
def load_feature_extractor_model_blip2(self, device: str = "cpu"):
|
||||
"""
|
||||
Load pretrain blip2_feature_extractor model and preprocessors for visual and text inputs from lavis.models.
|
||||
|
||||
Args:
|
||||
device (str): device to use. Can be "cpu" or "cuda". Default: "cpu".
|
||||
|
||||
Returns:
|
||||
model (torch.nn.Module): model.
|
||||
vis_processors (dict): preprocessors for visual inputs.
|
||||
txt_processors (dict): preprocessors for text inputs.
|
||||
"""
|
||||
model, vis_processors, txt_processors = load_model_and_preprocess(
|
||||
name="blip2_feature_extractor",
|
||||
model_type="pretrain",
|
||||
@ -29,7 +40,18 @@ class MultimodalSearch(AnalysisMethod):
|
||||
)
|
||||
return model, vis_processors, txt_processors
|
||||
|
||||
def load_feature_extractor_model_blip(self, device):
|
||||
def load_feature_extractor_model_blip(self, device: str = "cpu"):
|
||||
"""
|
||||
Load base blip_feature_extractor model and preprocessors for visual and text inputs from lavis.models.
|
||||
|
||||
Args:
|
||||
device (str): device to use. Can be "cpu" or "cuda". Default: "cpu".
|
||||
|
||||
Returns:
|
||||
model (torch.nn.Module): model.
|
||||
vis_processors (dict): preprocessors for visual inputs.
|
||||
txt_processors (dict): preprocessors for text inputs.
|
||||
"""
|
||||
model, vis_processors, txt_processors = load_model_and_preprocess(
|
||||
name="blip_feature_extractor",
|
||||
model_type="base",
|
||||
@ -38,7 +60,18 @@ class MultimodalSearch(AnalysisMethod):
|
||||
)
|
||||
return model, vis_processors, txt_processors
|
||||
|
||||
def load_feature_extractor_model_albef(self, device):
|
||||
def load_feature_extractor_model_albef(self, device: str = "cpu"):
|
||||
"""
|
||||
Load base albef_feature_extractor model and preprocessors for visual and text inputs from lavis.models.
|
||||
|
||||
Args:
|
||||
device (str): device to use. Can be "cpu" or "cuda". Default: "cpu".
|
||||
|
||||
Returns:
|
||||
model (torch.nn.Module): model.
|
||||
vis_processors (dict): preprocessors for visual inputs.
|
||||
txt_processors (dict): preprocessors for text inputs.
|
||||
"""
|
||||
model, vis_processors, txt_processors = load_model_and_preprocess(
|
||||
name="albef_feature_extractor",
|
||||
model_type="base",
|
||||
@ -47,7 +80,18 @@ class MultimodalSearch(AnalysisMethod):
|
||||
)
|
||||
return model, vis_processors, txt_processors
|
||||
|
||||
def load_feature_extractor_model_clip_base(self, device):
|
||||
def load_feature_extractor_model_clip_base(self, device: str = "cpu"):
|
||||
"""
|
||||
Load base clip_feature_extractor model and preprocessors for visual and text inputs from lavis.models.
|
||||
|
||||
Args:
|
||||
device (str): device to use. Can be "cpu" or "cuda". Default: "cpu".
|
||||
|
||||
Returns:
|
||||
model (torch.nn.Module): model.
|
||||
vis_processors (dict): preprocessors for visual inputs.
|
||||
txt_processors (dict): preprocessors for text inputs.
|
||||
"""
|
||||
model, vis_processors, txt_processors = load_model_and_preprocess(
|
||||
name="clip_feature_extractor",
|
||||
model_type="base",
|
||||
@ -56,7 +100,18 @@ class MultimodalSearch(AnalysisMethod):
|
||||
)
|
||||
return model, vis_processors, txt_processors
|
||||
|
||||
def load_feature_extractor_model_clip_vitl14(self, device):
|
||||
def load_feature_extractor_model_clip_vitl14(self, device: str = "cpu"):
|
||||
"""
|
||||
Load ViT-L-14 clip_feature_extractor model and preprocessors for visual and text inputs from lavis.models.
|
||||
|
||||
Args:
|
||||
device (str): device to use. Can be "cpu" or "cuda". Default: "cpu".
|
||||
|
||||
Returns:
|
||||
model (torch.nn.Module): model.
|
||||
vis_processors (dict): preprocessors for visual inputs.
|
||||
txt_processors (dict): preprocessors for text inputs.
|
||||
"""
|
||||
model, vis_processors, txt_processors = load_model_and_preprocess(
|
||||
name="clip_feature_extractor",
|
||||
model_type="ViT-L-14",
|
||||
@ -65,7 +120,18 @@ class MultimodalSearch(AnalysisMethod):
|
||||
)
|
||||
return model, vis_processors, txt_processors
|
||||
|
||||
def load_feature_extractor_model_clip_vitl14_336(self, device):
|
||||
def load_feature_extractor_model_clip_vitl14_336(self, device: str = "cpu"):
|
||||
"""
|
||||
Load ViT-L-14-336 clip_feature_extractor model and preprocessors for visual and text inputs from lavis.models.
|
||||
|
||||
Args:
|
||||
device (str): device to use. Can be "cpu" or "cuda". Default: "cpu".
|
||||
|
||||
Returns:
|
||||
model (torch.nn.Module): model.
|
||||
vis_processors (dict): preprocessors for visual inputs.
|
||||
txt_processors (dict): preprocessors for text inputs.
|
||||
"""
|
||||
model, vis_processors, txt_processors = load_model_and_preprocess(
|
||||
name="clip_feature_extractor",
|
||||
model_type="ViT-L-14-336",
|
||||
@ -74,11 +140,31 @@ class MultimodalSearch(AnalysisMethod):
|
||||
)
|
||||
return model, vis_processors, txt_processors
|
||||
|
||||
def read_img(self, filepath):
|
||||
def read_img(self, filepath: str) -> Image:
|
||||
"""
|
||||
Load Image from filepath.
|
||||
|
||||
Args:
|
||||
filepath (str): path to image.
|
||||
|
||||
Returns:
|
||||
raw_image (PIL.Image): image.
|
||||
"""
|
||||
raw_image = Image.open(filepath).convert("RGB")
|
||||
return raw_image
|
||||
|
||||
def read_and_process_images(self, image_paths, vis_processor):
|
||||
def read_and_process_images(self, image_paths: str, vis_processor) -> tuple:
|
||||
"""
|
||||
Read and process images with vis_processor.
|
||||
|
||||
Args:
|
||||
image_paths (str): paths to images.
|
||||
vis_processor (dict): preprocessors for visual inputs.
|
||||
|
||||
Returns:
|
||||
raw_images (list): list of images.
|
||||
images_tensors (torch.Tensor): tensors of images stacked in device.
|
||||
"""
|
||||
raw_images = [MultimodalSearch.read_img(self, path) for path in image_paths]
|
||||
images = [
|
||||
vis_processor["eval"](r_img)
|
||||
@ -90,7 +176,19 @@ class MultimodalSearch(AnalysisMethod):
|
||||
|
||||
return raw_images, images_tensors
|
||||
|
||||
def extract_image_features_blip2(self, model, images_tensors):
|
||||
def extract_image_features_blip2(
|
||||
self, model, images_tensors: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Extract image features from images_tensors with blip2_feature_extractor model.
|
||||
|
||||
Args:
|
||||
model (torch.nn.Module): model.
|
||||
images_tensors (torch.Tensor): tensors of images stacked in device.
|
||||
|
||||
Returns:
|
||||
features_image_stacked (torch.Tensor): tensors of images features stacked in device.
|
||||
"""
|
||||
with torch.cuda.amp.autocast(
|
||||
enabled=(MultimodalSearch.multimodal_device != torch.device("cpu"))
|
||||
):
|
||||
@ -103,7 +201,19 @@ class MultimodalSearch(AnalysisMethod):
|
||||
)
|
||||
return features_image_stacked
|
||||
|
||||
def extract_image_features_clip(self, model, images_tensors):
|
||||
def extract_image_features_clip(
|
||||
self, model, images_tensors: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Extract image features from images_tensors with clip_feature_extractor model.
|
||||
|
||||
Args:
|
||||
model (torch.nn.Module): model.
|
||||
images_tensors (torch.Tensor): tensors of images stacked in device.
|
||||
|
||||
Returns:
|
||||
features_image_stacked (torch.Tensor): tensors of images features stacked in device.
|
||||
"""
|
||||
features_image = [
|
||||
model.extract_features({"image": ten}) for ten in images_tensors
|
||||
]
|
||||
@ -112,7 +222,19 @@ class MultimodalSearch(AnalysisMethod):
|
||||
)
|
||||
return features_image_stacked
|
||||
|
||||
def extract_image_features_basic(self, model, images_tensors):
|
||||
def extract_image_features_basic(
|
||||
self, model, images_tensors: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Extract image features from images_tensors with blip_feature_extractor or albef_feature_extractor model.
|
||||
|
||||
Args:
|
||||
model (torch.nn.Module): model.
|
||||
images_tensors (torch.Tensor): tensors of images stacked in device.
|
||||
|
||||
Returns:
|
||||
features_image_stacked (torch.Tensor): tensors of images features stacked in device.
|
||||
"""
|
||||
features_image = [
|
||||
model.extract_features({"image": ten, "text_input": ""}, mode="image")
|
||||
for ten in images_tensors
|
||||
@ -124,11 +246,23 @@ class MultimodalSearch(AnalysisMethod):
|
||||
|
||||
def save_tensors(
|
||||
self,
|
||||
model_type,
|
||||
features_image_stacked,
|
||||
name="saved_features_image.pt",
|
||||
path="./saved_tensors/",
|
||||
):
|
||||
model_type: str,
|
||||
features_image_stacked: torch.Tensor,
|
||||
name: str = "saved_features_image.pt",
|
||||
path: str = "./saved_tensors/",
|
||||
) -> str:
|
||||
"""
|
||||
Save tensors as binary to given path.
|
||||
|
||||
Args:
|
||||
model_type (str): type of the model.
|
||||
features_image_stacked (torch.Tensor): tensors of images features stacked in device.
|
||||
name (str): name of the file. Default: "saved_features_image.pt".
|
||||
path (str): path to save the file. Default: "./saved_tensors/".
|
||||
|
||||
Returns:
|
||||
name (str): name of the file.
|
||||
"""
|
||||
if not os.path.exists(path):
|
||||
os.makedirs(path)
|
||||
with open(
|
||||
@ -143,11 +277,30 @@ class MultimodalSearch(AnalysisMethod):
|
||||
torch.save(features_image_stacked, f)
|
||||
return name
|
||||
|
||||
def load_tensors(self, name):
|
||||
def load_tensors(self, name: str) -> torch.Tensor:
|
||||
"""
|
||||
Load tensors from given path.
|
||||
|
||||
Args:
|
||||
name (str): name of the file.
|
||||
|
||||
Returns:
|
||||
features_image_stacked (torch.Tensor): tensors of images features.
|
||||
"""
|
||||
features_image_stacked = torch.load(name)
|
||||
return features_image_stacked
|
||||
|
||||
def extract_text_features(self, model, text_input):
|
||||
def extract_text_features(self, model, text_input: str) -> torch.Tensor:
|
||||
"""
|
||||
Extract text features from text_input with feature_extractor model.
|
||||
|
||||
Args:
|
||||
model (torch.nn.Module): model.
|
||||
text_input (str): text.
|
||||
|
||||
Returns:
|
||||
features_text (torch.Tensor): tensors of text features.
|
||||
"""
|
||||
sample_text = {"text_input": [text_input]}
|
||||
features_text = model.extract_features(sample_text, mode="text")
|
||||
|
||||
@ -155,10 +308,26 @@ class MultimodalSearch(AnalysisMethod):
|
||||
|
||||
def parsing_images(
|
||||
self,
|
||||
model_type,
|
||||
path_to_saved_tensors="./saved_tensors/",
|
||||
path_to_load_tensors=None,
|
||||
):
|
||||
model_type: str,
|
||||
path_to_save_tensors: str = "./saved_tensors/",
|
||||
path_to_load_tensors: str = None,
|
||||
) -> tuple:
|
||||
"""
|
||||
Parsing images with feature_extractor model.
|
||||
|
||||
Args:
|
||||
model_type (str): type of the model.
|
||||
path_to_save_tensors (str): path to save the tensors. Default: "./saved_tensors/".
|
||||
path_to_load_tensors (str): path to load the tesors. Default: None.
|
||||
|
||||
Returns:
|
||||
model (torch.nn.Module): model.
|
||||
vis_processors (dict): preprocessors for visual inputs.
|
||||
txt_processors (dict): preprocessors for text inputs.
|
||||
image_keys (list): sorted list of image keys.
|
||||
image_names (list): sorted list of image names.
|
||||
features_image_stacked (torch.Tensor): tensors of images features stacked in device.
|
||||
"""
|
||||
if model_type in ("clip_base", "clip_vitl14_336", "clip_vitl14"):
|
||||
path_to_lib = lavis.__file__[:-11] + "models/clip_models/"
|
||||
url = "https://raw.githubusercontent.com/salesforce/LAVIS/main/lavis/models/clip_models/bpe_simple_vocab_16e6.txt.gz"
|
||||
@ -208,7 +377,7 @@ class MultimodalSearch(AnalysisMethod):
|
||||
self, model, images_tensors
|
||||
)
|
||||
MultimodalSearch.save_tensors(
|
||||
self, model_type, features_image_stacked, path=path_to_saved_tensors
|
||||
self, model_type, features_image_stacked, path=path_to_save_tensors
|
||||
)
|
||||
else:
|
||||
features_image_stacked = MultimodalSearch.load_tensors(
|
||||
@ -225,8 +394,21 @@ class MultimodalSearch(AnalysisMethod):
|
||||
)
|
||||
|
||||
def querys_processing(
|
||||
self, search_query, model, txt_processors, vis_processors, model_type
|
||||
):
|
||||
self, search_query: list, model, txt_processors, vis_processors, model_type: str
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Process querys.
|
||||
|
||||
Args:
|
||||
search_query (list): list of querys.
|
||||
model (torch.nn.Module): model.
|
||||
txt_processors (dict): preprocessors for text inputs.
|
||||
vis_processors (dict): preprocessors for visual inputs.
|
||||
model_type (str): type of the model.
|
||||
|
||||
Returns:
|
||||
multi_features_stacked (torch.Tensor): tensors of querys features.
|
||||
"""
|
||||
select_extract_image_features = {
|
||||
"blip2": MultimodalSearch.extract_image_features_blip2,
|
||||
"blip": MultimodalSearch.extract_image_features_basic,
|
||||
@ -295,14 +477,33 @@ class MultimodalSearch(AnalysisMethod):
|
||||
model,
|
||||
vis_processors,
|
||||
txt_processors,
|
||||
model_type,
|
||||
image_keys,
|
||||
features_image_stacked,
|
||||
search_query,
|
||||
filter_number_of_images=None,
|
||||
filter_val_limit=None,
|
||||
filter_rel_error=None,
|
||||
):
|
||||
model_type: str,
|
||||
image_keys: list,
|
||||
features_image_stacked: torch.Tensor,
|
||||
search_query: list,
|
||||
filter_number_of_images: str = None,
|
||||
filter_val_limit: str = None,
|
||||
filter_rel_error: str = None,
|
||||
) -> tuple:
|
||||
"""
|
||||
Search for images with given querys.
|
||||
|
||||
Args:
|
||||
model (torch.nn.Module): model.
|
||||
vis_processors (dict): preprocessors for visual inputs.
|
||||
txt_processors (dict): preprocessors for text inputs.
|
||||
model_type (str): type of the model.
|
||||
image_keys (list): sorted list of image keys.
|
||||
features_image_stacked (torch.Tensor): tensors of images features stacked in device.
|
||||
search_query (list): list of querys.
|
||||
filter_number_of_images (str): number of images to show. Default: None.
|
||||
filter_val_limit (str): limit of similarity value. Default: None.
|
||||
filter_rel_error (str): limit of relative error. Default: None.
|
||||
|
||||
Returns:
|
||||
similarity (torch.Tensor): similarity between images and querys.
|
||||
sorted_lists (list): sorted list of similarity.
|
||||
"""
|
||||
if filter_number_of_images is None:
|
||||
filter_number_of_images = len(self)
|
||||
if filter_val_limit is None:
|
||||
@ -343,7 +544,16 @@ class MultimodalSearch(AnalysisMethod):
|
||||
self[image_keys[key]][list(search_query[q].values())[0]] = 0
|
||||
return similarity, sorted_lists
|
||||
|
||||
def itm_text_precessing(self, search_query):
|
||||
def itm_text_precessing(self, search_query: list[dict[str, str]]) -> list:
|
||||
"""
|
||||
Process text querys for itm model.
|
||||
|
||||
Args:
|
||||
search_query (list): list of querys.
|
||||
|
||||
Returns:
|
||||
text_query_index (list): list of indexes of text querys.
|
||||
"""
|
||||
for query in search_query:
|
||||
if (len(query) != 1) and (query in ("image", "text_input")):
|
||||
raise SyntaxError(
|
||||
@ -356,7 +566,17 @@ class MultimodalSearch(AnalysisMethod):
|
||||
|
||||
return text_query_index
|
||||
|
||||
def get_pathes_from_query(self, query):
|
||||
def get_pathes_from_query(self, query: dict[str, str]) -> tuple:
|
||||
"""
|
||||
Get pathes and image names from query.
|
||||
|
||||
Args:
|
||||
query (dict): query.
|
||||
|
||||
Returns:
|
||||
paths (list): list of pathes.
|
||||
image_names (list): list of image names.
|
||||
"""
|
||||
paths = []
|
||||
image_names = []
|
||||
for s in sorted(
|
||||
@ -368,7 +588,18 @@ class MultimodalSearch(AnalysisMethod):
|
||||
image_names.append(s[0])
|
||||
return paths, image_names
|
||||
|
||||
def read_and_process_images_itm(self, image_paths, vis_processor):
|
||||
def read_and_process_images_itm(self, image_paths: list, vis_processor) -> tuple:
|
||||
"""
|
||||
Read and process images with vis_processor for itm model.
|
||||
|
||||
Args:
|
||||
image_paths (list): paths to images.
|
||||
vis_processor (dict): preprocessors for visual inputs.
|
||||
|
||||
Returns:
|
||||
raw_images (list): list of images.
|
||||
images_tensors (torch.Tensor): tensors of images stacked in device.
|
||||
"""
|
||||
raw_images = [MultimodalSearch.read_img(self, path) for path in image_paths]
|
||||
images = [vis_processor(r_img) for r_img in raw_images]
|
||||
images_tensors = torch.stack(images).to(MultimodalSearch.multimodal_device)
|
||||
@ -377,12 +608,26 @@ class MultimodalSearch(AnalysisMethod):
|
||||
|
||||
def compute_gradcam_batch(
|
||||
self,
|
||||
model,
|
||||
visual_input,
|
||||
text_input,
|
||||
tokenized_text,
|
||||
block_num=6,
|
||||
):
|
||||
model: torch.nn.Module,
|
||||
visual_input: torch.Tensor,
|
||||
text_input: str,
|
||||
tokenized_text: torch.Tensor,
|
||||
block_num: str = 6,
|
||||
) -> tuple:
|
||||
"""
|
||||
Compute gradcam for itm model.
|
||||
|
||||
Args:
|
||||
model (torch.nn.Module): model.
|
||||
visual_input (torch.Tensor): tensors of images features stacked in device.
|
||||
text_input (str): text.
|
||||
tokenized_text (torch.Tensor): tokenized text.
|
||||
block_num (int): number of block. Default: 6.
|
||||
|
||||
Returns:
|
||||
gradcam (torch.Tensor): gradcam.
|
||||
output (torch.Tensor): output of model.
|
||||
"""
|
||||
model.text_encoder.base_model.base_model.encoder.layer[
|
||||
block_num
|
||||
].crossattention.self.save_attention = True
|
||||
@ -430,7 +675,16 @@ class MultimodalSearch(AnalysisMethod):
|
||||
|
||||
return gradcam, output
|
||||
|
||||
def resize_img(self, raw_img):
|
||||
def resize_img(self, raw_img: Image):
|
||||
"""
|
||||
Proportional resize image to 240 p width.
|
||||
|
||||
Args:
|
||||
raw_img (PIL.Image): image.
|
||||
|
||||
Returns:
|
||||
resized_image (PIL.Image): proportional resized image to 240p.
|
||||
"""
|
||||
w, h = raw_img.size
|
||||
scaling_factor = 240 / w
|
||||
resized_image = raw_img.resize(
|
||||
@ -438,7 +692,25 @@ class MultimodalSearch(AnalysisMethod):
|
||||
)
|
||||
return resized_image
|
||||
|
||||
def get_att_map(self, img, att_map, blur=True, overlap=True):
|
||||
def get_att_map(
|
||||
self,
|
||||
img: np.ndarray,
|
||||
att_map: np.ndarray,
|
||||
blur: bool = True,
|
||||
overlap: bool = True,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Get attention map.
|
||||
|
||||
Args:
|
||||
img (np.ndarray): image.
|
||||
att_map (np.ndarray): attention map.
|
||||
blur (bool): blur attention map. Default: True.
|
||||
overlap (bool): overlap attention map with image. Default: True.
|
||||
|
||||
Returns:
|
||||
att_map (np.ndarray): attention map.
|
||||
"""
|
||||
att_map -= att_map.min()
|
||||
if att_map.max() > 0:
|
||||
att_map /= att_map.max()
|
||||
@ -459,7 +731,14 @@ class MultimodalSearch(AnalysisMethod):
|
||||
)
|
||||
return att_map
|
||||
|
||||
def upload_model_blip2_coco(self):
|
||||
def upload_model_blip2_coco(self) -> tuple:
|
||||
"""
|
||||
Load coco blip2_image_text_matching model and preprocessors for visual inputs from lavis.models.
|
||||
|
||||
Returns:
|
||||
itm_model (torch.nn.Module): model.
|
||||
vis_processor (dict): preprocessors for visual inputs.
|
||||
"""
|
||||
itm_model = load_model(
|
||||
"blip2_image_text_matching",
|
||||
"coco",
|
||||
@ -469,7 +748,14 @@ class MultimodalSearch(AnalysisMethod):
|
||||
vis_processor = load_processor("blip_image_eval").build(image_size=364)
|
||||
return itm_model, vis_processor
|
||||
|
||||
def upload_model_blip_base(self):
|
||||
def upload_model_blip_base(self) -> tuple:
|
||||
"""
|
||||
Load base blip_image_text_matching model and preprocessors for visual input from lavis.models.
|
||||
|
||||
Returns:
|
||||
itm_model (torch.nn.Module): model.
|
||||
vis_processor (dict): preprocessors for visual inputs.
|
||||
"""
|
||||
itm_model = load_model(
|
||||
"blip_image_text_matching",
|
||||
"base",
|
||||
@ -479,7 +765,14 @@ class MultimodalSearch(AnalysisMethod):
|
||||
vis_processor = load_processor("blip_image_eval").build(image_size=384)
|
||||
return itm_model, vis_processor
|
||||
|
||||
def upload_model_blip_large(self):
|
||||
def upload_model_blip_large(self) -> tuple:
|
||||
"""
|
||||
Load large blip_image_text_matching model and preprocessors for visual input from lavis.models.
|
||||
|
||||
Returns:
|
||||
itm_model (torch.nn.Module): model.
|
||||
vis_processor (dict): preprocessors for visual inputs.
|
||||
"""
|
||||
itm_model = load_model(
|
||||
"blip_image_text_matching",
|
||||
"large",
|
||||
@ -491,13 +784,28 @@ class MultimodalSearch(AnalysisMethod):
|
||||
|
||||
def image_text_match_reordering(
|
||||
self,
|
||||
search_query,
|
||||
itm_model_type,
|
||||
image_keys,
|
||||
sorted_lists,
|
||||
batch_size=1,
|
||||
need_grad_cam=False,
|
||||
):
|
||||
search_query: list[dict[str, str]],
|
||||
itm_model_type: str,
|
||||
image_keys: list,
|
||||
sorted_lists: list[list],
|
||||
batch_size: int = 1,
|
||||
need_grad_cam: bool = False,
|
||||
) -> tuple:
|
||||
"""
|
||||
Reorder images with itm model.
|
||||
|
||||
Args:
|
||||
search_query (list): list of querys.
|
||||
itm_model_type (str): type of the model.
|
||||
image_keys (list): sorted list of image keys.
|
||||
sorted_lists (list): sorted list of similarity.
|
||||
batch_size (int): batch size. Default: 1.
|
||||
need_grad_cam (bool): need gradcam. Default: False. blip2_coco model does not yet work with gradcam.
|
||||
|
||||
Returns:
|
||||
itm_scores2: list of itm scores.
|
||||
image_gradcam_with_itm: dict of image names and gradcam.
|
||||
"""
|
||||
if itm_model_type == "blip2_coco" and need_grad_cam is True:
|
||||
raise SyntaxError(
|
||||
"The blip2_coco model does not yet work with gradcam. Please set need_grad_cam to False"
|
||||
@ -624,7 +932,17 @@ class MultimodalSearch(AnalysisMethod):
|
||||
torch.cuda.empty_cache()
|
||||
return itm_scores2, image_gradcam_with_itm
|
||||
|
||||
def show_results(self, query, itm=False, image_gradcam_with_itm=False):
|
||||
def show_results(
|
||||
self, query: dict, itm=False, image_gradcam_with_itm=False
|
||||
) -> None:
|
||||
"""
|
||||
Show results of search.
|
||||
|
||||
Args:
|
||||
query (dict): query.
|
||||
itm (bool): use itm model. Default: False.
|
||||
image_gradcam_with_itm (bool): use gradcam. Default: False.
|
||||
"""
|
||||
if "image" in query.keys():
|
||||
pic = Image.open(query["image"]).convert("RGB")
|
||||
pic.thumbnail((400, 400))
|
||||
|
||||
Загрузка…
x
Ссылка в новой задаче
Block a user