зеркало из
				https://github.com/ssciwr/AMMICO.git
				synced 2025-10-29 21:16:06 +02:00 
			
		
		
		
	
		
			
				
	
	
		
			437 строки
		
	
	
		
			17 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			437 строки
		
	
	
		
			17 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| from ammico.utils import AnalysisMethod, AnalysisType
 | |
| from ammico.model import MultimodalSummaryModel
 | |
| 
 | |
| import os
 | |
| import torch
 | |
| from PIL import Image
 | |
| import warnings
 | |
| 
 | |
| from typing import List, Optional, Union, Dict, Any, Tuple
 | |
| from collections.abc import Sequence as _Sequence
 | |
| from transformers import GenerationConfig
 | |
| from qwen_vl_utils import process_vision_info
 | |
| 
 | |
| 
 | |
| class ImageSummaryDetector(AnalysisMethod):
 | |
|     token_prompt_config = {
 | |
|         "default": {
 | |
|             "summary": {"prompt": "Describe this image.", "max_new_tokens": 256},
 | |
|             "questions": {"prompt": "", "max_new_tokens": 128},
 | |
|         },
 | |
|         "concise": {
 | |
|             "summary": {
 | |
|                 "prompt": "Describe this image in one concise caption.",
 | |
|                 "max_new_tokens": 64,
 | |
|             },
 | |
|             "questions": {"prompt": "Answer concisely: ", "max_new_tokens": 128},
 | |
|         },
 | |
|     }
 | |
|     MAX_QUESTIONS_PER_IMAGE = 32
 | |
|     KEYS_BATCH_SIZE = 16
 | |
| 
 | |
|     def __init__(
 | |
|         self,
 | |
|         summary_model: MultimodalSummaryModel,
 | |
|         subdict: Optional[Dict[str, Any]] = None,
 | |
|     ) -> None:
 | |
|         """
 | |
|         Class for analysing images using QWEN-2.5-VL model.
 | |
|         It provides methods for generating captions and answering questions about images.
 | |
| 
 | |
|         Args:
 | |
|             summary_model ([type], optional): An instance of MultimodalSummaryModel to be used for analysis.
 | |
|             subdict (dict, optional): Dictionary containing the image to be analysed. Defaults to {}.
 | |
| 
 | |
|         Returns:
 | |
|             None.
 | |
|         """
 | |
|         if subdict is None:
 | |
|             subdict = {}
 | |
| 
 | |
|         super().__init__(subdict)
 | |
|         self.summary_model = summary_model
 | |
| 
 | |
|     def _load_pil_if_needed(
 | |
|         self, filename: Union[str, os.PathLike, Image.Image]
 | |
|     ) -> Image.Image:
 | |
|         if isinstance(filename, (str, os.PathLike)):
 | |
|             return Image.open(filename).convert("RGB")
 | |
|         elif isinstance(filename, Image.Image):
 | |
|             return filename.convert("RGB")
 | |
|         else:
 | |
|             raise ValueError("filename must be a path or PIL.Image")
 | |
| 
 | |
|     @staticmethod
 | |
|     def _is_sequence_but_not_str(obj: Any) -> bool:
 | |
|         """True for sequence-like but not a string/bytes/PIL.Image."""
 | |
|         return isinstance(obj, _Sequence) and not isinstance(
 | |
|             obj, (str, bytes, Image.Image)
 | |
|         )
 | |
| 
 | |
|     def _prepare_inputs(
 | |
|         self, list_of_questions: list[str], entry: Optional[Dict[str, Any]] = None
 | |
|     ) -> Dict[str, torch.Tensor]:
 | |
|         filename = entry.get("filename")
 | |
|         if filename is None:
 | |
|             raise ValueError("entry must contain key 'filename'")
 | |
| 
 | |
|         if isinstance(filename, (str, os.PathLike, Image.Image)):
 | |
|             images_context = self._load_pil_if_needed(filename)
 | |
|         elif self._is_sequence_but_not_str(filename):
 | |
|             images_context = [self._load_pil_if_needed(i) for i in filename]
 | |
|         else:
 | |
