From de8ee83432af2cdd8982cc1030c6ab6cf923169a Mon Sep 17 00:00:00 2001 From: Inga Ulusoy Date: Fri, 24 Oct 2025 16:36:03 +0200 Subject: [PATCH] refactor: use dictionary mapping for values, check question list strings for None --- ammico/image_summary.py | 72 +++++++++++++++++++++++-------- ammico/test/test_image_summary.py | 21 +++++++++ 2 files changed, 74 insertions(+), 19 deletions(-) diff --git a/ammico/image_summary.py b/ammico/image_summary.py index 203ef21..7d462f2 100644 --- a/ammico/image_summary.py +++ b/ammico/image_summary.py @@ -13,6 +13,22 @@ from qwen_vl_utils import process_vision_info class ImageSummaryDetector(AnalysisMethod): + token_prompt_config = { + "default": { + "summary": {"prompt": "Describe this image.", "max_new_tokens": 256}, + "questions": {"prompt": "", "max_new_tokens": 128}, + }, + "concise": { + "summary": { + "prompt": "Describe this image in one concise caption.", + "max_new_tokens": 64, + }, + "questions": {"prompt": "Answer concisely: ", "max_new_tokens": 128}, + }, + } + MAX_QUESTIONS_PER_IMAGE = 32 + KEYS_BATCH_SIZE = 16 + def __init__( self, summary_model: MultimodalSummaryModel, @@ -155,7 +171,7 @@ class ImageSummaryDetector(AnalysisMethod): entry: dict, analysis_type: Union[str, AnalysisType] = AnalysisType.SUMMARY_AND_QUESTIONS, list_of_questions: Optional[List[str]] = None, - max_questions_per_image: int = 32, + max_questions_per_image: int = MAX_QUESTIONS_PER_IMAGE, is_concise_summary: bool = True, is_concise_answer: bool = True, ) -> Dict[str, Any]: @@ -197,8 +213,8 @@ class ImageSummaryDetector(AnalysisMethod): self, analysis_type: Union[AnalysisType, str] = AnalysisType.SUMMARY_AND_QUESTIONS, list_of_questions: Optional[List[str]] = None, - max_questions_per_image: int = 32, - keys_batch_size: int = 16, + max_questions_per_image: int = MAX_QUESTIONS_PER_IMAGE, + keys_batch_size: int = KEYS_BATCH_SIZE, is_concise_summary: bool = True, is_concise_answer: bool = True, ) -> Dict[str, dict]: @@ -267,12 +283,12 @@ class ImageSummaryDetector(AnalysisMethod): Returns: results (list[str]): list of generated captions. """ - if is_concise_summary: - prompt = ["Describe this image in one concise caption."] - max_new_tokens = 64 - else: - prompt = ["Describe this image."] - max_new_tokens = 256 + prompt = self.token_prompt_config[ + "concise" if is_concise_summary else "default" + ]["summary"]["prompt"] + max_new_tokens = self.token_prompt_config[ + "concise" if is_concise_summary else "default" + ]["summary"]["max_new_tokens"] inputs = self._prepare_inputs(prompt, entry) gen_conf = GenerationConfig( @@ -333,6 +349,24 @@ class ImageSummaryDetector(AnalysisMethod): results = [d.strip() for d in decoded] return results + def _clean_list_of_questions( + self, list_of_questions: list[str], prompt: str + ) -> list[str]: + """Clean the list of questions to contain correctly formatted strings.""" + # remove all None or empty questions + list_of_questions = [i for i in list_of_questions if i and i.strip()] + # ensure each question ends with a question mark + list_of_questions = [ + i.strip() + "?" if not i.strip().endswith("?") else i.strip() + for i in list_of_questions + ] + # ensure each question starts with the prompt + list_of_questions = [ + i if i.lower().startswith(prompt.lower()) else prompt + i + for i in list_of_questions + ] + return list_of_questions + def answer_questions( self, list_of_questions: list[str], @@ -348,15 +382,15 @@ class ImageSummaryDetector(AnalysisMethod): Returns: answers (list[str]): list of answers. """ - if is_concise_answer: - gen_conf = GenerationConfig(max_new_tokens=64, do_sample=False) - for i in range(len(list_of_questions)): - if not list_of_questions[i].strip().endswith("?"): - list_of_questions[i] = list_of_questions[i].strip() + "?" - if not list_of_questions[i].lower().startswith("answer concisely"): - list_of_questions[i] = "Answer concisely: " + list_of_questions[i] - else: - gen_conf = GenerationConfig(max_new_tokens=128, do_sample=False) + prompt = self.token_prompt_config[ + "concise" if is_concise_answer else "default" + ]["answer"]["prompt"] + max_new_tokens = self.token_prompt_config[ + "concise" if is_concise_answer else "default" + ]["answer"]["max_new_tokens"] + + list_of_questions = self._clean_list_of_questions(list_of_questions, prompt) + gen_conf = GenerationConfig(max_new_tokens=max_new_tokens, do_sample=False) question_chunk_size = 8 answers: List[str] = [] @@ -396,7 +430,7 @@ class ImageSummaryDetector(AnalysisMethod): if len(answers) != len(list_of_questions): raise ValueError( - f"Expected {len(list_of_questions)} answers, but got {len(answers)}, try vary amount of questions" + f"Expected {len(list_of_questions)} answers, but got {len(answers)}, try varying amount of questions" ) return answers diff --git a/ammico/test/test_image_summary.py b/ammico/test/test_image_summary.py index ad48298..b56d806 100644 --- a/ammico/test/test_image_summary.py +++ b/ammico/test/test_image_summary.py @@ -35,3 +35,24 @@ def test_image_summary_detector_questions(model, get_testdict): assert ( "two" in results[key]["vqa"][1].lower() or "2" in results[key]["vqa"][1] ) + + +def test_clean_list_of_questions(model): + list_of_questions = [ + "What is happening in the image?", + "", + " ", + None, + "How many cars are in the image in total", + ] + detector = ImageSummaryDetector(summary_model=model, subdict={}) + prompt = detector.token_prompt_config["default"]["questions"]["prompt"] + cleaned_questions = detector._clean_list_of_questions(list_of_questions, prompt) + assert len(cleaned_questions) == 2 + assert cleaned_questions[0] == "What is happening in the image?" + assert cleaned_questions[1] == "How many cars are in the image in total?" + prompt = detector.token_prompt_config["concise"]["questions"]["prompt"] + cleaned_questions = detector._clean_list_of_questions(list_of_questions, prompt) + assert len(cleaned_questions) == 2 + assert cleaned_questions[0] == prompt + "What is happening in the image?" + assert cleaned_questions[1] == prompt + "How many cars are in the image in total?"