optimize validation of analysis type

Этот коммит содержится в:
DimasfromLavoisier 2025-09-26 17:29:46 +02:00
родитель d1a4954669
Коммит ece132fe14

Просмотреть файл

@ -6,7 +6,7 @@ import torch
from PIL import Image from PIL import Image
import warnings import warnings
from typing import List, Optional, Union, Dict, Any from typing import List, Optional, Union, Dict, Any, Tuple
from collections.abc import Sequence as _Sequence from collections.abc import Sequence as _Sequence
from transformers import GenerationConfig from transformers import GenerationConfig
from qwen_vl_utils import process_vision_info from qwen_vl_utils import process_vision_info
@ -118,6 +118,36 @@ class ImageSummaryDetector(AnalysisMethod):
return inputs return inputs
def _validate_analysis_type(
self,
analysis_type: Union["AnalysisType", str],
list_of_questions: Optional[List[str]],
max_questions_per_image: int,
) -> Tuple[str, List[str], bool, bool]:
if isinstance(analysis_type, AnalysisType):
analysis_type = analysis_type.value
allowed = {"summary", "questions", "summary_and_questions"}
if analysis_type not in allowed:
raise ValueError(f"analysis_type must be one of {allowed}")
if list_of_questions is None:
list_of_questions = [
"Are there people in the image?",
"What is this picture about?",
]
if analysis_type in ("questions", "summary_and_questions"):
if len(list_of_questions) > max_questions_per_image:
raise ValueError(
f"Number of questions per image ({len(list_of_questions)}) exceeds safety cap ({max_questions_per_image}). Reduce questions or increase max_questions_per_image."
)
is_summary = analysis_type in ("summary", "summary_and_questions")
is_questions = analysis_type in ("questions", "summary_and_questions")
return analysis_type, list_of_questions, is_summary, is_questions
def analyse_images( def analyse_images(
self, self,
analysis_type: Union[AnalysisType, str] = AnalysisType.SUMMARY_AND_QUESTIONS, analysis_type: Union[AnalysisType, str] = AnalysisType.SUMMARY_AND_QUESTIONS,
@ -141,25 +171,18 @@ class ImageSummaryDetector(AnalysisMethod):
self.subdict (dict): dictionary with analysis results. self.subdict (dict): dictionary with analysis results.
""" """
# TODO: add option to ask multiple questions per image as one batch. # TODO: add option to ask multiple questions per image as one batch.
if isinstance(analysis_type, AnalysisType): analysis_type, list_of_questions, is_summary, is_questions = (
analysis_type = analysis_type.value self._validate_analysis_type(
analysis_type, list_of_questions, max_questions_per_image
allowed = {"summary", "questions", "summary_and_questions"} )
if analysis_type not in allowed: )
raise ValueError(f"analysis_type must be one of {allowed}")
if list_of_questions is None:
list_of_questions = [
"Are there people in the image?",
"What is this picture about?",
]
keys = list(self.subdict.keys()) keys = list(self.subdict.keys())
for batch_start in range(0, len(keys), keys_batch_size): for batch_start in range(0, len(keys), keys_batch_size):
batch_keys = keys[batch_start : batch_start + keys_batch_size] batch_keys = keys[batch_start : batch_start + keys_batch_size]
for key in batch_keys: for key in batch_keys:
entry = self.subdict[key] entry = self.subdict[key]
if analysis_type in ("summary", "summary_and_questions"): if is_summary:
try: try:
caps = self.generate_caption( caps = self.generate_caption(
entry, entry,
@ -172,12 +195,7 @@ class ImageSummaryDetector(AnalysisMethod):
"Caption generation failed for key %s: %s", key, e "Caption generation failed for key %s: %s", key, e
) )
if analysis_type in ("questions", "summary_and_questions"): if is_questions:
if len(list_of_questions) > max_questions_per_image:
raise ValueError(
f"Number of questions per image ({len(list_of_questions)}) exceeds safety cap ({max_questions_per_image})."
" Reduce questions or increase max_questions_per_image."
)
try: try:
vqa_map = self.answer_questions( vqa_map = self.answer_questions(
list_of_questions, entry, is_concise_answer list_of_questions, entry, is_concise_answer