|             raise ValueError(
 | |
|                 "Unsupported 'filename' entry: expected path, PIL.Image, or sequence."
 | |
|             )
 | |
| 
 | |
|         images_only_messages = [
 | |
|             {
 | |
|                 "role": "user",
 | |
|                 "content": [
 | |
|                     *(
 | |
|                         [{"type": "image", "image": img} for img in images_context]
 | |
|                         if isinstance(images_context, list)
 | |
|                         else [{"type": "image", "image": images_context}]
 | |
|                     )
 | |
|                 ],
 | |
|             }
 | |
|         ]
 | |
| 
 | |
|         try:
 | |
|             image_inputs, _ = process_vision_info(images_only_messages)
 | |
|         except Exception as e:
 | |
|             raise RuntimeError(f"Image processing failed: {e}")
 | |
| 
 | |
|         texts: List[str] = []
 | |
|         for q in list_of_questions:
 | |
|             messages = [
 | |
|                 {
 | |
|                     "role": "user",
 | |
|                     "content": [
 | |
|                         *(
 | |
|                             [
 | |
|                                 {"type": "image", "image": image}
 | |
|                                 for image in images_context
 | |
|                             ]
 | |
|                             if isinstance(images_context, list)
 | |
|                             else [{"type": "image", "image": images_context}]
 | |
|                         ),
 | |
|                         {"type": "text", "text": q},
 | |
|                     ],
 | |
|                 }
 | |
|             ]
 | |
|             text = self.summary_model.processor.apply_chat_template(
 | |
|                 messages, tokenize=False, add_generation_prompt=True
 | |
|             )
 | |
|             texts.append(text)
 | |
| 
 | |
|         images_batch = [image_inputs] * len(texts)
 | |
|         inputs = self.summary_model.processor(
 | |
|             text=texts,
 | |
|             images=images_batch,
 | |
|             padding=True,
 | |
|             return_tensors="pt",
 | |
|         )
 | |
|         inputs = {k: v.to(self.summary_model.device) for k, v in inputs.items()}
 | |
| 
 | |
|         return inputs
 | |
| 
 | |
|     def _validate_analysis_type(
 | |
|         self,
 | |
|         analysis_type: Union["AnalysisType", str],
 | |
|         list_of_questions: Optional[List[str]],
 | |
|         max_questions_per_image: int,
 | |
|     ) -> Tuple[str, List[str], bool, bool]:
 | |
|         if isinstance(analysis_type, AnalysisType):
 | |
|             analysis_type = analysis_type.value
 | |
| 
 | |
|         allowed = {"summary", "questions", "summary_and_questions"}
 | |
|         if analysis_type not in allowed:
 | |
|             raise ValueError(f"analysis_type must be one of {allowed}")
 | |
| 
 | |
|         if list_of_questions is None:
 | |
|             list_of_questions = [
 | |
|                 "Are there people in the image?",
 | |
|                 "What is this picture about?",
 | |
|             ]
 | |
| 
 | |
|         if analysis_type in ("questions", "summary_and_questions"):
 | |
|             if len(list_of_questions) > max_questions_per_image:
 | |
|                 raise ValueError(
 | |
|                     f"Number of questions per image ({len(list_of_questions)}) exceeds safety cap ({max_questions_per_image}). Reduce questions or increase max_questions_per_image."
 | |
|                 )
 | |
| 
 | |
|         is_summary = analysis_type in ("summary", "summary_and_questions")
 | |
|         is_questions = analysis_type in ("questions", "summary_and_questions")
 | |
| 
 | |
|         return analysis_type, list_of_questions, is_summary, is_questions
 | |
| 
 | |
|     def analyse_image(
 | |
|         self,
 | |
|         entry: dict,
 | |
|         analysis_type: Union[str, AnalysisType] = AnalysisType.SUMMARY_AND_QUESTIONS,
 | |
|         list_of_questions: Optional[List[str]] = None,
 | |
|         max_questions_per_image: int = MAX_QUESTIONS_PER_IMAGE,
 | |
|         is_concise_summary: bool = True,
 | |
|         is_concise_answer: bool = True,
 | |
|     ) -> Dict[str, Any]:
 | |
