AMMICO/ammico/summary.py
2025-08-29 15:11:43 +02:00

345 строки
13 KiB
Python

from ammico.utils import AnalysisMethod, AnalysisType
from ammico.model import MultimodalSummaryModel
import os
import torch
from PIL import Image
import warnings
from typing import List, Optional, Union, Dict, Any
from collections.abc import Sequence as _Sequence
from transformers import GenerationConfig
import re
from qwen_vl_utils import process_vision_info
class ImageSummaryDetector(AnalysisMethod):
def __init__(
self,
summary_model: MultimodalSummaryModel,
subdict: dict = {},
) -> None:
"""
Class for analysing images using QWEN-2.5-VL model.
It provides methods for generating captions and answering questions about images.
Args:
summary_model ([type], optional): An instance of MultimodalSummaryModel to be used for analysis.
subdict (dict, optional): Dictionary containing the image to be analysed. Defaults to {}.
Returns:
None.
"""
super().__init__(subdict)
self.summary_model = summary_model
def _load_pil_if_needed(
self, filename: Union[str, os.PathLike, Image.Image]
) -> Image.Image:
if isinstance(filename, (str, os.PathLike)):
return Image.open(filename).convert("RGB")
elif isinstance(filename, Image.Image):
return filename.convert("RGB")
else:
raise ValueError("filename must be a path or PIL.Image")
@staticmethod
def _is_sequence_but_not_str(obj: Any) -> bool:
"""True for sequence-like but not a string/bytes/PIL.Image."""
return isinstance(obj, _Sequence) and not isinstance(
obj, (str, bytes, Image.Image)
)
def _prepare_inputs(
self, list_of_questions: list[str], entry: Optional[Dict[str, Any]] = None
) -> Dict[str, torch.Tensor]:
filename = entry.get("filename")
if filename is None:
raise ValueError("entry must contain key 'filename'")
if isinstance(filename, (str, os.PathLike, Image.Image)):
images_context = self._load_pil_if_needed(filename)
elif self._is_sequence_but_not_str(filename):
images_context = [self._load_pil_if_needed(i) for i in filename]
else:
raise ValueError(
"Unsupported 'filename' entry: expected path, PIL.Image, or sequence."
)
images_only_messages = [
{
"role": "user",
"content": [
*(
[{"type": "image", "image": img} for img in images_context]
if isinstance(images_context, list)
else [{"type": "image", "image": images_context}]
)
],
}
]
try:
image_inputs, _ = process_vision_info(images_only_messages)
except Exception as e:
raise RuntimeError(f"Image processing failed: {e}")
texts: List[str] = []
for q in list_of_questions:
messages = [
{
"role": "user",
"content": [
*(
[
{"type": "image", "image": image}
for image in images_context
]
if isinstance(images_context, list)
else [{"type": "image", "image": images_context}]
),
{"type": "text", "text": q},
],
}
]
text = self.summary_model.processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
texts.append(text)
images_batch = [image_inputs] * len(texts)
inputs = self.summary_model.processor(
text=texts,
images=images_batch,
padding=True,
return_tensors="pt",
)
inputs = {k: v.to(self.summary_model.device) for k, v in inputs.items()}
return inputs
def analyse_images(
self,
analysis_type: Union[AnalysisType, str] = AnalysisType.SUMMARY_AND_QUESTIONS,
list_of_questions: Optional[List[str]] = None,
max_questions_per_image: int = 32,
keys_batch_size: int = 16,
is_concise_summary: bool = True,
is_concise_answer: bool = True,
) -> Dict[str, dict]:
"""
Analyse image with model.
Args:
analysis_type (str): type of the analysis.
list_of_questions (list[str]): list of questions.
max_questions_per_image (int): maximum number of questions per image. We recommend to keep it low to avoid long processing times and high memory usage.
keys_batch_size (int): number of images to process in a batch.
is_concise_summary (bool): whether to generate concise summary.
is_concise_answer (bool): whether to generate concise answers.
Returns:
self.subdict (dict): dictionary with analysis results.
"""
# TODO: add option to ask multiple questions per image as one batch.
if isinstance(analysis_type, AnalysisType):
analysis_type = analysis_type.value
allowed = {"summary", "questions", "summary_and_questions"}
if analysis_type not in allowed:
raise ValueError(f"analysis_type must be one of {allowed}")
if list_of_questions is None:
list_of_questions = [
"Are there people in the image?",
"What is this picture about?",
]
keys = list(self.subdict.keys())
for batch_start in range(0, len(keys), keys_batch_size):
batch_keys = keys[batch_start : batch_start + keys_batch_size]
for key in batch_keys:
entry = self.subdict[key]
if analysis_type in ("summary", "summary_and_questions"):
try:
caps = self.generate_caption(
entry,
num_return_sequences=1,
is_concise_summary=is_concise_summary,
)
entry["caption"] = caps[0] if caps else ""
except Exception as e:
warnings.warn(
"Caption generation failed for key %s: %s", key, e
)
if analysis_type in ("questions", "summary_and_questions"):
if len(list_of_questions) > max_questions_per_image:
raise ValueError(
f"Number of questions per image ({len(list_of_questions)}) exceeds safety cap ({max_questions_per_image})."
" Reduce questions or increase max_questions_per_image."
)
try:
vqa_map = self.answer_questions(
list_of_questions, entry, is_concise_answer
)
entry["vqa"] = vqa_map
except Exception as e:
warnings.warn("VQA failed for key %s: %s", key, e)
self.subdict[key] = entry
return self.subdict
def generate_caption(
self,
entry: Optional[Dict[str, Any]] = None,
num_return_sequences: int = 1,
is_concise_summary: bool = True,
) -> List[str]:
"""
Create caption for image. Depending on is_concise_summary it will be either concise or detailed.
Args:
entry (dict): dictionary containing the image to be captioned.
num_return_sequences (int): number of captions to generate.
is_concise_summary (bool): whether to generate concise summary.
Returns:
results (list[str]): list of generated captions.
"""
if is_concise_summary:
prompt = ["Describe this image in one concise caption."]
max_new_tokens = 64
else:
prompt = ["Describe this image."]
max_new_tokens = 256
inputs = self._prepare_inputs(prompt, entry)
gen_conf = GenerationConfig(
max_new_tokens=max_new_tokens,
do_sample=False,
num_return_sequences=num_return_sequences,
)
with torch.inference_mode():
try:
if self.summary_model.device == "cuda":
with torch.cuda.amp.autocast(enabled=True):
generated_ids = self.summary_model.model.generate(
**inputs, generation_config=gen_conf
)
else:
generated_ids = self.summary_model.model.generate(
**inputs, generation_config=gen_conf
)
except RuntimeError as e:
warnings.warn(
"Retry without autocast failed: %s. Attempting cudnn-disabled retry.",
e,
)
cudnn_was_enabled = (
torch.backends.cudnn.is_available() and torch.backends.cudnn.enabled
)
if cudnn_was_enabled:
torch.backends.cudnn.enabled = False
try:
generated_ids = self.summary_model.model.generate(
**inputs, generation_config=gen_conf
)
except Exception as retry_error:
raise RuntimeError(
f"Failed to generate ids after retry: {retry_error}"
) from retry_error
finally:
if cudnn_was_enabled:
torch.backends.cudnn.enabled = True
decoded = None
if "input_ids" in inputs:
in_ids = inputs["input_ids"]
trimmed = [
out_ids[len(inp_ids) :]
for inp_ids, out_ids in zip(in_ids, generated_ids)
]
decoded = self.summary_model.tokenizer.batch_decode(
trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
else:
decoded = self.summary_model.tokenizer.batch_decode(
generated_ids,
skip_special_tokens=True,
clean_up_tokenization_spaces=False,
)
results = [d.strip() for d in decoded]
return results
def answer_questions(
self,
list_of_questions: list[str],
entry: Optional[Dict[str, Any]] = None,
is_concise_answer: bool = True,
) -> List[str]:
"""
Create answers for list of questions about image.
Args:
list_of_questions (list[str]): list of questions.
entry (dict): dictionary containing the image to be captioned.
is_concise_answer (bool): whether to generate concise answers.
Returns:
answers (list[str]): list of answers.
"""
if is_concise_answer:
gen_conf = GenerationConfig(max_new_tokens=64, do_sample=False)
for i in range(len(list_of_questions)):
if not list_of_questions[i].strip().endswith("?"):
list_of_questions[i] = list_of_questions[i].strip() + "?"
if not list_of_questions[i].lower().startswith("answer concisely"):
list_of_questions[i] = "Answer concisely: " + list_of_questions[i]
else:
gen_conf = GenerationConfig(max_new_tokens=128, do_sample=False)
question_chunk_size = 8
answers: List[str] = []
n = len(list_of_questions)
for i in range(0, n, question_chunk_size):
chunk = list_of_questions[i : i + question_chunk_size]
inputs = self._prepare_inputs(chunk, entry)
with torch.inference_mode():
if self.summary_model.device == "cuda":
with torch.cuda.amp.autocast(enabled=True):
out_ids = self.summary_model.model.generate(
**inputs, generation_config=gen_conf
)
else:
out_ids = self.summary_model.model.generate(
**inputs, generation_config=gen_conf
)
if "input_ids" in inputs:
in_ids = inputs["input_ids"]
trimmed_batch = [
out_row[len(inp_row) :] for inp_row, out_row in zip(in_ids, out_ids)
]
decoded = self.summary_model.tokenizer.batch_decode(
trimmed_batch,
skip_special_tokens=True,
clean_up_tokenization_spaces=False,
)
else:
decoded = self.summary_model.tokenizer.batch_decode(
out_ids,
skip_special_tokens=True,
clean_up_tokenization_spaces=False,
)
answers.extend([d.strip() for d in decoded])
if len(answers) != len(list_of_questions):
raise ValueError(
f"Expected {len(list_of_questions)} answers, but got {len(answers)}, try vary amount of questions"
)
return answers