зеркало из
https://github.com/ssciwr/AMMICO.git
synced 2025-10-29 13:06:04 +02:00
optimize validation of analysis type
Этот коммит содержится в:
родитель
d1a4954669
Коммит
ece132fe14
@ -6,7 +6,7 @@ import torch
|
||||
from PIL import Image
|
||||
import warnings
|
||||
|
||||
from typing import List, Optional, Union, Dict, Any
|
||||
from typing import List, Optional, Union, Dict, Any, Tuple
|
||||
from collections.abc import Sequence as _Sequence
|
||||
from transformers import GenerationConfig
|
||||
from qwen_vl_utils import process_vision_info
|
||||
@ -118,6 +118,36 @@ class ImageSummaryDetector(AnalysisMethod):
|
||||
|
||||
return inputs
|
||||
|
||||
def _validate_analysis_type(
|
||||
self,
|
||||
analysis_type: Union["AnalysisType", str],
|
||||
list_of_questions: Optional[List[str]],
|
||||
max_questions_per_image: int,
|
||||
) -> Tuple[str, List[str], bool, bool]:
|
||||
if isinstance(analysis_type, AnalysisType):
|
||||
analysis_type = analysis_type.value
|
||||
|
||||
allowed = {"summary", "questions", "summary_and_questions"}
|
||||
if analysis_type not in allowed:
|
||||
raise ValueError(f"analysis_type must be one of {allowed}")
|
||||
|
||||
if list_of_questions is None:
|
||||
list_of_questions = [
|
||||
"Are there people in the image?",
|
||||
"What is this picture about?",
|
||||
]
|
||||
|
||||
if analysis_type in ("questions", "summary_and_questions"):
|
||||
if len(list_of_questions) > max_questions_per_image:
|
||||
raise ValueError(
|
||||
f"Number of questions per image ({len(list_of_questions)}) exceeds safety cap ({max_questions_per_image}). Reduce questions or increase max_questions_per_image."
|
||||
)
|
||||
|
||||
is_summary = analysis_type in ("summary", "summary_and_questions")
|
||||
is_questions = analysis_type in ("questions", "summary_and_questions")
|
||||
|
||||
return analysis_type, list_of_questions, is_summary, is_questions
|
||||
|
||||
def analyse_images(
|
||||
self,
|
||||
analysis_type: Union[AnalysisType, str] = AnalysisType.SUMMARY_AND_QUESTIONS,
|
||||
@ -141,25 +171,18 @@ class ImageSummaryDetector(AnalysisMethod):
|
||||
self.subdict (dict): dictionary with analysis results.
|
||||
"""
|
||||
# TODO: add option to ask multiple questions per image as one batch.
|
||||
if isinstance(analysis_type, AnalysisType):
|
||||
analysis_type = analysis_type.value
|
||||
|
||||
allowed = {"summary", "questions", "summary_and_questions"}
|
||||
if analysis_type not in allowed:
|
||||
raise ValueError(f"analysis_type must be one of {allowed}")
|
||||
|
||||
if list_of_questions is None:
|
||||
list_of_questions = [
|
||||
"Are there people in the image?",
|
||||
"What is this picture about?",
|
||||
]
|
||||
analysis_type, list_of_questions, is_summary, is_questions = (
|
||||
self._validate_analysis_type(
|
||||
analysis_type, list_of_questions, max_questions_per_image
|
||||
)
|
||||
)
|
||||
|
||||
keys = list(self.subdict.keys())
|
||||
for batch_start in range(0, len(keys), keys_batch_size):
|
||||
batch_keys = keys[batch_start : batch_start + keys_batch_size]
|
||||
for key in batch_keys:
|
||||
entry = self.subdict[key]
|
||||
if analysis_type in ("summary", "summary_and_questions"):
|
||||
if is_summary:
|
||||
try:
|
||||
caps = self.generate_caption(
|
||||
entry,
|
||||
@ -172,12 +195,7 @@ class ImageSummaryDetector(AnalysisMethod):
|
||||
"Caption generation failed for key %s: %s", key, e
|
||||
)
|
||||
|
||||
if analysis_type in ("questions", "summary_and_questions"):
|
||||
if len(list_of_questions) > max_questions_per_image:
|
||||
raise ValueError(
|
||||
f"Number of questions per image ({len(list_of_questions)}) exceeds safety cap ({max_questions_per_image})."
|
||||
" Reduce questions or increase max_questions_per_image."
|
||||
)
|
||||
if is_questions:
|
||||
try:
|
||||
vqa_map = self.answer_questions(
|
||||
list_of_questions, entry, is_concise_answer
|
||||
|
||||
Загрузка…
x
Ссылка в новой задаче
Block a user