зеркало из
https://github.com/ssciwr/AMMICO.git
synced 2025-10-29 13:06:04 +02:00
437 строки
17 KiB
Python
437 строки
17 KiB
Python
from ammico.utils import AnalysisMethod, AnalysisType
|
|
from ammico.model import MultimodalSummaryModel
|
|
|
|
import os
|
|
import torch
|
|
from PIL import Image
|
|
import warnings
|
|
|
|
from typing import List, Optional, Union, Dict, Any, Tuple
|
|
from collections.abc import Sequence as _Sequence
|
|
from transformers import GenerationConfig
|
|
from qwen_vl_utils import process_vision_info
|
|
|
|
|
|
class ImageSummaryDetector(AnalysisMethod):
|
|
token_prompt_config = {
|
|
"default": {
|
|
"summary": {"prompt": "Describe this image.", "max_new_tokens": 256},
|
|
"questions": {"prompt": "", "max_new_tokens": 128},
|
|
},
|
|
"concise": {
|
|
"summary": {
|
|
"prompt": "Describe this image in one concise caption.",
|
|
"max_new_tokens": 64,
|
|
},
|
|
"questions": {"prompt": "Answer concisely: ", "max_new_tokens": 128},
|
|
},
|
|
}
|
|
MAX_QUESTIONS_PER_IMAGE = 32
|
|
KEYS_BATCH_SIZE = 16
|
|
|
|
def __init__(
|
|
self,
|
|
summary_model: MultimodalSummaryModel,
|
|
subdict: Optional[Dict[str, Any]] = None,
|
|
) -> None:
|
|
"""
|
|
Class for analysing images using QWEN-2.5-VL model.
|
|
It provides methods for generating captions and answering questions about images.
|
|
|
|
Args:
|
|
summary_model ([type], optional): An instance of MultimodalSummaryModel to be used for analysis.
|
|
subdict (dict, optional): Dictionary containing the image to be analysed. Defaults to {}.
|
|
|
|
Returns:
|
|
None.
|
|
"""
|
|
if subdict is None:
|
|
subdict = {}
|
|
|
|
super().__init__(subdict)
|
|
self.summary_model = summary_model
|
|
|
|
def _load_pil_if_needed(
|
|
self, filename: Union[str, os.PathLike, Image.Image]
|
|
) -> Image.Image:
|
|
if isinstance(filename, (str, os.PathLike)):
|
|
return Image.open(filename).convert("RGB")
|
|
elif isinstance(filename, Image.Image):
|
|
return filename.convert("RGB")
|
|
else:
|
|
raise ValueError("filename must be a path or PIL.Image")
|
|
|
|
@staticmethod
|
|
def _is_sequence_but_not_str(obj: Any) -> bool:
|
|
"""True for sequence-like but not a string/bytes/PIL.Image."""
|
|
return isinstance(obj, _Sequence) and not isinstance(
|
|
obj, (str, bytes, Image.Image)
|
|
)
|
|
|
|
def _prepare_inputs(
|
|
self, list_of_questions: list[str], entry: Optional[Dict[str, Any]] = None
|
|
) -> Dict[str, torch.Tensor]:
|
|
filename = entry.get("filename")
|
|
if filename is None:
|
|
raise ValueError("entry must contain key 'filename'")
|
|
|
|
if isinstance(filename, (str, os.PathLike, Image.Image)):
|
|
images_context = self._load_pil_if_needed(filename)
|
|
elif self._is_sequence_but_not_str(filename):
|
|
images_context = [self._load_pil_if_needed(i) for i in filename]
|
|
else:
|
|
raise ValueError(
|
|
"Unsupported 'filename' entry: expected path, PIL.Image, or sequence."
|
|
)
|
|
|
|
images_only_messages = [
|
|
{
|
|
"role": "user",
|
|
"content": [
|
|
*(
|
|
[{"type": "image", "image": img} for img in images_context]
|
|
if isinstance(images_context, list)
|
|
else [{"type": "image", "image": images_context}]
|
|
)
|
|
],
|
|
}
|
|
]
|
|
|
|
try:
|
|
image_inputs, _ = process_vision_info(images_only_messages)
|
|
except Exception as e:
|
|
raise RuntimeError(f"Image processing failed: {e}")
|
|
|
|
texts: List[str] = []
|
|
for q in list_of_questions:
|
|
messages = [
|
|
{
|
|
"role": "user",
|
|
"content": [
|
|
*(
|
|
[
|
|
{"type": "image", "image": image}
|
|
for image in images_context
|
|
]
|
|
if isinstance(images_context, list)
|
|
else [{"type": "image", "image": images_context}]
|
|
),
|
|
{"type": "text", "text": q},
|
|
],
|
|
}
|
|
]
|
|
text = self.summary_model.processor.apply_chat_template(
|
|
messages, tokenize=False, add_generation_prompt=True
|
|
)
|
|
texts.append(text)
|
|
|
|
images_batch = [image_inputs] * len(texts)
|
|
inputs = self.summary_model.processor(
|
|
text=texts,
|
|
images=images_batch,
|
|
padding=True,
|
|
return_tensors="pt",
|
|
)
|
|
inputs = {k: v.to(self.summary_model.device) for k, v in inputs.items()}
|
|
|
|
return inputs
|
|
|
|
def _validate_analysis_type(
|
|
self,
|
|
analysis_type: Union["AnalysisType", str],
|
|
list_of_questions: Optional[List[str]],
|
|
max_questions_per_image: int,
|
|
) -> Tuple[str, List[str], bool, bool]:
|
|
if isinstance(analysis_type, AnalysisType):
|
|
analysis_type = analysis_type.value
|
|
|
|
allowed = {"summary", "questions", "summary_and_questions"}
|
|
if analysis_type not in allowed:
|
|
raise ValueError(f"analysis_type must be one of {allowed}")
|
|
|
|
if list_of_questions is None:
|
|
list_of_questions = [
|
|
"Are there people in the image?",
|
|
"What is this picture about?",
|
|
]
|
|
|
|
if analysis_type in ("questions", "summary_and_questions"):
|
|
if len(list_of_questions) > max_questions_per_image:
|
|
raise ValueError(
|
|
f"Number of questions per image ({len(list_of_questions)}) exceeds safety cap ({max_questions_per_image}). Reduce questions or increase max_questions_per_image."
|
|
)
|
|
|
|
is_summary = analysis_type in ("summary", "summary_and_questions")
|
|
is_questions = analysis_type in ("questions", "summary_and_questions")
|
|
|
|
return analysis_type, list_of_questions, is_summary, is_questions
|
|
|
|
def analyse_image(
|
|
self,
|
|
entry: dict,
|
|
analysis_type: Union[str, AnalysisType] = AnalysisType.SUMMARY_AND_QUESTIONS,
|
|
list_of_questions: Optional[List[str]] = None,
|
|
max_questions_per_image: int = MAX_QUESTIONS_PER_IMAGE,
|
|
is_concise_summary: bool = True,
|
|
is_concise_answer: bool = True,
|
|
) -> Dict[str, Any]:
|
|
"""
|
|
Analyse a single image entry. Returns dict with keys depending on analysis_type:
|
|
- 'caption' (str) if summary requested
|
|
- 'vqa' (dict) if questions requested
|
|
"""
|
|
self.subdict = entry
|
|
analysis_type, list_of_questions, is_summary, is_questions = (
|
|
self._validate_analysis_type(
|
|
analysis_type, list_of_questions, max_questions_per_image
|
|
)
|
|
)
|
|
|
|
if is_summary:
|
|
try:
|
|
caps = self.generate_caption(
|
|
entry,
|
|
num_return_sequences=1,
|
|
is_concise_summary=is_concise_summary,
|
|
)
|
|
self.subdict["caption"] = caps[0] if caps else ""
|
|
except Exception as e:
|
|
warnings.warn(f"Caption generation failed: {e}")
|
|
|
|
if is_questions:
|
|
try:
|
|
vqa_map = self.answer_questions(
|
|
list_of_questions, entry, is_concise_answer
|
|
)
|
|
self.subdict["vqa"] = vqa_map
|
|
except Exception as e:
|
|
warnings.warn(f"VQA failed: {e}")
|
|
|
|
return self.subdict
|
|
|
|
def analyse_images_from_dict(
|
|
self,
|
|
analysis_type: Union[AnalysisType, str] = AnalysisType.SUMMARY_AND_QUESTIONS,
|
|
list_of_questions: Optional[List[str]] = None,
|
|
max_questions_per_image: int = MAX_QUESTIONS_PER_IMAGE,
|
|
keys_batch_size: int = KEYS_BATCH_SIZE,
|
|
is_concise_summary: bool = True,
|
|
is_concise_answer: bool = True,
|
|
) -> Dict[str, dict]:
|
|
"""
|
|
Analyse image with model.
|
|
|
|
Args:
|
|
analysis_type (str): type of the analysis.
|
|
list_of_questions (list[str]): list of questions.
|
|
max_questions_per_image (int): maximum number of questions per image. We recommend to keep it low to avoid long processing times and high memory usage.
|
|
keys_batch_size (int): number of images to process in a batch.
|
|
is_concise_summary (bool): whether to generate concise summary.
|
|
is_concise_answer (bool): whether to generate concise answers.
|
|
Returns:
|
|
self.subdict (dict): dictionary with analysis results.
|
|
"""
|
|
# TODO: add option to ask multiple questions per image as one batch.
|
|
analysis_type, list_of_questions, is_summary, is_questions = (
|
|
self._validate_analysis_type(
|
|
analysis_type, list_of_questions, max_questions_per_image
|
|
)
|
|
)
|
|
|
|
keys = list(self.subdict.keys())
|
|
for batch_start in range(0, len(keys), keys_batch_size):
|
|
batch_keys = keys[batch_start : batch_start + keys_batch_size]
|
|
for key in batch_keys:
|
|
entry = self.subdict[key]
|
|
if is_summary:
|
|
try:
|
|
caps = self.generate_caption(
|
|
entry,
|
|
num_return_sequences=1,
|
|
is_concise_summary=is_concise_summary,
|
|
)
|
|
entry["caption"] = caps[0] if caps else ""
|
|
except Exception as e:
|
|
warnings.warn(f"Caption generation failed: {e}")
|
|
|
|
if is_questions:
|
|
try:
|
|
vqa_map = self.answer_questions(
|
|
list_of_questions, entry, is_concise_answer
|
|
)
|
|
entry["vqa"] = vqa_map
|
|
except Exception as e:
|
|
warnings.warn(f"VQA failed: {e}")
|
|
|
|
self.subdict[key] = entry
|
|
return self.subdict
|
|
|
|
def generate_caption(
|
|
self,
|
|
entry: Optional[Dict[str, Any]] = None,
|
|
num_return_sequences: int = 1,
|
|
is_concise_summary: bool = True,
|
|
) -> List[str]:
|
|
"""
|
|
Create caption for image. Depending on is_concise_summary it will be either concise or detailed.
|
|
|
|
Args:
|
|
entry (dict): dictionary containing the image to be captioned.
|
|
num_return_sequences (int): number of captions to generate.
|
|
is_concise_summary (bool): whether to generate concise summary.
|
|
|
|
Returns:
|
|
results (list[str]): list of generated captions.
|
|
"""
|
|
prompt = self.token_prompt_config[
|
|
"concise" if is_concise_summary else "default"
|
|
]["summary"]["prompt"]
|
|
max_new_tokens = self.token_prompt_config[
|
|
"concise" if is_concise_summary else "default"
|
|
]["summary"]["max_new_tokens"]
|
|
inputs = self._prepare_inputs([prompt], entry)
|
|
|
|
gen_conf = GenerationConfig(
|
|
max_new_tokens=max_new_tokens,
|
|
do_sample=False,
|
|
num_return_sequences=num_return_sequences,
|
|
)
|
|
|
|
with torch.inference_mode():
|
|
try:
|
|
if self.summary_model.device == "cuda":
|
|
with torch.cuda.amp.autocast(enabled=True):
|
|
generated_ids = self.summary_model.model.generate(
|
|
**inputs, generation_config=gen_conf
|
|
)
|
|
else:
|
|
generated_ids = self.summary_model.model.generate(
|
|
**inputs, generation_config=gen_conf
|
|
)
|
|
except RuntimeError as e:
|
|
warnings.warn(
|
|
f"Retry without autocast failed: {e}. Attempting cudnn-disabled retry."
|
|
)
|
|
cudnn_was_enabled = (
|
|
torch.backends.cudnn.is_available() and torch.backends.cudnn.enabled
|
|
)
|
|
if cudnn_was_enabled:
|
|
torch.backends.cudnn.enabled = False
|
|
try:
|
|
generated_ids = self.summary_model.model.generate(
|
|
**inputs, generation_config=gen_conf
|
|
)
|
|
except Exception as retry_error:
|
|
raise RuntimeError(
|
|
f"Failed to generate ids after retry: {retry_error}"
|
|
) from retry_error
|
|
finally:
|
|
if cudnn_was_enabled:
|
|
torch.backends.cudnn.enabled = True
|
|
|
|
decoded = None
|
|
if "input_ids" in inputs:
|
|
in_ids = inputs["input_ids"]
|
|
trimmed = [
|
|
out_ids[len(inp_ids) :]
|
|
for inp_ids, out_ids in zip(in_ids, generated_ids)
|
|
]
|
|
decoded = self.summary_model.tokenizer.batch_decode(
|
|
trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
|
)
|
|
else:
|
|
decoded = self.summary_model.tokenizer.batch_decode(
|
|
generated_ids,
|
|
skip_special_tokens=True,
|
|
clean_up_tokenization_spaces=False,
|
|
)
|
|
|
|
results = [d.strip() for d in decoded]
|
|
return results
|
|
|
|
def _clean_list_of_questions(
|
|
self, list_of_questions: list[str], prompt: str
|
|
) -> list[str]:
|
|
"""Clean the list of questions to contain correctly formatted strings."""
|
|
# remove all None or empty questions
|
|
list_of_questions = [i for i in list_of_questions if i and i.strip()]
|
|
# ensure each question ends with a question mark
|
|
list_of_questions = [
|
|
i.strip() + "?" if not i.strip().endswith("?") else i.strip()
|
|
for i in list_of_questions
|
|
]
|
|
# ensure each question starts with the prompt
|
|
list_of_questions = [
|
|
i if i.lower().startswith(prompt.lower()) else prompt + i
|
|
for i in list_of_questions
|
|
]
|
|
return list_of_questions
|
|
|
|
def answer_questions(
|
|
self,
|
|
list_of_questions: list[str],
|
|
entry: Optional[Dict[str, Any]] = None,
|
|
is_concise_answer: bool = True,
|
|
) -> List[str]:
|
|
"""
|
|
Create answers for list of questions about image.
|
|
Args:
|
|
list_of_questions (list[str]): list of questions.
|
|
entry (dict): dictionary containing the image to be captioned.
|
|
is_concise_answer (bool): whether to generate concise answers.
|
|
Returns:
|
|
answers (list[str]): list of answers.
|
|
"""
|
|
prompt = self.token_prompt_config[
|
|
"concise" if is_concise_answer else "default"
|
|
]["questions"]["prompt"]
|
|
max_new_tokens = self.token_prompt_config[
|
|
"concise" if is_concise_answer else "default"
|
|
]["questions"]["max_new_tokens"]
|
|
|
|
list_of_questions = self._clean_list_of_questions(list_of_questions, prompt)
|
|
gen_conf = GenerationConfig(max_new_tokens=max_new_tokens, do_sample=False)
|
|
|
|
question_chunk_size = 8
|
|
answers: List[str] = []
|
|
n = len(list_of_questions)
|
|
for i in range(0, n, question_chunk_size):
|
|
chunk = list_of_questions[i : i + question_chunk_size]
|
|
inputs = self._prepare_inputs(chunk, entry)
|
|
with torch.inference_mode():
|
|
if self.summary_model.device == "cuda":
|
|
with torch.cuda.amp.autocast(enabled=True):
|
|
out_ids = self.summary_model.model.generate(
|
|
**inputs, generation_config=gen_conf
|
|
)
|
|
else:
|
|
out_ids = self.summary_model.model.generate(
|
|
**inputs, generation_config=gen_conf
|
|
)
|
|
|
|
if "input_ids" in inputs:
|
|
in_ids = inputs["input_ids"]
|
|
trimmed_batch = [
|
|
out_row[len(inp_row) :] for inp_row, out_row in zip(in_ids, out_ids)
|
|
]
|
|
decoded = self.summary_model.tokenizer.batch_decode(
|
|
trimmed_batch,
|
|
skip_special_tokens=True,
|
|
clean_up_tokenization_spaces=False,
|
|
)
|
|
else:
|
|
decoded = self.summary_model.tokenizer.batch_decode(
|
|
out_ids,
|
|
skip_special_tokens=True,
|
|
clean_up_tokenization_spaces=False,
|
|
)
|
|
|
|
answers.extend([d.strip() for d in decoded])
|
|
|
|
if len(answers) != len(list_of_questions):
|
|
raise ValueError(
|
|
f"Expected {len(list_of_questions)} answers, but got {len(answers)}, try varying amount of questions"
|
|
)
|
|
|
|
return answers
|