зеркало из
https://github.com/ssciwr/AMMICO.git
synced 2025-10-29 13:06:04 +02:00
refactor: use dictionary mapping for values, check question list strings for None
Этот коммит содержится в:
родитель
f277e86b29
Коммит
de8ee83432
@ -13,6 +13,22 @@ from qwen_vl_utils import process_vision_info
|
||||
|
||||
|
||||
class ImageSummaryDetector(AnalysisMethod):
|
||||
token_prompt_config = {
|
||||
"default": {
|
||||
"summary": {"prompt": "Describe this image.", "max_new_tokens": 256},
|
||||
"questions": {"prompt": "", "max_new_tokens": 128},
|
||||
},
|
||||
"concise": {
|
||||
"summary": {
|
||||
"prompt": "Describe this image in one concise caption.",
|
||||
"max_new_tokens": 64,
|
||||
},
|
||||
"questions": {"prompt": "Answer concisely: ", "max_new_tokens": 128},
|
||||
},
|
||||
}
|
||||
MAX_QUESTIONS_PER_IMAGE = 32
|
||||
KEYS_BATCH_SIZE = 16
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
summary_model: MultimodalSummaryModel,
|
||||
@ -155,7 +171,7 @@ class ImageSummaryDetector(AnalysisMethod):
|
||||
entry: dict,
|
||||
analysis_type: Union[str, AnalysisType] = AnalysisType.SUMMARY_AND_QUESTIONS,
|
||||
list_of_questions: Optional[List[str]] = None,
|
||||
max_questions_per_image: int = 32,
|
||||
max_questions_per_image: int = MAX_QUESTIONS_PER_IMAGE,
|
||||
is_concise_summary: bool = True,
|
||||
is_concise_answer: bool = True,
|
||||
) -> Dict[str, Any]:
|
||||
@ -197,8 +213,8 @@ class ImageSummaryDetector(AnalysisMethod):
|
||||
self,
|
||||
analysis_type: Union[AnalysisType, str] = AnalysisType.SUMMARY_AND_QUESTIONS,
|
||||
list_of_questions: Optional[List[str]] = None,
|
||||
max_questions_per_image: int = 32,
|
||||
keys_batch_size: int = 16,
|
||||
max_questions_per_image: int = MAX_QUESTIONS_PER_IMAGE,
|
||||
keys_batch_size: int = KEYS_BATCH_SIZE,
|
||||
is_concise_summary: bool = True,
|
||||
is_concise_answer: bool = True,
|
||||
) -> Dict[str, dict]:
|
||||
@ -267,12 +283,12 @@ class ImageSummaryDetector(AnalysisMethod):
|
||||
Returns:
|
||||
results (list[str]): list of generated captions.
|
||||
"""
|
||||
if is_concise_summary:
|
||||
prompt = ["Describe this image in one concise caption."]
|
||||
max_new_tokens = 64
|
||||
else:
|
||||
prompt = ["Describe this image."]
|
||||
max_new_tokens = 256
|
||||
prompt = self.token_prompt_config[
|
||||
"concise" if is_concise_summary else "default"
|
||||
]["summary"]["prompt"]
|
||||
max_new_tokens = self.token_prompt_config[
|
||||
"concise" if is_concise_summary else "default"
|
||||
]["summary"]["max_new_tokens"]
|
||||
inputs = self._prepare_inputs(prompt, entry)
|
||||
|
||||
gen_conf = GenerationConfig(
|
||||
@ -333,6 +349,24 @@ class ImageSummaryDetector(AnalysisMethod):
|
||||
results = [d.strip() for d in decoded]
|
||||
return results
|
||||
|
||||
def _clean_list_of_questions(
|
||||
self, list_of_questions: list[str], prompt: str
|
||||
) -> list[str]:
|
||||
"""Clean the list of questions to contain correctly formatted strings."""
|
||||
# remove all None or empty questions
|
||||
list_of_questions = [i for i in list_of_questions if i and i.strip()]
|
||||
# ensure each question ends with a question mark
|
||||
list_of_questions = [
|
||||
i.strip() + "?" if not i.strip().endswith("?") else i.strip()
|
||||
for i in list_of_questions
|
||||
]
|
||||
# ensure each question starts with the prompt
|
||||
list_of_questions = [
|
||||
i if i.lower().startswith(prompt.lower()) else prompt + i
|
||||
for i in list_of_questions
|
||||
]
|
||||
return list_of_questions
|
||||
|
||||
def answer_questions(
|
||||
self,
|
||||
list_of_questions: list[str],
|
||||
@ -348,15 +382,15 @@ class ImageSummaryDetector(AnalysisMethod):
|
||||
Returns:
|
||||
answers (list[str]): list of answers.
|
||||
"""
|
||||
if is_concise_answer:
|
||||
gen_conf = GenerationConfig(max_new_tokens=64, do_sample=False)
|
||||
for i in range(len(list_of_questions)):
|
||||
if not list_of_questions[i].strip().endswith("?"):
|
||||
list_of_questions[i] = list_of_questions[i].strip() + "?"
|
||||
if not list_of_questions[i].lower().startswith("answer concisely"):
|
||||
list_of_questions[i] = "Answer concisely: " + list_of_questions[i]
|
||||
else:
|
||||
gen_conf = GenerationConfig(max_new_tokens=128, do_sample=False)
|
||||
prompt = self.token_prompt_config[
|
||||
"concise" if is_concise_answer else "default"
|
||||
]["answer"]["prompt"]
|
||||
max_new_tokens = self.token_prompt_config[
|
||||
"concise" if is_concise_answer else "default"
|
||||
]["answer"]["max_new_tokens"]
|
||||
|
||||
list_of_questions = self._clean_list_of_questions(list_of_questions, prompt)
|
||||
gen_conf = GenerationConfig(max_new_tokens=max_new_tokens, do_sample=False)
|
||||
|
||||
question_chunk_size = 8
|
||||
answers: List[str] = []
|
||||
@ -396,7 +430,7 @@ class ImageSummaryDetector(AnalysisMethod):
|
||||
|
||||
if len(answers) != len(list_of_questions):
|
||||
raise ValueError(
|
||||
f"Expected {len(list_of_questions)} answers, but got {len(answers)}, try vary amount of questions"
|
||||
f"Expected {len(list_of_questions)} answers, but got {len(answers)}, try varying amount of questions"
|
||||
)
|
||||
|
||||
return answers
|
||||
|
||||
@ -35,3 +35,24 @@ def test_image_summary_detector_questions(model, get_testdict):
|
||||
assert (
|
||||
"two" in results[key]["vqa"][1].lower() or "2" in results[key]["vqa"][1]
|
||||
)
|
||||
|
||||
|
||||
def test_clean_list_of_questions(model):
|
||||
list_of_questions = [
|
||||
"What is happening in the image?",
|
||||
"",
|
||||
" ",
|
||||
None,
|
||||
"How many cars are in the image in total",
|
||||
]
|
||||
detector = ImageSummaryDetector(summary_model=model, subdict={})
|
||||
prompt = detector.token_prompt_config["default"]["questions"]["prompt"]
|
||||
cleaned_questions = detector._clean_list_of_questions(list_of_questions, prompt)
|
||||
assert len(cleaned_questions) == 2
|
||||
assert cleaned_questions[0] == "What is happening in the image?"
|
||||
assert cleaned_questions[1] == "How many cars are in the image in total?"
|
||||
prompt = detector.token_prompt_config["concise"]["questions"]["prompt"]
|
||||
cleaned_questions = detector._clean_list_of_questions(list_of_questions, prompt)
|
||||
assert len(cleaned_questions) == 2
|
||||
assert cleaned_questions[0] == prompt + "What is happening in the image?"
|
||||
assert cleaned_questions[1] == prompt + "How many cars are in the image in total?"
|
||||
|
||||
Загрузка…
x
Ссылка в новой задаче
Block a user