diff --git a/ammico/summary.py b/ammico/summary.py index 43a464d..64cc502 100644 --- a/ammico/summary.py +++ b/ammico/summary.py @@ -10,6 +10,15 @@ class SummaryDetector(AnalysisMethod): self.summary_device = device("cuda" if cuda.is_available() else "cpu") def load_model_base(self): + """ + Load base_coco blip_caption model and preprocessors for visual inputs from lavis.models. + + Args: + + Returns: + model (torch.nn.Module): model. + vis_processors (dict): preprocessors for visual inputs. + """ summary_model, summary_vis_processors, _ = load_model_and_preprocess( name="blip_caption", model_type="base_coco", @@ -19,6 +28,15 @@ class SummaryDetector(AnalysisMethod): return summary_model, summary_vis_processors def load_model_large(self): + """ + Load large_coco blip_caption model and preprocessors for visual inputs from lavis.models. + + Args: + + Returns: + model (torch.nn.Module): model. + vis_processors (dict): preprocessors for visual inputs. + """ summary_model, summary_vis_processors, _ = load_model_and_preprocess( name="blip_caption", model_type="large_coco", @@ -27,7 +45,17 @@ class SummaryDetector(AnalysisMethod): ) return summary_model, summary_vis_processors - def load_model(self, model_type): + def load_model(self, model_type: str): + """ + Load blip_caption model and preprocessors for visual inputs from lavis.models. + + Args: + model_type (str): type of the model. + + Returns: + model (torch.nn.Module): model. + vis_processors (dict): preprocessors for visual inputs. + """ select_model = { "base": SummaryDetector.load_model_base, "large": SummaryDetector.load_model_large, @@ -36,6 +64,16 @@ class SummaryDetector(AnalysisMethod): return summary_model, summary_vis_processors def analyse_image(self, summary_model=None, summary_vis_processors=None): + """ + Create 1 constant and 3 non deterministic captions for image. + + Args: + summary_model (str): model. + summary_vis_processors (str): preprocessors for visual inputs. + + Returns: + self.subdict (dict): dictionary with constant image summary and 3 non deterministic summary. + """ if summary_model is None and summary_vis_processors is None: summary_model, summary_vis_processors = self.load_model_base() @@ -55,7 +93,16 @@ class SummaryDetector(AnalysisMethod): ) return self.subdict - def analyse_questions(self, list_of_questions): + def analyse_questions(self, list_of_questions: list[str]) -> dict: + """ + Generate answers to free-form questions about image written in natural language. + + Args: + list_of_questions (list[str]): list of questions. + + Returns: + self.subdict (dict): dictionary with answers to questions. + """ ( summary_vqa_model, summary_vqa_vis_processors,