optimize validation of analysis type

Этот коммит содержится в:
DimasfromLavoisier 2025-09-26 17:29:46 +02:00
родитель d1a4954669
Коммит ece132fe14

Просмотреть файл

@ -6,7 +6,7 @@ import torch
from PIL import Image
import warnings
from typing import List, Optional, Union, Dict, Any
from typing import List, Optional, Union, Dict, Any, Tuple
from collections.abc import Sequence as _Sequence
from transformers import GenerationConfig
from qwen_vl_utils import process_vision_info
@ -118,6 +118,36 @@ class ImageSummaryDetector(AnalysisMethod):
return inputs
def _validate_analysis_type(
self,
analysis_type: Union["AnalysisType", str],
list_of_questions: Optional[List[str]],
max_questions_per_image: int,
) -> Tuple[str, List[str], bool, bool]:
if isinstance(analysis_type, AnalysisType):
analysis_type = analysis_type.value
allowed = {"summary", "questions", "summary_and_questions"}
if analysis_type not in allowed:
raise ValueError(f"analysis_type must be one of {allowed}")
if list_of_questions is None:
list_of_questions = [
"Are there people in the image?",
"What is this picture about?",
]
if analysis_type in ("questions", "summary_and_questions"):
if len(list_of_questions) > max_questions_per_image:
raise ValueError(
f"Number of questions per image ({len(list_of_questions)}) exceeds safety cap ({max_questions_per_image}). Reduce questions or increase max_questions_per_image."
)
is_summary = analysis_type in ("summary", "summary_and_questions")
is_questions = analysis_type in ("questions", "summary_and_questions")
return analysis_type, list_of_questions, is_summary, is_questions
def analyse_images(
self,
analysis_type: Union[AnalysisType, str] = AnalysisType.SUMMARY_AND_QUESTIONS,
@ -141,25 +171,18 @@ class ImageSummaryDetector(AnalysisMethod):
self.subdict (dict): dictionary with analysis results.
"""
# TODO: add option to ask multiple questions per image as one batch.
if isinstance(analysis_type, AnalysisType):
analysis_type = analysis_type.value
allowed = {"summary", "questions", "summary_and_questions"}
if analysis_type not in allowed:
raise ValueError(f"analysis_type must be one of {allowed}")
if list_of_questions is None:
list_of_questions = [
"Are there people in the image?",
"What is this picture about?",
]
analysis_type, list_of_questions, is_summary, is_questions = (
self._validate_analysis_type(
analysis_type, list_of_questions, max_questions_per_image
)
)
keys = list(self.subdict.keys())
for batch_start in range(0, len(keys), keys_batch_size):
batch_keys = keys[batch_start : batch_start + keys_batch_size]
for key in batch_keys:
entry = self.subdict[key]
if analysis_type in ("summary", "summary_and_questions"):
if is_summary:
try:
caps = self.generate_caption(
entry,
@ -172,12 +195,7 @@ class ImageSummaryDetector(AnalysisMethod):
"Caption generation failed for key %s: %s", key, e
)
if analysis_type in ("questions", "summary_and_questions"):
if len(list_of_questions) > max_questions_per_image:
raise ValueError(
f"Number of questions per image ({len(list_of_questions)}) exceeds safety cap ({max_questions_per_image})."
" Reduce questions or increase max_questions_per_image."
)
if is_questions:
try:
vqa_map = self.answer_questions(
list_of_questions, entry, is_concise_answer