From ece132fe14a6139d88d2d7e7559ab6ed10980b21 Mon Sep 17 00:00:00 2001 From: DimasfromLavoisier Date: Fri, 26 Sep 2025 17:29:46 +0200 Subject: [PATCH] optimize validation of analysis type --- ammico/image_summary.py | 58 +++++++++++++++++++++++++++-------------- 1 file changed, 38 insertions(+), 20 deletions(-) diff --git a/ammico/image_summary.py b/ammico/image_summary.py index 0cdaebe..3ccc3f4 100644 --- a/ammico/image_summary.py +++ b/ammico/image_summary.py @@ -6,7 +6,7 @@ import torch from PIL import Image import warnings -from typing import List, Optional, Union, Dict, Any +from typing import List, Optional, Union, Dict, Any, Tuple from collections.abc import Sequence as _Sequence from transformers import GenerationConfig from qwen_vl_utils import process_vision_info @@ -118,6 +118,36 @@ class ImageSummaryDetector(AnalysisMethod): return inputs + def _validate_analysis_type( + self, + analysis_type: Union["AnalysisType", str], + list_of_questions: Optional[List[str]], + max_questions_per_image: int, + ) -> Tuple[str, List[str], bool, bool]: + if isinstance(analysis_type, AnalysisType): + analysis_type = analysis_type.value + + allowed = {"summary", "questions", "summary_and_questions"} + if analysis_type not in allowed: + raise ValueError(f"analysis_type must be one of {allowed}") + + if list_of_questions is None: + list_of_questions = [ + "Are there people in the image?", + "What is this picture about?", + ] + + if analysis_type in ("questions", "summary_and_questions"): + if len(list_of_questions) > max_questions_per_image: + raise ValueError( + f"Number of questions per image ({len(list_of_questions)}) exceeds safety cap ({max_questions_per_image}). Reduce questions or increase max_questions_per_image." + ) + + is_summary = analysis_type in ("summary", "summary_and_questions") + is_questions = analysis_type in ("questions", "summary_and_questions") + + return analysis_type, list_of_questions, is_summary, is_questions + def analyse_images( self, analysis_type: Union[AnalysisType, str] = AnalysisType.SUMMARY_AND_QUESTIONS, @@ -141,25 +171,18 @@ class ImageSummaryDetector(AnalysisMethod): self.subdict (dict): dictionary with analysis results. """ # TODO: add option to ask multiple questions per image as one batch. - if isinstance(analysis_type, AnalysisType): - analysis_type = analysis_type.value - - allowed = {"summary", "questions", "summary_and_questions"} - if analysis_type not in allowed: - raise ValueError(f"analysis_type must be one of {allowed}") - - if list_of_questions is None: - list_of_questions = [ - "Are there people in the image?", - "What is this picture about?", - ] + analysis_type, list_of_questions, is_summary, is_questions = ( + self._validate_analysis_type( + analysis_type, list_of_questions, max_questions_per_image + ) + ) keys = list(self.subdict.keys()) for batch_start in range(0, len(keys), keys_batch_size): batch_keys = keys[batch_start : batch_start + keys_batch_size] for key in batch_keys: entry = self.subdict[key] - if analysis_type in ("summary", "summary_and_questions"): + if is_summary: try: caps = self.generate_caption( entry, @@ -172,12 +195,7 @@ class ImageSummaryDetector(AnalysisMethod): "Caption generation failed for key %s: %s", key, e ) - if analysis_type in ("questions", "summary_and_questions"): - if len(list_of_questions) > max_questions_per_image: - raise ValueError( - f"Number of questions per image ({len(list_of_questions)}) exceeds safety cap ({max_questions_per_image})." - " Reduce questions or increase max_questions_per_image." - ) + if is_questions: try: vqa_map = self.answer_questions( list_of_questions, entry, is_concise_answer