|         """
 | |
|         Analyse a single image entry. Returns dict with keys depending on analysis_type:
 | |
|             - 'caption' (str) if summary requested
 | |
|             - 'vqa' (dict) if questions requested
 | |
|         """
 | |
|         self.subdict = entry
 | |
|         analysis_type, list_of_questions, is_summary, is_questions = (
 | |
|             self._validate_analysis_type(
 | |
|                 analysis_type, list_of_questions, max_questions_per_image
 | |
|             )
 | |
|         )
 | |
| 
 | |
|         if is_summary:
 | |
|             try:
 | |
|                 caps = self.generate_caption(
 | |
|                     entry,
 | |
|                     num_return_sequences=1,
 | |
|                     is_concise_summary=is_concise_summary,
 | |
|                 )
 | |
|                 self.subdict["caption"] = caps[0] if caps else ""
 | |
|             except Exception as e:
 | |
|                 warnings.warn(f"Caption generation failed: {e}")
 | |
| 
 | |
|         if is_questions:
 | |
|             try:
 | |
|                 vqa_map = self.answer_questions(
 | |
|                     list_of_questions, entry, is_concise_answer
 | |
|                 )
 | |
|                 self.subdict["vqa"] = vqa_map
 | |
|             except Exception as e:
 | |
|                 warnings.warn(f"VQA failed: {e}")
 | |
| 
 | |
|         return self.subdict
 | |
| 
 | |
|     def analyse_images_from_dict(
 | |
|         self,
 | |
|         analysis_type: Union[AnalysisType, str] = AnalysisType.SUMMARY_AND_QUESTIONS,
 | |
|         list_of_questions: Optional[List[str]] = None,
 | |
|         max_questions_per_image: int = MAX_QUESTIONS_PER_IMAGE,
 | |
|         keys_batch_size: int = KEYS_BATCH_SIZE,
 | |
|         is_concise_summary: bool = True,
 | |
|         is_concise_answer: bool = True,
 | |
|     ) -> Dict[str, dict]:
 | |
|         """
 | |
|         Analyse image with  model.
 | |
| 
 | |
|         Args:
 | |
|             analysis_type (str): type of the analysis.
 | |
|             list_of_questions (list[str]): list of questions.
 | |
|             max_questions_per_image (int): maximum number of questions per image. We recommend to keep it low to avoid long processing times and high memory usage.
 | |
|             keys_batch_size (int): number of images to process in a batch.
 | |
|             is_concise_summary (bool): whether to generate concise summary.
 | |
|             is_concise_answer (bool): whether to generate concise answers.
 | |
|         Returns:
 | |
|             self.subdict (dict): dictionary with analysis results.
 | |
|         """
 | |
|         # TODO: add option to ask multiple questions per image as one batch.
 | |
|         analysis_type, list_of_questions, is_summary, is_questions = (
 | |
|             self._validate_analysis_type(
 | |
|                 analysis_type, list_of_questions, max_questions_per_image
 | |
|             )
 | |
|         )
 | |
| 
 | |
|         keys = list(self.subdict.keys())
 | |
|         for batch_start in range(0, len(keys), keys_batch_size):
 | |
|             batch_keys = keys[batch_start : batch_start + keys_batch_size]
 | |
|             for key in batch_keys:
 | |
|                 entry = self.subdict[key]
 | |
|                 if is_summary:
 | |
|                     try:
 | |
|                         caps = self.generate_caption(
 | |
|                             entry,
 | |
|                             num_return_sequences=1,
 | |
|                             is_concise_summary=is_concise_summary,
 | |
|                         )
 | |
|                         entry["caption"] = caps[0] if caps else ""
 | |
|                     except Exception as e:
 | |
|                         warnings.warn(f"Caption generation failed: {e}")
 | |
| 
 | |
|                 if is_questions:
 | |
|                     try:
 | |
|                         vqa_map = self.answer_questions(
 | |
|                             list_of_questions, entry, is_concise_answer
 | |
|                         )
 | |
