Wiriting docs to multimodal_search.py

Этот коммит содержится в:
Petr Andriushchenko 2023-05-23 14:06:02 +02:00
родитель ca3994c72f
Коммит f524ecaad8
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4C4A5DCF634115B6

Просмотреть файл

@ -20,7 +20,18 @@ class MultimodalSearch(AnalysisMethod):
multimodal_device = torch.device("cuda" if torch.cuda.is_available() else "cpu") multimodal_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def load_feature_extractor_model_blip2(self, device): def load_feature_extractor_model_blip2(self, device: str = "cpu"):
"""
Load pretrain blip2_feature_extractor model and preprocessors for visual and text inputs from lavis.models.
Args:
device (str): device to use. Can be "cpu" or "cuda". Default: "cpu".
Returns:
model (torch.nn.Module): model.
vis_processors (dict): preprocessors for visual inputs.
txt_processors (dict): preprocessors for text inputs.
"""
model, vis_processors, txt_processors = load_model_and_preprocess( model, vis_processors, txt_processors = load_model_and_preprocess(
name="blip2_feature_extractor", name="blip2_feature_extractor",
model_type="pretrain", model_type="pretrain",
@ -29,7 +40,18 @@ class MultimodalSearch(AnalysisMethod):
) )
return model, vis_processors, txt_processors return model, vis_processors, txt_processors
def load_feature_extractor_model_blip(self, device): def load_feature_extractor_model_blip(self, device: str = "cpu"):
"""
Load base blip_feature_extractor model and preprocessors for visual and text inputs from lavis.models.
Args:
device (str): device to use. Can be "cpu" or "cuda". Default: "cpu".
Returns:
model (torch.nn.Module): model.
vis_processors (dict): preprocessors for visual inputs.
txt_processors (dict): preprocessors for text inputs.
"""
model, vis_processors, txt_processors = load_model_and_preprocess( model, vis_processors, txt_processors = load_model_and_preprocess(
name="blip_feature_extractor", name="blip_feature_extractor",
model_type="base", model_type="base",
@ -38,7 +60,18 @@ class MultimodalSearch(AnalysisMethod):
) )
return model, vis_processors, txt_processors return model, vis_processors, txt_processors
def load_feature_extractor_model_albef(self, device): def load_feature_extractor_model_albef(self, device: str = "cpu"):
"""
Load base albef_feature_extractor model and preprocessors for visual and text inputs from lavis.models.
Args:
device (str): device to use. Can be "cpu" or "cuda". Default: "cpu".
Returns:
model (torch.nn.Module): model.
vis_processors (dict): preprocessors for visual inputs.
txt_processors (dict): preprocessors for text inputs.
"""
model, vis_processors, txt_processors = load_model_and_preprocess( model, vis_processors, txt_processors = load_model_and_preprocess(
name="albef_feature_extractor", name="albef_feature_extractor",
model_type="base", model_type="base",
@ -47,7 +80,18 @@ class MultimodalSearch(AnalysisMethod):
) )
return model, vis_processors, txt_processors return model, vis_processors, txt_processors
def load_feature_extractor_model_clip_base(self, device): def load_feature_extractor_model_clip_base(self, device: str = "cpu"):
"""
Load base clip_feature_extractor model and preprocessors for visual and text inputs from lavis.models.
Args:
device (str): device to use. Can be "cpu" or "cuda". Default: "cpu".
Returns:
model (torch.nn.Module): model.
vis_processors (dict): preprocessors for visual inputs.
txt_processors (dict): preprocessors for text inputs.
"""
model, vis_processors, txt_processors = load_model_and_preprocess( model, vis_processors, txt_processors = load_model_and_preprocess(
name="clip_feature_extractor", name="clip_feature_extractor",
model_type="base", model_type="base",
@ -56,7 +100,18 @@ class MultimodalSearch(AnalysisMethod):
) )
return model, vis_processors, txt_processors return model, vis_processors, txt_processors
def load_feature_extractor_model_clip_vitl14(self, device): def load_feature_extractor_model_clip_vitl14(self, device: str = "cpu"):
"""
Load ViT-L-14 clip_feature_extractor model and preprocessors for visual and text inputs from lavis.models.
Args:
device (str): device to use. Can be "cpu" or "cuda". Default: "cpu".
Returns:
model (torch.nn.Module): model.
vis_processors (dict): preprocessors for visual inputs.
txt_processors (dict): preprocessors for text inputs.
"""
model, vis_processors, txt_processors = load_model_and_preprocess( model, vis_processors, txt_processors = load_model_and_preprocess(
name="clip_feature_extractor", name="clip_feature_extractor",
model_type="ViT-L-14", model_type="ViT-L-14",
@ -65,7 +120,18 @@ class MultimodalSearch(AnalysisMethod):
) )
return model, vis_processors, txt_processors return model, vis_processors, txt_processors
def load_feature_extractor_model_clip_vitl14_336(self, device): def load_feature_extractor_model_clip_vitl14_336(self, device: str = "cpu"):
"""
Load ViT-L-14-336 clip_feature_extractor model and preprocessors for visual and text inputs from lavis.models.
Args:
device (str): device to use. Can be "cpu" or "cuda". Default: "cpu".
Returns:
model (torch.nn.Module): model.
vis_processors (dict): preprocessors for visual inputs.
txt_processors (dict): preprocessors for text inputs.
"""
model, vis_processors, txt_processors = load_model_and_preprocess( model, vis_processors, txt_processors = load_model_and_preprocess(
name="clip_feature_extractor", name="clip_feature_extractor",
model_type="ViT-L-14-336", model_type="ViT-L-14-336",
@ -74,11 +140,31 @@ class MultimodalSearch(AnalysisMethod):
) )
return model, vis_processors, txt_processors return model, vis_processors, txt_processors
def read_img(self, filepath): def read_img(self, filepath: str) -> Image:
"""
Load Image from filepath.
Args:
filepath (str): path to image.
Returns:
raw_image (PIL.Image): image.
"""
raw_image = Image.open(filepath).convert("RGB") raw_image = Image.open(filepath).convert("RGB")
return raw_image return raw_image
def read_and_process_images(self, image_paths, vis_processor): def read_and_process_images(self, image_paths: str, vis_processor) -> tuple:
"""
Read and process images with vis_processor.
Args:
image_paths (str): paths to images.
vis_processor (dict): preprocessors for visual inputs.
Returns:
raw_images (list): list of images.
images_tensors (torch.Tensor): tensors of images stacked in device.
"""
raw_images = [MultimodalSearch.read_img(self, path) for path in image_paths] raw_images = [MultimodalSearch.read_img(self, path) for path in image_paths]
images = [ images = [
vis_processor["eval"](r_img) vis_processor["eval"](r_img)
@ -90,7 +176,19 @@ class MultimodalSearch(AnalysisMethod):
return raw_images, images_tensors return raw_images, images_tensors
def extract_image_features_blip2(self, model, images_tensors): def extract_image_features_blip2(
self, model, images_tensors: torch.Tensor
) -> torch.Tensor:
"""
Extract image features from images_tensors with blip2_feature_extractor model.
Args:
model (torch.nn.Module): model.
images_tensors (torch.Tensor): tensors of images stacked in device.
Returns:
features_image_stacked (torch.Tensor): tensors of images features stacked in device.
"""
with torch.cuda.amp.autocast( with torch.cuda.amp.autocast(
enabled=(MultimodalSearch.multimodal_device != torch.device("cpu")) enabled=(MultimodalSearch.multimodal_device != torch.device("cpu"))
): ):
@ -103,7 +201,19 @@ class MultimodalSearch(AnalysisMethod):
) )
return features_image_stacked return features_image_stacked
def extract_image_features_clip(self, model, images_tensors): def extract_image_features_clip(
self, model, images_tensors: torch.Tensor
) -> torch.Tensor:
"""
Extract image features from images_tensors with clip_feature_extractor model.
Args:
model (torch.nn.Module): model.
images_tensors (torch.Tensor): tensors of images stacked in device.
Returns:
features_image_stacked (torch.Tensor): tensors of images features stacked in device.
"""
features_image = [ features_image = [
model.extract_features({"image": ten}) for ten in images_tensors model.extract_features({"image": ten}) for ten in images_tensors
] ]
@ -112,7 +222,19 @@ class MultimodalSearch(AnalysisMethod):
) )
return features_image_stacked return features_image_stacked
def extract_image_features_basic(self, model, images_tensors): def extract_image_features_basic(
self, model, images_tensors: torch.Tensor
) -> torch.Tensor:
"""
Extract image features from images_tensors with blip_feature_extractor or albef_feature_extractor model.
Args:
model (torch.nn.Module): model.
images_tensors (torch.Tensor): tensors of images stacked in device.
Returns:
features_image_stacked (torch.Tensor): tensors of images features stacked in device.
"""
features_image = [ features_image = [
model.extract_features({"image": ten, "text_input": ""}, mode="image") model.extract_features({"image": ten, "text_input": ""}, mode="image")
for ten in images_tensors for ten in images_tensors
@ -124,11 +246,23 @@ class MultimodalSearch(AnalysisMethod):
def save_tensors( def save_tensors(
self, self,
model_type, model_type: str,
features_image_stacked, features_image_stacked: torch.Tensor,
name="saved_features_image.pt", name: str = "saved_features_image.pt",
path="./saved_tensors/", path: str = "./saved_tensors/",
): ) -> str:
"""
Save tensors as binary to given path.
Args:
model_type (str): type of the model.
features_image_stacked (torch.Tensor): tensors of images features stacked in device.
name (str): name of the file. Default: "saved_features_image.pt".
path (str): path to save the file. Default: "./saved_tensors/".
Returns:
name (str): name of the file.
"""
if not os.path.exists(path): if not os.path.exists(path):
os.makedirs(path) os.makedirs(path)
with open( with open(
@ -143,11 +277,30 @@ class MultimodalSearch(AnalysisMethod):
torch.save(features_image_stacked, f) torch.save(features_image_stacked, f)
return name return name
def load_tensors(self, name): def load_tensors(self, name: str) -> torch.Tensor:
"""
Load tensors from given path.
Args:
name (str): name of the file.
Returns:
features_image_stacked (torch.Tensor): tensors of images features.
"""
features_image_stacked = torch.load(name) features_image_stacked = torch.load(name)
return features_image_stacked return features_image_stacked
def extract_text_features(self, model, text_input): def extract_text_features(self, model, text_input: str) -> torch.Tensor:
"""
Extract text features from text_input with feature_extractor model.
Args:
model (torch.nn.Module): model.
text_input (str): text.
Returns:
features_text (torch.Tensor): tensors of text features.
"""
sample_text = {"text_input": [text_input]} sample_text = {"text_input": [text_input]}
features_text = model.extract_features(sample_text, mode="text") features_text = model.extract_features(sample_text, mode="text")
@ -155,10 +308,26 @@ class MultimodalSearch(AnalysisMethod):
def parsing_images( def parsing_images(
self, self,
model_type, model_type: str,
path_to_saved_tensors="./saved_tensors/", path_to_save_tensors: str = "./saved_tensors/",
path_to_load_tensors=None, path_to_load_tensors: str = None,
): ) -> tuple:
"""
Parsing images with feature_extractor model.
Args:
model_type (str): type of the model.
path_to_save_tensors (str): path to save the tensors. Default: "./saved_tensors/".
path_to_load_tensors (str): path to load the tesors. Default: None.
Returns:
model (torch.nn.Module): model.
vis_processors (dict): preprocessors for visual inputs.
txt_processors (dict): preprocessors for text inputs.
image_keys (list): sorted list of image keys.
image_names (list): sorted list of image names.
features_image_stacked (torch.Tensor): tensors of images features stacked in device.
"""
if model_type in ("clip_base", "clip_vitl14_336", "clip_vitl14"): if model_type in ("clip_base", "clip_vitl14_336", "clip_vitl14"):
path_to_lib = lavis.__file__[:-11] + "models/clip_models/" path_to_lib = lavis.__file__[:-11] + "models/clip_models/"
url = "https://raw.githubusercontent.com/salesforce/LAVIS/main/lavis/models/clip_models/bpe_simple_vocab_16e6.txt.gz" url = "https://raw.githubusercontent.com/salesforce/LAVIS/main/lavis/models/clip_models/bpe_simple_vocab_16e6.txt.gz"
@ -208,7 +377,7 @@ class MultimodalSearch(AnalysisMethod):
self, model, images_tensors self, model, images_tensors
) )
MultimodalSearch.save_tensors( MultimodalSearch.save_tensors(
self, model_type, features_image_stacked, path=path_to_saved_tensors self, model_type, features_image_stacked, path=path_to_save_tensors
) )
else: else:
features_image_stacked = MultimodalSearch.load_tensors( features_image_stacked = MultimodalSearch.load_tensors(
@ -225,8 +394,21 @@ class MultimodalSearch(AnalysisMethod):
) )
def querys_processing( def querys_processing(
self, search_query, model, txt_processors, vis_processors, model_type self, search_query: list, model, txt_processors, vis_processors, model_type: str
): ) -> torch.Tensor:
"""
Process querys.
Args:
search_query (list): list of querys.
model (torch.nn.Module): model.
txt_processors (dict): preprocessors for text inputs.
vis_processors (dict): preprocessors for visual inputs.
model_type (str): type of the model.
Returns:
multi_features_stacked (torch.Tensor): tensors of querys features.
"""
select_extract_image_features = { select_extract_image_features = {
"blip2": MultimodalSearch.extract_image_features_blip2, "blip2": MultimodalSearch.extract_image_features_blip2,
"blip": MultimodalSearch.extract_image_features_basic, "blip": MultimodalSearch.extract_image_features_basic,
@ -295,14 +477,33 @@ class MultimodalSearch(AnalysisMethod):
model, model,
vis_processors, vis_processors,
txt_processors, txt_processors,
model_type, model_type: str,
image_keys, image_keys: list,
features_image_stacked, features_image_stacked: torch.Tensor,
search_query, search_query: list,
filter_number_of_images=None, filter_number_of_images: str = None,
filter_val_limit=None, filter_val_limit: str = None,
filter_rel_error=None, filter_rel_error: str = None,
): ) -> tuple:
"""
Search for images with given querys.
Args:
model (torch.nn.Module): model.
vis_processors (dict): preprocessors for visual inputs.
txt_processors (dict): preprocessors for text inputs.
model_type (str): type of the model.
image_keys (list): sorted list of image keys.
features_image_stacked (torch.Tensor): tensors of images features stacked in device.
search_query (list): list of querys.
filter_number_of_images (str): number of images to show. Default: None.
filter_val_limit (str): limit of similarity value. Default: None.
filter_rel_error (str): limit of relative error. Default: None.
Returns:
similarity (torch.Tensor): similarity between images and querys.
sorted_lists (list): sorted list of similarity.
"""
if filter_number_of_images is None: if filter_number_of_images is None:
filter_number_of_images = len(self) filter_number_of_images = len(self)
if filter_val_limit is None: if filter_val_limit is None:
@ -343,7 +544,16 @@ class MultimodalSearch(AnalysisMethod):
self[image_keys[key]][list(search_query[q].values())[0]] = 0 self[image_keys[key]][list(search_query[q].values())[0]] = 0
return similarity, sorted_lists return similarity, sorted_lists
def itm_text_precessing(self, search_query): def itm_text_precessing(self, search_query: list[dict[str, str]]) -> list:
"""
Process text querys for itm model.
Args:
search_query (list): list of querys.
Returns:
text_query_index (list): list of indexes of text querys.
"""
for query in search_query: for query in search_query:
if (len(query) != 1) and (query in ("image", "text_input")): if (len(query) != 1) and (query in ("image", "text_input")):
raise SyntaxError( raise SyntaxError(
@ -356,7 +566,17 @@ class MultimodalSearch(AnalysisMethod):
return text_query_index return text_query_index
def get_pathes_from_query(self, query): def get_pathes_from_query(self, query: dict[str, str]) -> tuple:
"""
Get pathes and image names from query.
Args:
query (dict): query.
Returns:
paths (list): list of pathes.
image_names (list): list of image names.
"""
paths = [] paths = []
image_names = [] image_names = []
for s in sorted( for s in sorted(
@ -368,7 +588,18 @@ class MultimodalSearch(AnalysisMethod):
image_names.append(s[0]) image_names.append(s[0])
return paths, image_names return paths, image_names
def read_and_process_images_itm(self, image_paths, vis_processor): def read_and_process_images_itm(self, image_paths: list, vis_processor) -> tuple:
"""
Read and process images with vis_processor for itm model.
Args:
image_paths (list): paths to images.
vis_processor (dict): preprocessors for visual inputs.
Returns:
raw_images (list): list of images.
images_tensors (torch.Tensor): tensors of images stacked in device.
"""
raw_images = [MultimodalSearch.read_img(self, path) for path in image_paths] raw_images = [MultimodalSearch.read_img(self, path) for path in image_paths]
images = [vis_processor(r_img) for r_img in raw_images] images = [vis_processor(r_img) for r_img in raw_images]
images_tensors = torch.stack(images).to(MultimodalSearch.multimodal_device) images_tensors = torch.stack(images).to(MultimodalSearch.multimodal_device)
@ -377,12 +608,26 @@ class MultimodalSearch(AnalysisMethod):
def compute_gradcam_batch( def compute_gradcam_batch(
self, self,
model, model: torch.nn.Module,
visual_input, visual_input: torch.Tensor,
text_input, text_input: str,
tokenized_text, tokenized_text: torch.Tensor,
block_num=6, block_num: str = 6,
): ) -> tuple:
"""
Compute gradcam for itm model.
Args:
model (torch.nn.Module): model.
visual_input (torch.Tensor): tensors of images features stacked in device.
text_input (str): text.
tokenized_text (torch.Tensor): tokenized text.
block_num (int): number of block. Default: 6.
Returns:
gradcam (torch.Tensor): gradcam.
output (torch.Tensor): output of model.
"""
model.text_encoder.base_model.base_model.encoder.layer[ model.text_encoder.base_model.base_model.encoder.layer[
block_num block_num
].crossattention.self.save_attention = True ].crossattention.self.save_attention = True
@ -430,7 +675,16 @@ class MultimodalSearch(AnalysisMethod):
return gradcam, output return gradcam, output
def resize_img(self, raw_img): def resize_img(self, raw_img: Image):
"""
Proportional resize image to 240 p width.
Args:
raw_img (PIL.Image): image.
Returns:
resized_image (PIL.Image): proportional resized image to 240p.
"""
w, h = raw_img.size w, h = raw_img.size
scaling_factor = 240 / w scaling_factor = 240 / w
resized_image = raw_img.resize( resized_image = raw_img.resize(
@ -438,7 +692,25 @@ class MultimodalSearch(AnalysisMethod):
) )
return resized_image return resized_image
def get_att_map(self, img, att_map, blur=True, overlap=True): def get_att_map(
self,
img: np.ndarray,
att_map: np.ndarray,
blur: bool = True,
overlap: bool = True,
) -> np.ndarray:
"""
Get attention map.
Args:
img (np.ndarray): image.
att_map (np.ndarray): attention map.
blur (bool): blur attention map. Default: True.
overlap (bool): overlap attention map with image. Default: True.
Returns:
att_map (np.ndarray): attention map.
"""
att_map -= att_map.min() att_map -= att_map.min()
if att_map.max() > 0: if att_map.max() > 0:
att_map /= att_map.max() att_map /= att_map.max()
@ -459,7 +731,14 @@ class MultimodalSearch(AnalysisMethod):
) )
return att_map return att_map
def upload_model_blip2_coco(self): def upload_model_blip2_coco(self) -> tuple:
"""
Load coco blip2_image_text_matching model and preprocessors for visual inputs from lavis.models.
Returns:
itm_model (torch.nn.Module): model.
vis_processor (dict): preprocessors for visual inputs.
"""
itm_model = load_model( itm_model = load_model(
"blip2_image_text_matching", "blip2_image_text_matching",
"coco", "coco",
@ -469,7 +748,14 @@ class MultimodalSearch(AnalysisMethod):
vis_processor = load_processor("blip_image_eval").build(image_size=364) vis_processor = load_processor("blip_image_eval").build(image_size=364)
return itm_model, vis_processor return itm_model, vis_processor
def upload_model_blip_base(self): def upload_model_blip_base(self) -> tuple:
"""
Load base blip_image_text_matching model and preprocessors for visual input from lavis.models.
Returns:
itm_model (torch.nn.Module): model.
vis_processor (dict): preprocessors for visual inputs.
"""
itm_model = load_model( itm_model = load_model(
"blip_image_text_matching", "blip_image_text_matching",
"base", "base",
@ -479,7 +765,14 @@ class MultimodalSearch(AnalysisMethod):
vis_processor = load_processor("blip_image_eval").build(image_size=384) vis_processor = load_processor("blip_image_eval").build(image_size=384)
return itm_model, vis_processor return itm_model, vis_processor
def upload_model_blip_large(self): def upload_model_blip_large(self) -> tuple:
"""
Load large blip_image_text_matching model and preprocessors for visual input from lavis.models.
Returns:
itm_model (torch.nn.Module): model.
vis_processor (dict): preprocessors for visual inputs.
"""
itm_model = load_model( itm_model = load_model(
"blip_image_text_matching", "blip_image_text_matching",
"large", "large",
@ -491,13 +784,28 @@ class MultimodalSearch(AnalysisMethod):
def image_text_match_reordering( def image_text_match_reordering(
self, self,
search_query, search_query: list[dict[str, str]],
itm_model_type, itm_model_type: str,
image_keys, image_keys: list,
sorted_lists, sorted_lists: list[list],
batch_size=1, batch_size: int = 1,
need_grad_cam=False, need_grad_cam: bool = False,
): ) -> tuple:
"""
Reorder images with itm model.
Args:
search_query (list): list of querys.
itm_model_type (str): type of the model.
image_keys (list): sorted list of image keys.
sorted_lists (list): sorted list of similarity.
batch_size (int): batch size. Default: 1.
need_grad_cam (bool): need gradcam. Default: False. blip2_coco model does not yet work with gradcam.
Returns:
itm_scores2: list of itm scores.
image_gradcam_with_itm: dict of image names and gradcam.
"""
if itm_model_type == "blip2_coco" and need_grad_cam is True: if itm_model_type == "blip2_coco" and need_grad_cam is True:
raise SyntaxError( raise SyntaxError(
"The blip2_coco model does not yet work with gradcam. Please set need_grad_cam to False" "The blip2_coco model does not yet work with gradcam. Please set need_grad_cam to False"
@ -624,7 +932,17 @@ class MultimodalSearch(AnalysisMethod):
torch.cuda.empty_cache() torch.cuda.empty_cache()
return itm_scores2, image_gradcam_with_itm return itm_scores2, image_gradcam_with_itm
def show_results(self, query, itm=False, image_gradcam_with_itm=False): def show_results(
self, query: dict, itm=False, image_gradcam_with_itm=False
) -> None:
"""
Show results of search.
Args:
query (dict): query.
itm (bool): use itm model. Default: False.
image_gradcam_with_itm (bool): use gradcam. Default: False.
"""
if "image" in query.keys(): if "image" in query.keys():
pic = Image.open(query["image"]).convert("RGB") pic = Image.open(query["image"]).convert("RGB")
pic.thumbnail((400, 400)) pic.thumbnail((400, 400))