Wiriting docs to multimodal_search.py

Этот коммит содержится в:
Petr Andriushchenko 2023-05-23 14:06:02 +02:00
родитель ca3994c72f
Коммит f524ecaad8
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4C4A5DCF634115B6

Просмотреть файл

@ -20,7 +20,18 @@ class MultimodalSearch(AnalysisMethod):
multimodal_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def load_feature_extractor_model_blip2(self, device):
def load_feature_extractor_model_blip2(self, device: str = "cpu"):
"""
Load pretrain blip2_feature_extractor model and preprocessors for visual and text inputs from lavis.models.
Args:
device (str): device to use. Can be "cpu" or "cuda". Default: "cpu".
Returns:
model (torch.nn.Module): model.
vis_processors (dict): preprocessors for visual inputs.
txt_processors (dict): preprocessors for text inputs.
"""
model, vis_processors, txt_processors = load_model_and_preprocess(
name="blip2_feature_extractor",
model_type="pretrain",
@ -29,7 +40,18 @@ class MultimodalSearch(AnalysisMethod):
)
return model, vis_processors, txt_processors
def load_feature_extractor_model_blip(self, device):
def load_feature_extractor_model_blip(self, device: str = "cpu"):
"""
Load base blip_feature_extractor model and preprocessors for visual and text inputs from lavis.models.
Args:
device (str): device to use. Can be "cpu" or "cuda". Default: "cpu".
Returns:
model (torch.nn.Module): model.
vis_processors (dict): preprocessors for visual inputs.
txt_processors (dict): preprocessors for text inputs.
"""
model, vis_processors, txt_processors = load_model_and_preprocess(
name="blip_feature_extractor",
model_type="base",
@ -38,7 +60,18 @@ class MultimodalSearch(AnalysisMethod):
)
return model, vis_processors, txt_processors
def load_feature_extractor_model_albef(self, device):
def load_feature_extractor_model_albef(self, device: str = "cpu"):
"""
Load base albef_feature_extractor model and preprocessors for visual and text inputs from lavis.models.
Args:
device (str): device to use. Can be "cpu" or "cuda". Default: "cpu".
Returns:
model (torch.nn.Module): model.
vis_processors (dict): preprocessors for visual inputs.
txt_processors (dict): preprocessors for text inputs.
"""
model, vis_processors, txt_processors = load_model_and_preprocess(
name="albef_feature_extractor",
model_type="base",
@ -47,7 +80,18 @@ class MultimodalSearch(AnalysisMethod):
)
return model, vis_processors, txt_processors
def load_feature_extractor_model_clip_base(self, device):
def load_feature_extractor_model_clip_base(self, device: str = "cpu"):
"""
Load base clip_feature_extractor model and preprocessors for visual and text inputs from lavis.models.
Args:
device (str): device to use. Can be "cpu" or "cuda". Default: "cpu".
Returns:
model (torch.nn.Module): model.
vis_processors (dict): preprocessors for visual inputs.
txt_processors (dict): preprocessors for text inputs.
"""
model, vis_processors, txt_processors = load_model_and_preprocess(
name="clip_feature_extractor",
model_type="base",
@ -56,7 +100,18 @@ class MultimodalSearch(AnalysisMethod):
)
return model, vis_processors, txt_processors
def load_feature_extractor_model_clip_vitl14(self, device):
def load_feature_extractor_model_clip_vitl14(self, device: str = "cpu"):
"""
Load ViT-L-14 clip_feature_extractor model and preprocessors for visual and text inputs from lavis.models.
Args:
device (str): device to use. Can be "cpu" or "cuda". Default: "cpu".
Returns:
model (torch.nn.Module): model.
vis_processors (dict): preprocessors for visual inputs.
txt_processors (dict): preprocessors for text inputs.
"""
model, vis_processors, txt_processors = load_model_and_preprocess(
name="clip_feature_extractor",
model_type="ViT-L-14",
@ -65,7 +120,18 @@ class MultimodalSearch(AnalysisMethod):
)
return model, vis_processors, txt_processors
def load_feature_extractor_model_clip_vitl14_336(self, device):
def load_feature_extractor_model_clip_vitl14_336(self, device: str = "cpu"):
"""
Load ViT-L-14-336 clip_feature_extractor model and preprocessors for visual and text inputs from lavis.models.
Args:
device (str): device to use. Can be "cpu" or "cuda". Default: "cpu".
Returns:
model (torch.nn.Module): model.
vis_processors (dict): preprocessors for visual inputs.
txt_processors (dict): preprocessors for text inputs.
"""
model, vis_processors, txt_processors = load_model_and_preprocess(
name="clip_feature_extractor",
model_type="ViT-L-14-336",
@ -74,11 +140,31 @@ class MultimodalSearch(AnalysisMethod):
)
return model, vis_processors, txt_processors
def read_img(self, filepath):
def read_img(self, filepath: str) -> Image:
"""
Load Image from filepath.
Args:
filepath (str): path to image.
Returns:
raw_image (PIL.Image): image.
"""
raw_image = Image.open(filepath).convert("RGB")
return raw_image
def read_and_process_images(self, image_paths, vis_processor):
def read_and_process_images(self, image_paths: str, vis_processor) -> tuple:
"""
Read and process images with vis_processor.
Args:
image_paths (str): paths to images.
vis_processor (dict): preprocessors for visual inputs.
Returns:
raw_images (list): list of images.
images_tensors (torch.Tensor): tensors of images stacked in device.
"""
raw_images = [MultimodalSearch.read_img(self, path) for path in image_paths]
images = [
vis_processor["eval"](r_img)
@ -90,7 +176,19 @@ class MultimodalSearch(AnalysisMethod):
return raw_images, images_tensors
def extract_image_features_blip2(self, model, images_tensors):
def extract_image_features_blip2(
self, model, images_tensors: torch.Tensor
) -> torch.Tensor:
"""
Extract image features from images_tensors with blip2_feature_extractor model.
Args:
model (torch.nn.Module): model.
images_tensors (torch.Tensor): tensors of images stacked in device.
Returns:
features_image_stacked (torch.Tensor): tensors of images features stacked in device.
"""
with torch.cuda.amp.autocast(
enabled=(MultimodalSearch.multimodal_device != torch.device("cpu"))
):
@ -103,7 +201,19 @@ class MultimodalSearch(AnalysisMethod):
)
return features_image_stacked
def extract_image_features_clip(self, model, images_tensors):
def extract_image_features_clip(
self, model, images_tensors: torch.Tensor
) -> torch.Tensor:
"""
Extract image features from images_tensors with clip_feature_extractor model.
Args:
model (torch.nn.Module): model.
images_tensors (torch.Tensor): tensors of images stacked in device.
Returns:
features_image_stacked (torch.Tensor): tensors of images features stacked in device.
"""
features_image = [
model.extract_features({"image": ten}) for ten in images_tensors
]
@ -112,7 +222,19 @@ class MultimodalSearch(AnalysisMethod):
)
return features_image_stacked
def extract_image_features_basic(self, model, images_tensors):
def extract_image_features_basic(
self, model, images_tensors: torch.Tensor
) -> torch.Tensor:
"""
Extract image features from images_tensors with blip_feature_extractor or albef_feature_extractor model.
Args:
model (torch.nn.Module): model.
images_tensors (torch.Tensor): tensors of images stacked in device.
Returns:
features_image_stacked (torch.Tensor): tensors of images features stacked in device.
"""
features_image = [
model.extract_features({"image": ten, "text_input": ""}, mode="image")
for ten in images_tensors
@ -124,11 +246,23 @@ class MultimodalSearch(AnalysisMethod):
def save_tensors(
self,
model_type,
features_image_stacked,
name="saved_features_image.pt",
path="./saved_tensors/",
):
model_type: str,
features_image_stacked: torch.Tensor,
name: str = "saved_features_image.pt",
path: str = "./saved_tensors/",
) -> str:
"""
Save tensors as binary to given path.
Args:
model_type (str): type of the model.
features_image_stacked (torch.Tensor): tensors of images features stacked in device.
name (str): name of the file. Default: "saved_features_image.pt".
path (str): path to save the file. Default: "./saved_tensors/".
Returns:
name (str): name of the file.
"""
if not os.path.exists(path):
os.makedirs(path)
with open(
@ -143,11 +277,30 @@ class MultimodalSearch(AnalysisMethod):
torch.save(features_image_stacked, f)
return name
def load_tensors(self, name):
def load_tensors(self, name: str) -> torch.Tensor:
"""
Load tensors from given path.
Args:
name (str): name of the file.
Returns:
features_image_stacked (torch.Tensor): tensors of images features.
"""
features_image_stacked = torch.load(name)
return features_image_stacked
def extract_text_features(self, model, text_input):
def extract_text_features(self, model, text_input: str) -> torch.Tensor:
"""
Extract text features from text_input with feature_extractor model.
Args:
model (torch.nn.Module): model.
text_input (str): text.
Returns:
features_text (torch.Tensor): tensors of text features.
"""
sample_text = {"text_input": [text_input]}
features_text = model.extract_features(sample_text, mode="text")
@ -155,10 +308,26 @@ class MultimodalSearch(AnalysisMethod):
def parsing_images(
self,
model_type,
path_to_saved_tensors="./saved_tensors/",
path_to_load_tensors=None,
):
model_type: str,
path_to_save_tensors: str = "./saved_tensors/",
path_to_load_tensors: str = None,
) -> tuple:
"""
Parsing images with feature_extractor model.
Args:
model_type (str): type of the model.
path_to_save_tensors (str): path to save the tensors. Default: "./saved_tensors/".
path_to_load_tensors (str): path to load the tesors. Default: None.
Returns:
model (torch.nn.Module): model.
vis_processors (dict): preprocessors for visual inputs.
txt_processors (dict): preprocessors for text inputs.
image_keys (list): sorted list of image keys.
image_names (list): sorted list of image names.
features_image_stacked (torch.Tensor): tensors of images features stacked in device.
"""
if model_type in ("clip_base", "clip_vitl14_336", "clip_vitl14"):
path_to_lib = lavis.__file__[:-11] + "models/clip_models/"
url = "https://raw.githubusercontent.com/salesforce/LAVIS/main/lavis/models/clip_models/bpe_simple_vocab_16e6.txt.gz"
@ -208,7 +377,7 @@ class MultimodalSearch(AnalysisMethod):
self, model, images_tensors
)
MultimodalSearch.save_tensors(
self, model_type, features_image_stacked, path=path_to_saved_tensors
self, model_type, features_image_stacked, path=path_to_save_tensors
)
else:
features_image_stacked = MultimodalSearch.load_tensors(
@ -225,8 +394,21 @@ class MultimodalSearch(AnalysisMethod):
)
def querys_processing(
self, search_query, model, txt_processors, vis_processors, model_type
):
self, search_query: list, model, txt_processors, vis_processors, model_type: str
) -> torch.Tensor:
"""
Process querys.
Args:
search_query (list): list of querys.
model (torch.nn.Module): model.
txt_processors (dict): preprocessors for text inputs.
vis_processors (dict): preprocessors for visual inputs.
model_type (str): type of the model.
Returns:
multi_features_stacked (torch.Tensor): tensors of querys features.
"""
select_extract_image_features = {
"blip2": MultimodalSearch.extract_image_features_blip2,
"blip": MultimodalSearch.extract_image_features_basic,
@ -295,14 +477,33 @@ class MultimodalSearch(AnalysisMethod):
model,
vis_processors,
txt_processors,
model_type,
image_keys,
features_image_stacked,
search_query,
filter_number_of_images=None,
filter_val_limit=None,
filter_rel_error=None,
):
model_type: str,
image_keys: list,
features_image_stacked: torch.Tensor,
search_query: list,
filter_number_of_images: str = None,
filter_val_limit: str = None,
filter_rel_error: str = None,
) -> tuple:
"""
Search for images with given querys.
Args:
model (torch.nn.Module): model.
vis_processors (dict): preprocessors for visual inputs.
txt_processors (dict): preprocessors for text inputs.
model_type (str): type of the model.
image_keys (list): sorted list of image keys.
features_image_stacked (torch.Tensor): tensors of images features stacked in device.
search_query (list): list of querys.
filter_number_of_images (str): number of images to show. Default: None.
filter_val_limit (str): limit of similarity value. Default: None.
filter_rel_error (str): limit of relative error. Default: None.
Returns:
similarity (torch.Tensor): similarity between images and querys.
sorted_lists (list): sorted list of similarity.
"""
if filter_number_of_images is None:
filter_number_of_images = len(self)
if filter_val_limit is None:
@ -343,7 +544,16 @@ class MultimodalSearch(AnalysisMethod):
self[image_keys[key]][list(search_query[q].values())[0]] = 0
return similarity, sorted_lists
def itm_text_precessing(self, search_query):
def itm_text_precessing(self, search_query: list[dict[str, str]]) -> list:
"""
Process text querys for itm model.
Args:
search_query (list): list of querys.
Returns:
text_query_index (list): list of indexes of text querys.
"""
for query in search_query:
if (len(query) != 1) and (query in ("image", "text_input")):
raise SyntaxError(
@ -356,7 +566,17 @@ class MultimodalSearch(AnalysisMethod):
return text_query_index
def get_pathes_from_query(self, query):
def get_pathes_from_query(self, query: dict[str, str]) -> tuple:
"""
Get pathes and image names from query.
Args:
query (dict): query.
Returns:
paths (list): list of pathes.
image_names (list): list of image names.
"""
paths = []
image_names = []
for s in sorted(
@ -368,7 +588,18 @@ class MultimodalSearch(AnalysisMethod):
image_names.append(s[0])
return paths, image_names
def read_and_process_images_itm(self, image_paths, vis_processor):
def read_and_process_images_itm(self, image_paths: list, vis_processor) -> tuple:
"""
Read and process images with vis_processor for itm model.
Args:
image_paths (list): paths to images.
vis_processor (dict): preprocessors for visual inputs.
Returns:
raw_images (list): list of images.
images_tensors (torch.Tensor): tensors of images stacked in device.
"""
raw_images = [MultimodalSearch.read_img(self, path) for path in image_paths]
images = [vis_processor(r_img) for r_img in raw_images]
images_tensors = torch.stack(images).to(MultimodalSearch.multimodal_device)
@ -377,12 +608,26 @@ class MultimodalSearch(AnalysisMethod):
def compute_gradcam_batch(
self,
model,
visual_input,
text_input,
tokenized_text,
block_num=6,
):
model: torch.nn.Module,
visual_input: torch.Tensor,
text_input: str,
tokenized_text: torch.Tensor,
block_num: str = 6,
) -> tuple:
"""
Compute gradcam for itm model.
Args:
model (torch.nn.Module): model.
visual_input (torch.Tensor): tensors of images features stacked in device.
text_input (str): text.
tokenized_text (torch.Tensor): tokenized text.
block_num (int): number of block. Default: 6.
Returns:
gradcam (torch.Tensor): gradcam.
output (torch.Tensor): output of model.
"""
model.text_encoder.base_model.base_model.encoder.layer[
block_num
].crossattention.self.save_attention = True
@ -430,7 +675,16 @@ class MultimodalSearch(AnalysisMethod):
return gradcam, output
def resize_img(self, raw_img):
def resize_img(self, raw_img: Image):
"""
Proportional resize image to 240 p width.
Args:
raw_img (PIL.Image): image.
Returns:
resized_image (PIL.Image): proportional resized image to 240p.
"""
w, h = raw_img.size
scaling_factor = 240 / w
resized_image = raw_img.resize(
@ -438,7 +692,25 @@ class MultimodalSearch(AnalysisMethod):
)
return resized_image
def get_att_map(self, img, att_map, blur=True, overlap=True):
def get_att_map(
self,
img: np.ndarray,
att_map: np.ndarray,
blur: bool = True,
overlap: bool = True,
) -> np.ndarray:
"""
Get attention map.
Args:
img (np.ndarray): image.
att_map (np.ndarray): attention map.
blur (bool): blur attention map. Default: True.
overlap (bool): overlap attention map with image. Default: True.
Returns:
att_map (np.ndarray): attention map.
"""
att_map -= att_map.min()
if att_map.max() > 0:
att_map /= att_map.max()
@ -459,7 +731,14 @@ class MultimodalSearch(AnalysisMethod):
)
return att_map
def upload_model_blip2_coco(self):
def upload_model_blip2_coco(self) -> tuple:
"""
Load coco blip2_image_text_matching model and preprocessors for visual inputs from lavis.models.
Returns:
itm_model (torch.nn.Module): model.
vis_processor (dict): preprocessors for visual inputs.
"""
itm_model = load_model(
"blip2_image_text_matching",
"coco",
@ -469,7 +748,14 @@ class MultimodalSearch(AnalysisMethod):
vis_processor = load_processor("blip_image_eval").build(image_size=364)
return itm_model, vis_processor
def upload_model_blip_base(self):
def upload_model_blip_base(self) -> tuple:
"""
Load base blip_image_text_matching model and preprocessors for visual input from lavis.models.
Returns:
itm_model (torch.nn.Module): model.
vis_processor (dict): preprocessors for visual inputs.
"""
itm_model = load_model(
"blip_image_text_matching",
"base",
@ -479,7 +765,14 @@ class MultimodalSearch(AnalysisMethod):
vis_processor = load_processor("blip_image_eval").build(image_size=384)
return itm_model, vis_processor
def upload_model_blip_large(self):
def upload_model_blip_large(self) -> tuple:
"""
Load large blip_image_text_matching model and preprocessors for visual input from lavis.models.
Returns:
itm_model (torch.nn.Module): model.
vis_processor (dict): preprocessors for visual inputs.
"""
itm_model = load_model(
"blip_image_text_matching",
"large",
@ -491,13 +784,28 @@ class MultimodalSearch(AnalysisMethod):
def image_text_match_reordering(
self,
search_query,
itm_model_type,
image_keys,
sorted_lists,
batch_size=1,
need_grad_cam=False,
):
search_query: list[dict[str, str]],
itm_model_type: str,
image_keys: list,
sorted_lists: list[list],
batch_size: int = 1,
need_grad_cam: bool = False,
) -> tuple:
"""
Reorder images with itm model.
Args:
search_query (list): list of querys.
itm_model_type (str): type of the model.
image_keys (list): sorted list of image keys.
sorted_lists (list): sorted list of similarity.
batch_size (int): batch size. Default: 1.
need_grad_cam (bool): need gradcam. Default: False. blip2_coco model does not yet work with gradcam.
Returns:
itm_scores2: list of itm scores.
image_gradcam_with_itm: dict of image names and gradcam.
"""
if itm_model_type == "blip2_coco" and need_grad_cam is True:
raise SyntaxError(
"The blip2_coco model does not yet work with gradcam. Please set need_grad_cam to False"
@ -624,7 +932,17 @@ class MultimodalSearch(AnalysisMethod):
torch.cuda.empty_cache()
return itm_scores2, image_gradcam_with_itm
def show_results(self, query, itm=False, image_gradcam_with_itm=False):
def show_results(
self, query: dict, itm=False, image_gradcam_with_itm=False
) -> None:
"""
Show results of search.
Args:
query (dict): query.
itm (bool): use itm model. Default: False.
image_gradcam_with_itm (bool): use gradcam. Default: False.
"""
if "image" in query.keys():
pic = Image.open(query["image"]).convert("RGB")
pic.thumbnail((400, 400))