|                         entry["vqa"] = vqa_map
 | |
|                     except Exception as e:
 | |
|                         warnings.warn(f"VQA failed: {e}")
 | |
| 
 | |
|                 self.subdict[key] = entry
 | |
|         return self.subdict
 | |
| 
 | |
|     def generate_caption(
 | |
|         self,
 | |
|         entry: Optional[Dict[str, Any]] = None,
 | |
|         num_return_sequences: int = 1,
 | |
|         is_concise_summary: bool = True,
 | |
|     ) -> List[str]:
 | |
|         """
 | |
|         Create caption for image. Depending on is_concise_summary it will be either concise or detailed.
 | |
| 
 | |
|         Args:
 | |
|             entry (dict): dictionary containing the image to be captioned.
 | |
|             num_return_sequences (int): number of captions to generate.
 | |
|             is_concise_summary (bool): whether to generate concise summary.
 | |
| 
 | |
|         Returns:
 | |
|             results (list[str]): list of generated captions.
 | |
|         """
 | |
|         prompt = self.token_prompt_config[
 | |
|             "concise" if is_concise_summary else "default"
 | |
|         ]["summary"]["prompt"]
 | |
|         max_new_tokens = self.token_prompt_config[
 | |
|             "concise" if is_concise_summary else "default"
 | |
|         ]["summary"]["max_new_tokens"]
 | |
|         inputs = self._prepare_inputs([prompt], entry)
 | |
| 
 | |
|         gen_conf = GenerationConfig(
 | |
|             max_new_tokens=max_new_tokens,
 | |
|             do_sample=False,
 | |
|             num_return_sequences=num_return_sequences,
 | |
|         )
 | |
| 
 | |
|         with torch.inference_mode():
 | |
|             try:
 | |
|                 if self.summary_model.device == "cuda":
 | |
|                     with torch.cuda.amp.autocast(enabled=True):
 | |
|                         generated_ids = self.summary_model.model.generate(
 | |
|                             **inputs, generation_config=gen_conf
 | |
|                         )
 | |
|                 else:
 | |
|                     generated_ids = self.summary_model.model.generate(
 | |
|                         **inputs, generation_config=gen_conf
 | |
|                     )
 | |
|             except RuntimeError as e:
 | |
|                 warnings.warn(
 | |
|                     f"Retry without autocast failed: {e}. Attempting cudnn-disabled retry."
 | |
|                 )
 | |
|                 cudnn_was_enabled = (
 | |
|                     torch.backends.cudnn.is_available() and torch.backends.cudnn.enabled
 | |
|                 )
 | |
|                 if cudnn_was_enabled:
 | |
|                     torch.backends.cudnn.enabled = False
 | |
|                 try:
 | |
|                     generated_ids = self.summary_model.model.generate(
 | |
|                         **inputs, generation_config=gen_conf
 | |
|                     )
 | |
|                 except Exception as retry_error:
 | |
|                     raise RuntimeError(
 | |
|                         f"Failed to generate ids after retry: {retry_error}"
 | |
|                     ) from retry_error
 | |
|                 finally:
 | |
|                     if cudnn_was_enabled:
 | |
|                         torch.backends.cudnn.enabled = True
 | |
| 
 | |
|         decoded = None
 | |
|         if "input_ids" in inputs:
 | |
|             in_ids = inputs["input_ids"]
 | |
|             trimmed = [
 | |
|                 out_ids[len(inp_ids) :]
 | |
|                 for inp_ids, out_ids in zip(in_ids, generated_ids)
 | |
|             ]
 | |
|             decoded = self.summary_model.tokenizer.batch_decode(
 | |
|                 trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
 | |
|             )
 | |
|         else:
 | |
|             decoded = self.summary_model.tokenizer.batch_decode(
 | |
|                 generated_ids,
 | |
|                 skip_special_tokens=True,
 | |
|                 clean_up_tokenization_spaces=False,
 | |
|             )
 | |
| 
 | |
|         results = [d.strip() for d in decoded]
 | |
|         return results
 | |
| 
 | |
