refactor: use dictionary mapping for values, check question list strings for None

Этот коммит содержится в:
Inga Ulusoy 2025-10-24 16:36:03 +02:00
родитель f277e86b29
Коммит de8ee83432
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 7E8998B3002D3B7C
2 изменённых файлов: 74 добавлений и 19 удалений

Просмотреть файл

@ -13,6 +13,22 @@ from qwen_vl_utils import process_vision_info
class ImageSummaryDetector(AnalysisMethod):
token_prompt_config = {
"default": {
"summary": {"prompt": "Describe this image.", "max_new_tokens": 256},
"questions": {"prompt": "", "max_new_tokens": 128},
},
"concise": {
"summary": {
"prompt": "Describe this image in one concise caption.",
"max_new_tokens": 64,
},
"questions": {"prompt": "Answer concisely: ", "max_new_tokens": 128},
},
}
MAX_QUESTIONS_PER_IMAGE = 32
KEYS_BATCH_SIZE = 16
def __init__(
self,
summary_model: MultimodalSummaryModel,
@ -155,7 +171,7 @@ class ImageSummaryDetector(AnalysisMethod):
entry: dict,
analysis_type: Union[str, AnalysisType] = AnalysisType.SUMMARY_AND_QUESTIONS,
list_of_questions: Optional[List[str]] = None,
max_questions_per_image: int = 32,
max_questions_per_image: int = MAX_QUESTIONS_PER_IMAGE,
is_concise_summary: bool = True,
is_concise_answer: bool = True,
) -> Dict[str, Any]:
@ -197,8 +213,8 @@ class ImageSummaryDetector(AnalysisMethod):
self,
analysis_type: Union[AnalysisType, str] = AnalysisType.SUMMARY_AND_QUESTIONS,
list_of_questions: Optional[List[str]] = None,
max_questions_per_image: int = 32,
keys_batch_size: int = 16,
max_questions_per_image: int = MAX_QUESTIONS_PER_IMAGE,
keys_batch_size: int = KEYS_BATCH_SIZE,
is_concise_summary: bool = True,
is_concise_answer: bool = True,
) -> Dict[str, dict]:
@ -267,12 +283,12 @@ class ImageSummaryDetector(AnalysisMethod):
Returns:
results (list[str]): list of generated captions.
"""
if is_concise_summary:
prompt = ["Describe this image in one concise caption."]
max_new_tokens = 64
else:
prompt = ["Describe this image."]
max_new_tokens = 256
prompt = self.token_prompt_config[
"concise" if is_concise_summary else "default"
]["summary"]["prompt"]
max_new_tokens = self.token_prompt_config[
"concise" if is_concise_summary else "default"
]["summary"]["max_new_tokens"]
inputs = self._prepare_inputs(prompt, entry)
gen_conf = GenerationConfig(
@ -333,6 +349,24 @@ class ImageSummaryDetector(AnalysisMethod):
results = [d.strip() for d in decoded]
return results
def _clean_list_of_questions(
self, list_of_questions: list[str], prompt: str
) -> list[str]:
"""Clean the list of questions to contain correctly formatted strings."""
# remove all None or empty questions
list_of_questions = [i for i in list_of_questions if i and i.strip()]
# ensure each question ends with a question mark
list_of_questions = [
i.strip() + "?" if not i.strip().endswith("?") else i.strip()
for i in list_of_questions
]
# ensure each question starts with the prompt
list_of_questions = [
i if i.lower().startswith(prompt.lower()) else prompt + i
for i in list_of_questions
]
return list_of_questions
def answer_questions(
self,
list_of_questions: list[str],
@ -348,15 +382,15 @@ class ImageSummaryDetector(AnalysisMethod):
Returns:
answers (list[str]): list of answers.
"""
if is_concise_answer:
gen_conf = GenerationConfig(max_new_tokens=64, do_sample=False)
for i in range(len(list_of_questions)):
if not list_of_questions[i].strip().endswith("?"):
list_of_questions[i] = list_of_questions[i].strip() + "?"
if not list_of_questions[i].lower().startswith("answer concisely"):
list_of_questions[i] = "Answer concisely: " + list_of_questions[i]
else:
gen_conf = GenerationConfig(max_new_tokens=128, do_sample=False)
prompt = self.token_prompt_config[
"concise" if is_concise_answer else "default"
]["answer"]["prompt"]
max_new_tokens = self.token_prompt_config[
"concise" if is_concise_answer else "default"
]["answer"]["max_new_tokens"]
list_of_questions = self._clean_list_of_questions(list_of_questions, prompt)
gen_conf = GenerationConfig(max_new_tokens=max_new_tokens, do_sample=False)
question_chunk_size = 8
answers: List[str] = []
@ -396,7 +430,7 @@ class ImageSummaryDetector(AnalysisMethod):
if len(answers) != len(list_of_questions):
raise ValueError(
f"Expected {len(list_of_questions)} answers, but got {len(answers)}, try vary amount of questions"
f"Expected {len(list_of_questions)} answers, but got {len(answers)}, try varying amount of questions"
)
return answers

Просмотреть файл

@ -35,3 +35,24 @@ def test_image_summary_detector_questions(model, get_testdict):
assert (
"two" in results[key]["vqa"][1].lower() or "2" in results[key]["vqa"][1]
)
def test_clean_list_of_questions(model):
list_of_questions = [
"What is happening in the image?",
"",
" ",
None,
"How many cars are in the image in total",
]
detector = ImageSummaryDetector(summary_model=model, subdict={})
prompt = detector.token_prompt_config["default"]["questions"]["prompt"]
cleaned_questions = detector._clean_list_of_questions(list_of_questions, prompt)
assert len(cleaned_questions) == 2
assert cleaned_questions[0] == "What is happening in the image?"
assert cleaned_questions[1] == "How many cars are in the image in total?"
prompt = detector.token_prompt_config["concise"]["questions"]["prompt"]
cleaned_questions = detector._clean_list_of_questions(list_of_questions, prompt)
assert len(cleaned_questions) == 2
assert cleaned_questions[0] == prompt + "What is happening in the image?"
assert cleaned_questions[1] == prompt + "How many cars are in the image in total?"