Этот коммит содержится в:
Petr Andriushchenko 2023-05-22 17:03:50 +02:00
родитель 90d049fdf7
Коммит 8c3641ac24
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4C4A5DCF634115B6

Просмотреть файл

@ -10,6 +10,15 @@ class SummaryDetector(AnalysisMethod):
self.summary_device = device("cuda" if cuda.is_available() else "cpu")
def load_model_base(self):
"""
Load base_coco blip_caption model and preprocessors for visual inputs from lavis.models.
Args:
Returns:
model (torch.nn.Module): model.
vis_processors (dict): preprocessors for visual inputs.
"""
summary_model, summary_vis_processors, _ = load_model_and_preprocess(
name="blip_caption",
model_type="base_coco",
@ -19,6 +28,15 @@ class SummaryDetector(AnalysisMethod):
return summary_model, summary_vis_processors
def load_model_large(self):
"""
Load large_coco blip_caption model and preprocessors for visual inputs from lavis.models.
Args:
Returns:
model (torch.nn.Module): model.
vis_processors (dict): preprocessors for visual inputs.
"""
summary_model, summary_vis_processors, _ = load_model_and_preprocess(
name="blip_caption",
model_type="large_coco",
@ -27,7 +45,17 @@ class SummaryDetector(AnalysisMethod):
)
return summary_model, summary_vis_processors
def load_model(self, model_type):
def load_model(self, model_type: str):
"""
Load blip_caption model and preprocessors for visual inputs from lavis.models.
Args:
model_type (str): type of the model.
Returns:
model (torch.nn.Module): model.
vis_processors (dict): preprocessors for visual inputs.
"""
select_model = {
"base": SummaryDetector.load_model_base,
"large": SummaryDetector.load_model_large,
@ -36,6 +64,16 @@ class SummaryDetector(AnalysisMethod):
return summary_model, summary_vis_processors
def analyse_image(self, summary_model=None, summary_vis_processors=None):
"""
Create 1 constant and 3 non deterministic captions for image.
Args:
summary_model (str): model.
summary_vis_processors (str): preprocessors for visual inputs.
Returns:
self.subdict (dict): dictionary with constant image summary and 3 non deterministic summary.
"""
if summary_model is None and summary_vis_processors is None:
summary_model, summary_vis_processors = self.load_model_base()
@ -55,7 +93,16 @@ class SummaryDetector(AnalysisMethod):
)
return self.subdict
def analyse_questions(self, list_of_questions):
def analyse_questions(self, list_of_questions: list[str]) -> dict:
"""
Generate answers to free-form questions about image written in natural language.
Args:
list_of_questions (list[str]): list of questions.
Returns:
self.subdict (dict): dictionary with answers to questions.
"""
(
summary_vqa_model,
summary_vqa_vis_processors,