|     def _clean_list_of_questions(
 | |
|         self, list_of_questions: list[str], prompt: str
 | |
|     ) -> list[str]:
 | |
|         """Clean the list of questions to contain correctly formatted strings."""
 | |
|         # remove all None or empty questions
 | |
|         list_of_questions = [i for i in list_of_questions if i and i.strip()]
 | |
|         # ensure each question ends with a question mark
 | |
|         list_of_questions = [
 | |
|             i.strip() + "?" if not i.strip().endswith("?") else i.strip()
 | |
|             for i in list_of_questions
 | |
|         ]
 | |
|         # ensure each question starts with the prompt
 | |
|         list_of_questions = [
 | |
|             i if i.lower().startswith(prompt.lower()) else prompt + i
 | |
|             for i in list_of_questions
 | |
|         ]
 | |
|         return list_of_questions
 | |
| 
 | |
|     def answer_questions(
 | |
|         self,
 | |
|         list_of_questions: list[str],
 | |
|         entry: Optional[Dict[str, Any]] = None,
 | |
|         is_concise_answer: bool = True,
 | |
|     ) -> List[str]:
 | |
|         """
 | |
|         Create answers for list of questions about image.
 | |
|         Args:
 | |
|             list_of_questions (list[str]): list of questions.
 | |
|             entry (dict): dictionary containing the image to be captioned.
 | |
|             is_concise_answer (bool): whether to generate concise answers.
 | |
|         Returns:
 | |
|             answers (list[str]): list of answers.
 | |
|         """
 | |
|         prompt = self.token_prompt_config[
 | |
|             "concise" if is_concise_answer else "default"
 | |
|         ]["questions"]["prompt"]
 | |
|         max_new_tokens = self.token_prompt_config[
 | |
|             "concise" if is_concise_answer else "default"
 | |
|         ]["questions"]["max_new_tokens"]
 | |
| 
 | |
|         list_of_questions = self._clean_list_of_questions(list_of_questions, prompt)
 | |
|         gen_conf = GenerationConfig(max_new_tokens=max_new_tokens, do_sample=False)
 | |
| 
 | |
|         question_chunk_size = 8
 | |
|         answers: List[str] = []
 | |
|         n = len(list_of_questions)
 | |
|         for i in range(0, n, question_chunk_size):
 | |
|             chunk = list_of_questions[i : i + question_chunk_size]
 | |
|             inputs = self._prepare_inputs(chunk, entry)
 | |
|             with torch.inference_mode():
 | |
|                 if self.summary_model.device == "cuda":
 | |
|                     with torch.cuda.amp.autocast(enabled=True):
 | |
|                         out_ids = self.summary_model.model.generate(
 | |
|                             **inputs, generation_config=gen_conf
 | |
|                         )
 | |
|                 else:
 | |
|                     out_ids = self.summary_model.model.generate(
 | |
|                         **inputs, generation_config=gen_conf
 | |
|                     )
 | |
| 
 | |
|             if "input_ids" in inputs:
 | |
|                 in_ids = inputs["input_ids"]
 | |
|                 trimmed_batch = [
 | |
|                     out_row[len(inp_row) :] for inp_row, out_row in zip(in_ids, out_ids)
 | |
|                 ]
 | |
|                 decoded = self.summary_model.tokenizer.batch_decode(
 | |
|                     trimmed_batch,
 | |
|                     skip_special_tokens=True,
 | |
|                     clean_up_tokenization_spaces=False,
 | |
|                 )
 | |
|             else:
 | |
|                 decoded = self.summary_model.tokenizer.batch_decode(
 | |
|                     out_ids,
 | |
|                     skip_special_tokens=True,
 | |
|                     clean_up_tokenization_spaces=False,
 | |
|                 )
 | |
| 
 | |
|             answers.extend([d.strip() for d in decoded])
 | |
| 
 | |
|         if len(answers) != len(list_of_questions):
 | |
|             raise ValueError(
 | |
|                 f"Expected {len(list_of_questions)} answers, but got {len(answers)}, try varying amount of questions"
 | |
|             )
 | |
| 
 | |
|         return answers
 | 
