From 34afed53755d5e13945e036e5ccfc49d8c042824 Mon Sep 17 00:00:00 2001 From: DimasfromLavoisier Date: Fri, 12 Sep 2025 15:12:07 +0200 Subject: [PATCH 01/23] add new dependencies for upcoming models --- environment.yml | 19 +++++++++++++++++++ pyproject.toml | 13 ++++++++++--- 2 files changed, 29 insertions(+), 3 deletions(-) create mode 100644 environment.yml diff --git a/environment.yml b/environment.yml new file mode 100644 index 0000000..2268e08 --- /dev/null +++ b/environment.yml @@ -0,0 +1,19 @@ +name: ammico-dev +channels: + - pytorch + - nvidia + - rapidsai + - conda-forge + - defaults + +dependencies: + - python=3.11 + - cudatoolkit=11.8 + - pytorch=2.3.1 + - pytorch-cuda=11.8 + - torchvision=0.18.1 + - torchaudio=2.3.1 + - faiss-gpu-raft=1.8.0 + - ipykernel + - jupyterlab + - jupyterlab_widgets \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 4cad313..1507f8b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,16 +18,20 @@ classifiers = [ "Operating System :: OS Independent", "License :: OSI Approved :: MIT License", ] - + dependencies = [ + "accelerate>=0.22", + "bitsandbytes", "colorgram.py", "colour-science", "dash", "dash-bootstrap-components", + "decord", "deepface", "google-cloud-vision", - "googletrans==4.0.0rc1", + "googletrans-py", # instead of googletrans4.0.0rc1, for a temporary solution due the incompatibility with jupyterlab "grpcio", + "huggingface-hub>=0.34.0", "importlib_metadata", "importlib_resources", "matplotlib", @@ -36,12 +40,15 @@ dependencies = [ "pandas", "Pillow", "pooch", + "qwen-vl-utils[decord]==0.0.8", "retina_face", + "safetensors>=0.6.2", "setuptools", "spacy", - "tensorflow<=2.16.0", + "tensorflow<2.15", # instead of <=2.16.0 to make it compatible with CUDA 11.8, may change after updating CUDA version. "tf-keras", "tqdm", + "transformers>=4.54", "webcolors", ] From 5583bbed08f48c133b01e4b0dcbdb528dc6f38a3 Mon Sep 17 00:00:00 2001 From: DimasfromLavoisier Date: Fri, 22 Aug 2025 15:43:38 +0200 Subject: [PATCH 02/23] add Model class --- ammico/__init__.py | 2 + ammico/model.py | 111 +++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 113 insertions(+) create mode 100644 ammico/model.py diff --git a/ammico/__init__.py b/ammico/__init__.py index 67a7065..9a25ade 100644 --- a/ammico/__init__.py +++ b/ammico/__init__.py @@ -1,5 +1,6 @@ from ammico.display import AnalysisExplorer from ammico.faces import EmotionDetector, ethical_disclosure +from ammico.model import MultimodalSummaryModel from ammico.text import TextDetector, TextAnalyzer, privacy_disclosure from ammico.utils import find_files, get_dataframe @@ -14,6 +15,7 @@ except ImportError: __all__ = [ "AnalysisExplorer", "EmotionDetector", + "MultimodalSummaryModel", "TextDetector", "TextAnalyzer", "find_files", diff --git a/ammico/model.py b/ammico/model.py new file mode 100644 index 0000000..80cc31f --- /dev/null +++ b/ammico/model.py @@ -0,0 +1,111 @@ +import torch +import warnings +from transformers import ( + Qwen2_5_VLForConditionalGeneration, + AutoProcessor, + BitsAndBytesConfig, + AutoTokenizer, +) +from typing import Optional + + +class MultimodalSummaryModel: + DEFAULT_CUDA_MODEL = "Qwen/Qwen2.5-VL-7B-Instruct" + DEFAULT_CPU_MODEL = "Qwen/Qwen2.5-VL-3B-Instruct" + + def __init__( + self, + model_id: Optional[str] = None, + device: Optional[str] = None, + cache_dir: Optional[str] = None, + ) -> None: + """ + Class for QWEN-2.5-VL model loading and inference. + Args: + model_id: Type of model to load, defaults to a smaller version for CPU if device is "cpu". + device: "cuda" or "cpu" (auto-detected when None). + cache_dir: huggingface cache dir (optional). + """ + self.device = self._resolve_device(device) + self.model_id = model_id or ( + self.DEFAULT_CUDA_MODEL if self.device == "cuda" else self.DEFAULT_CPU_MODEL + ) + + self.cache_dir = cache_dir + self._trust_remote_code = True + self._quantize = True + + self.model = None + self.processor = None + self.tokenizer = None + + self._load_model_and_processor() + + @staticmethod + def _resolve_device(device: Optional[str]) -> str: + if device is None: + return "cuda" if torch.cuda.is_available() else "cpu" + if device.lower() not in ("cuda", "cpu"): + raise ValueError("device must be 'cuda' or 'cpu'") + if device.lower() == "cuda" and not torch.cuda.is_available(): + warnings.warn( + "Although 'cuda' was requested, no CUDA device is available. Using CPU instead.", + RuntimeWarning, + stacklevel=2, + ) + return "cpu" + return device.lower() + + def _load_model_and_processor(self): + load_kwargs = {"trust_remote_code": self._trust_remote_code, "use_cache": True} + if self.cache_dir: + load_kwargs["cache_dir"] = self.cache_dir + + self.processor = AutoProcessor.from_pretrained( + self.model_id, padding_side="left", **load_kwargs + ) + self.tokenizer = AutoTokenizer.from_pretrained(self.model_id, **load_kwargs) + + if self.device == "cuda": + compute_dtype = ( + torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 + ) + bnb_config = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type="nf4", + bnb_4bit_compute_dtype=compute_dtype, + ) + load_kwargs["quantization_config"] = bnb_config + load_kwargs["device_map"] = "auto" + + else: + load_kwargs.pop("quantization_config", None) + load_kwargs.pop("device_map", None) + + self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained( + self.model_id, **load_kwargs + ) + self.model.eval() + + def _close(self) -> None: + """Free model resources (helpful in long-running processes).""" + try: + if self.model is not None: + del self.model + self.model = None + finally: + try: + if torch.cuda.is_available(): + torch.cuda.empty_cache() + except Exception as e: + warnings.warn( + "Failed to empty CUDA cache. This is not critical, but may lead to memory lingering: " + f"{e!r}", + RuntimeWarning, + stacklevel=2, + ) + + def close(self) -> None: + """Free model resources (helpful in long-running processes).""" + self._close() From bd63be469392a593342cf518531b00e283c8dbbd Mon Sep 17 00:00:00 2001 From: DimasfromLavoisier Date: Fri, 22 Aug 2025 15:43:38 +0200 Subject: [PATCH 03/23] add Model class --- ammico/__init__.py | 2 + ammico/model.py | 111 +++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 113 insertions(+) create mode 100644 ammico/model.py diff --git a/ammico/__init__.py b/ammico/__init__.py index 67a7065..9a25ade 100644 --- a/ammico/__init__.py +++ b/ammico/__init__.py @@ -1,5 +1,6 @@ from ammico.display import AnalysisExplorer from ammico.faces import EmotionDetector, ethical_disclosure +from ammico.model import MultimodalSummaryModel from ammico.text import TextDetector, TextAnalyzer, privacy_disclosure from ammico.utils import find_files, get_dataframe @@ -14,6 +15,7 @@ except ImportError: __all__ = [ "AnalysisExplorer", "EmotionDetector", + "MultimodalSummaryModel", "TextDetector", "TextAnalyzer", "find_files", diff --git a/ammico/model.py b/ammico/model.py new file mode 100644 index 0000000..80cc31f --- /dev/null +++ b/ammico/model.py @@ -0,0 +1,111 @@ +import torch +import warnings +from transformers import ( + Qwen2_5_VLForConditionalGeneration, + AutoProcessor, + BitsAndBytesConfig, + AutoTokenizer, +) +from typing import Optional + + +class MultimodalSummaryModel: + DEFAULT_CUDA_MODEL = "Qwen/Qwen2.5-VL-7B-Instruct" + DEFAULT_CPU_MODEL = "Qwen/Qwen2.5-VL-3B-Instruct" + + def __init__( + self, + model_id: Optional[str] = None, + device: Optional[str] = None, + cache_dir: Optional[str] = None, + ) -> None: + """ + Class for QWEN-2.5-VL model loading and inference. + Args: + model_id: Type of model to load, defaults to a smaller version for CPU if device is "cpu". + device: "cuda" or "cpu" (auto-detected when None). + cache_dir: huggingface cache dir (optional). + """ + self.device = self._resolve_device(device) + self.model_id = model_id or ( + self.DEFAULT_CUDA_MODEL if self.device == "cuda" else self.DEFAULT_CPU_MODEL + ) + + self.cache_dir = cache_dir + self._trust_remote_code = True + self._quantize = True + + self.model = None + self.processor = None + self.tokenizer = None + + self._load_model_and_processor() + + @staticmethod + def _resolve_device(device: Optional[str]) -> str: + if device is None: + return "cuda" if torch.cuda.is_available() else "cpu" + if device.lower() not in ("cuda", "cpu"): + raise ValueError("device must be 'cuda' or 'cpu'") + if device.lower() == "cuda" and not torch.cuda.is_available(): + warnings.warn( + "Although 'cuda' was requested, no CUDA device is available. Using CPU instead.", + RuntimeWarning, + stacklevel=2, + ) + return "cpu" + return device.lower() + + def _load_model_and_processor(self): + load_kwargs = {"trust_remote_code": self._trust_remote_code, "use_cache": True} + if self.cache_dir: + load_kwargs["cache_dir"] = self.cache_dir + + self.processor = AutoProcessor.from_pretrained( + self.model_id, padding_side="left", **load_kwargs + ) + self.tokenizer = AutoTokenizer.from_pretrained(self.model_id, **load_kwargs) + + if self.device == "cuda": + compute_dtype = ( + torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 + ) + bnb_config = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type="nf4", + bnb_4bit_compute_dtype=compute_dtype, + ) + load_kwargs["quantization_config"] = bnb_config + load_kwargs["device_map"] = "auto" + + else: + load_kwargs.pop("quantization_config", None) + load_kwargs.pop("device_map", None) + + self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained( + self.model_id, **load_kwargs + ) + self.model.eval() + + def _close(self) -> None: + """Free model resources (helpful in long-running processes).""" + try: + if self.model is not None: + del self.model + self.model = None + finally: + try: + if torch.cuda.is_available(): + torch.cuda.empty_cache() + except Exception as e: + warnings.warn( + "Failed to empty CUDA cache. This is not critical, but may lead to memory lingering: " + f"{e!r}", + RuntimeWarning, + stacklevel=2, + ) + + def close(self) -> None: + """Free model resources (helpful in long-running processes).""" + self._close() From d20c4d68e4d1b2eacf59bb6ad8c1f03623731f80 Mon Sep 17 00:00:00 2001 From: DimasfromLavoisier Date: Fri, 12 Sep 2025 17:48:57 +0200 Subject: [PATCH 04/23] vqa --- ammico/image_summary.py | 343 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 343 insertions(+) create mode 100644 ammico/image_summary.py diff --git a/ammico/image_summary.py b/ammico/image_summary.py new file mode 100644 index 0000000..c4b2444 --- /dev/null +++ b/ammico/image_summary.py @@ -0,0 +1,343 @@ +from ammico.utils import AnalysisMethod, AnalysisType +from ammico.model import MultimodalSummaryModel + +import os +import torch +from PIL import Image +import warnings + +from typing import List, Optional, Union, Dict, Any +from collections.abc import Sequence as _Sequence +from transformers import GenerationConfig +import re +from qwen_vl_utils import process_vision_info + + +class ImageSummaryDetector(AnalysisMethod): + def __init__( + self, + summary_model: MultimodalSummaryModel, + subdict: dict = {}, + ) -> None: + """ + Class for analysing images using QWEN-2.5-VL model. + It provides methods for generating captions and answering questions about images. + + Args: + summary_model ([type], optional): An instance of MultimodalSummaryModel to be used for analysis. + subdict (dict, optional): Dictionary containing the image to be analysed. Defaults to {}. + + Returns: + None. + """ + + super().__init__(subdict) + self.summary_model = summary_model + + def _load_pil_if_needed( + self, filename: Union[str, os.PathLike, Image.Image] + ) -> Image.Image: + if isinstance(filename, (str, os.PathLike)): + return Image.open(filename).convert("RGB") + elif isinstance(filename, Image.Image): + return filename.convert("RGB") + else: + raise ValueError("filename must be a path or PIL.Image") + + @staticmethod + def _is_sequence_but_not_str(obj: Any) -> bool: + """True for sequence-like but not a string/bytes/PIL.Image.""" + return isinstance(obj, _Sequence) and not isinstance( + obj, (str, bytes, Image.Image) + ) + + def _prepare_inputs( + self, list_of_questions: list[str], entry: Optional[Dict[str, Any]] = None + ) -> Dict[str, torch.Tensor]: + filename = entry.get("filename") + if filename is None: + raise ValueError("entry must contain key 'filename'") + + if isinstance(filename, (str, os.PathLike, Image.Image)): + images_context = self._load_pil_if_needed(filename) + elif self._is_sequence_but_not_str(filename): + images_context = [self._load_pil_if_needed(i) for i in filename] + else: + raise ValueError( + "Unsupported 'filename' entry: expected path, PIL.Image, or sequence." + ) + + images_only_messages = [ + { + "role": "user", + "content": [ + *( + [{"type": "image", "image": img} for img in images_context] + if isinstance(images_context, list) + else [{"type": "image", "image": images_context}] + ) + ], + } + ] + + try: + image_inputs, _ = process_vision_info(images_only_messages) + except Exception as e: + raise RuntimeError(f"Image processing failed: {e}") + + texts: List[str] = [] + for q in list_of_questions: + messages = [ + { + "role": "user", + "content": [ + *( + [ + {"type": "image", "image": image} + for image in images_context + ] + if isinstance(images_context, list) + else [{"type": "image", "image": images_context}] + ), + {"type": "text", "text": q}, + ], + } + ] + text = self.summary_model.processor.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) + texts.append(text) + + images_batch = [image_inputs] * len(texts) + inputs = self.summary_model.processor( + text=texts, + images=images_batch, + padding=True, + return_tensors="pt", + ) + inputs = {k: v.to(self.summary_model.device) for k, v in inputs.items()} + + return inputs + + def analyse_images( + self, + analysis_type: Union[AnalysisType, str] = AnalysisType.SUMMARY_AND_QUESTIONS, + list_of_questions: Optional[List[str]] = None, + max_questions_per_image: int = 32, + keys_batch_size: int = 16, + is_concise_summary: bool = True, + is_concise_answer: bool = True, + ) -> Dict[str, dict]: + """ + Analyse image with model. + + Args: + analysis_type (str): type of the analysis. + list_of_questions (list[str]): list of questions. + max_questions_per_image (int): maximum number of questions per image. We recommend to keep it low to avoid long processing times and high memory usage. + keys_batch_size (int): number of images to process in a batch. + is_concise_summary (bool): whether to generate concise summary. + is_concise_answer (bool): whether to generate concise answers. + Returns: + self.subdict (dict): dictionary with analysis results. + """ + # TODO: add option to ask multiple questions per image as one batch. + if isinstance(analysis_type, AnalysisType): + analysis_type = analysis_type.value + + allowed = {"summary", "questions", "summary_and_questions"} + if analysis_type not in allowed: + raise ValueError(f"analysis_type must be one of {allowed}") + + if list_of_questions is None: + list_of_questions = [ + "Are there people in the image?", + "What is this picture about?", + ] + + keys = list(self.subdict.keys()) + for batch_start in range(0, len(keys), keys_batch_size): + batch_keys = keys[batch_start : batch_start + keys_batch_size] + for key in batch_keys: + entry = self.subdict[key] + if analysis_type in ("summary", "summary_and_questions"): + try: + caps = self.generate_caption( + entry, + num_return_sequences=1, + is_concise_summary=is_concise_summary, + ) + entry["caption"] = caps[0] if caps else "" + except Exception as e: + warnings.warn( + "Caption generation failed for key %s: %s", key, e + ) + + if analysis_type in ("questions", "summary_and_questions"): + if len(list_of_questions) > max_questions_per_image: + raise ValueError( + f"Number of questions per image ({len(list_of_questions)}) exceeds safety cap ({max_questions_per_image})." + " Reduce questions or increase max_questions_per_image." + ) + try: + vqa_map = self.answer_questions( + list_of_questions, entry, is_concise_answer + ) + entry["vqa"] = vqa_map + except Exception as e: + warnings.warn("VQA failed for key %s: %s", key, e) + + self.subdict[key] = entry + return self.subdict + + def generate_caption( + self, + entry: Optional[Dict[str, Any]] = None, + num_return_sequences: int = 1, + is_concise_summary: bool = True, + ) -> List[str]: + """ + Create caption for image. Depending on is_concise_summary it will be either concise or detailed. + + Args: + entry (dict): dictionary containing the image to be captioned. + num_return_sequences (int): number of captions to generate. + is_concise_summary (bool): whether to generate concise summary. + + Returns: + results (list[str]): list of generated captions. + """ + if is_concise_summary: + prompt = ["Describe this image in one concise caption."] + max_new_tokens = 64 + else: + prompt = ["Describe this image."] + max_new_tokens = 256 + inputs = self._prepare_inputs(prompt, entry) + + gen_conf = GenerationConfig( + max_new_tokens=max_new_tokens, + do_sample=False, + num_return_sequences=num_return_sequences, + ) + + with torch.inference_mode(): + try: + if self.summary_model.device == "cuda": + with torch.cuda.amp.autocast(enabled=True): + generated_ids = self.summary_model.model.generate( + **inputs, generation_config=gen_conf + ) + else: + generated_ids = self.summary_model.model.generate( + **inputs, generation_config=gen_conf + ) + except RuntimeError as e: + warnings.warn( + "Retry without autocast failed: %s. Attempting cudnn-disabled retry.", + e, + ) + cudnn_was_enabled = ( + torch.backends.cudnn.is_available() and torch.backends.cudnn.enabled + ) + if cudnn_was_enabled: + torch.backends.cudnn.enabled = False + try: + generated_ids = self.summary_model.model.generate( + **inputs, generation_config=gen_conf + ) + except Exception as retry_error: + raise RuntimeError( + f"Failed to generate ids after retry: {retry_error}" + ) from retry_error + finally: + if cudnn_was_enabled: + torch.backends.cudnn.enabled = True + + decoded = None + if "input_ids" in inputs: + in_ids = inputs["input_ids"] + trimmed = [ + out_ids[len(inp_ids) :] + for inp_ids, out_ids in zip(in_ids, generated_ids) + ] + decoded = self.summary_model.tokenizer.batch_decode( + trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False + ) + else: + decoded = self.summary_model.tokenizer.batch_decode( + generated_ids, + skip_special_tokens=True, + clean_up_tokenization_spaces=False, + ) + + results = [d.strip() for d in decoded] + return results + + def answer_questions( + self, + list_of_questions: list[str], + entry: Optional[Dict[str, Any]] = None, + is_concise_answer: bool = True, + ) -> List[str]: + """ + Create answers for list of questions about image. + Args: + list_of_questions (list[str]): list of questions. + entry (dict): dictionary containing the image to be captioned. + is_concise_answer (bool): whether to generate concise answers. + Returns: + answers (list[str]): list of answers. + """ + if is_concise_answer: + gen_conf = GenerationConfig(max_new_tokens=64, do_sample=False) + for i in range(len(list_of_questions)): + if not list_of_questions[i].strip().endswith("?"): + list_of_questions[i] = list_of_questions[i].strip() + "?" + if not list_of_questions[i].lower().startswith("answer concisely"): + list_of_questions[i] = "Answer concisely: " + list_of_questions[i] + else: + gen_conf = GenerationConfig(max_new_tokens=128, do_sample=False) + + question_chunk_size = 8 + answers: List[str] = [] + n = len(list_of_questions) + for i in range(0, n, question_chunk_size): + chunk = list_of_questions[i : i + question_chunk_size] + inputs = self._prepare_inputs(chunk, entry) + with torch.inference_mode(): + if self.summary_model.device == "cuda": + with torch.cuda.amp.autocast(enabled=True): + out_ids = self.summary_model.model.generate( + **inputs, generation_config=gen_conf + ) + else: + out_ids = self.summary_model.model.generate( + **inputs, generation_config=gen_conf + ) + + if "input_ids" in inputs: + in_ids = inputs["input_ids"] + trimmed_batch = [ + out_row[len(inp_row) :] for inp_row, out_row in zip(in_ids, out_ids) + ] + decoded = self.summary_model.tokenizer.batch_decode( + trimmed_batch, + skip_special_tokens=True, + clean_up_tokenization_spaces=False, + ) + else: + decoded = self.summary_model.tokenizer.batch_decode( + out_ids, + skip_special_tokens=True, + clean_up_tokenization_spaces=False, + ) + + answers.extend([d.strip() for d in decoded]) + + if len(answers) != len(list_of_questions): + raise ValueError( + f"Expected {len(list_of_questions)} answers, but got {len(answers)}, try vary amount of questions" + ) + + return answers From 2326aef4b561e360918c5fa78a5cf6c322517df5 Mon Sep 17 00:00:00 2001 From: Dmitrii Kapitan Date: Mon, 22 Sep 2025 16:40:02 +0200 Subject: [PATCH 05/23] Add example notebook and small fixes --- ammico/__init__.py | 2 + ammico/image_summary.py | 1 - ammico/notebooks/DemoImageSummaryVQA.ipynb | 190 +++++++++++++++++++++ ammico/utils.py | 9 + 4 files changed, 201 insertions(+), 1 deletion(-) create mode 100644 ammico/notebooks/DemoImageSummaryVQA.ipynb diff --git a/ammico/__init__.py b/ammico/__init__.py index 9a25ade..1bf343d 100644 --- a/ammico/__init__.py +++ b/ammico/__init__.py @@ -2,6 +2,7 @@ from ammico.display import AnalysisExplorer from ammico.faces import EmotionDetector, ethical_disclosure from ammico.model import MultimodalSummaryModel from ammico.text import TextDetector, TextAnalyzer, privacy_disclosure +from ammico.image_summary import ImageSummaryDetector from ammico.utils import find_files, get_dataframe # Export the version defined in project metadata @@ -18,6 +19,7 @@ __all__ = [ "MultimodalSummaryModel", "TextDetector", "TextAnalyzer", + "ImageSummaryDetector", "find_files", "get_dataframe", "ethical_disclosure", diff --git a/ammico/image_summary.py b/ammico/image_summary.py index c4b2444..0cdaebe 100644 --- a/ammico/image_summary.py +++ b/ammico/image_summary.py @@ -9,7 +9,6 @@ import warnings from typing import List, Optional, Union, Dict, Any from collections.abc import Sequence as _Sequence from transformers import GenerationConfig -import re from qwen_vl_utils import process_vision_info diff --git a/ammico/notebooks/DemoImageSummaryVQA.ipynb b/ammico/notebooks/DemoImageSummaryVQA.ipynb new file mode 100644 index 0000000..b067e1f --- /dev/null +++ b/ammico/notebooks/DemoImageSummaryVQA.ipynb @@ -0,0 +1,190 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "0", + "metadata": {}, + "source": [ + "# Image summary and visual question answering" + ] + }, + { + "cell_type": "markdown", + "id": "1", + "metadata": {}, + "source": [ + "This notebook shows how to generate image captions and use the visual question answering with AMMICO. \n", + "\n", + "The first cell imports `ammico`.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2", + "metadata": {}, + "outputs": [], + "source": [ + "import ammico" + ] + }, + { + "cell_type": "markdown", + "id": "3", + "metadata": {}, + "source": [ + "The cell below loads the model for VQA tasks. By default, it loads a large model on the GPU (if your device supports CUDA), otherwise it loads a relatively smaller model on the CPU. But you can specify other settings (e.g., a small model on the GPU) if you want." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4", + "metadata": {}, + "outputs": [], + "source": [ + "model = ammico.MultimodalSummaryModel()" + ] + }, + { + "cell_type": "markdown", + "id": "5", + "metadata": {}, + "source": [ + "Here you need to provide the path to your google drive folder or local folder containing the images" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6", + "metadata": {}, + "outputs": [], + "source": [ + "image_dict = ammico.find_files(\n", + " path=str(\"/insert/your/path/here/\"),\n", + " limit=-1, # -1 means no limit on the number of files, by default it is set to 20\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "7", + "metadata": {}, + "source": [ + "The cell below creates an object that analyzes images and generates a summary using a specific model and image data." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8", + "metadata": {}, + "outputs": [], + "source": [ + "img = ammico.ImageSummaryDetector(summary_model=model, subdict=image_dict)" + ] + }, + { + "cell_type": "markdown", + "id": "9", + "metadata": {}, + "source": [ + "## Image summary " + ] + }, + { + "cell_type": "markdown", + "id": "10", + "metadata": {}, + "source": [ + "To start your work with images, you should call the `analyse_images` method.\n", + "\n", + "You can specify what kind of analysis you want to perform with `analysis_type`. `\"summary\"` will generate a summary for all pictures in your dictionary, `\"questions\"` will prepare answers to your questions for all pictures, and `\"summary_and_questions\"` will do both.\n", + "\n", + "Parameter `\"is_concise_summary\"` regulates the length of an answer.\n", + "\n", + "Here we want to get a long summary on each object in our image dictionary." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "11", + "metadata": {}, + "outputs": [], + "source": [ + "summaries = img.analyse_images(analysis_type=\"summary\", is_concise_summary=False)" + ] + }, + { + "cell_type": "markdown", + "id": "12", + "metadata": {}, + "source": [ + "## VQA" + ] + }, + { + "cell_type": "markdown", + "id": "13", + "metadata": {}, + "source": [ + "In addition to analyzing images in `ammico`, the same model can be used in VQA mode. To do this, you need to define the questions that will be applied to all images from your dict." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "14", + "metadata": {}, + "outputs": [], + "source": [ + "questions = [\"Are there any visible signs of violence?\", \"Is it safe to be there?\"]" + ] + }, + { + "cell_type": "markdown", + "id": "15", + "metadata": {}, + "source": [ + "Here is an example of VQA mode usage. You can specify whether you want to receive short answers (recommended option) or not." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "16", + "metadata": {}, + "outputs": [], + "source": [ + "vqa_results = img.analyse_images(\n", + " analysis_type=\"questions\",\n", + " list_of_questions=questions,\n", + " is_concise_answer=True,\n", + ")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "ammico-dev", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.13" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/ammico/utils.py b/ammico/utils.py index 39a0ecb..38f4144 100644 --- a/ammico/utils.py +++ b/ammico/utils.py @@ -7,6 +7,9 @@ import collections import random +from enum import Enum + + pkg = importlib_resources.files("ammico") @@ -40,6 +43,12 @@ def ammico_prefetch_models(): res.get() +class AnalysisType(str, Enum): + SUMMARY = "summary" + QUESTIONS = "questions" + SUMMARY_AND_QUESTIONS = "summary_and_questions" + + class AnalysisMethod: """Base class to be inherited by all analysis methods.""" From 0f6f9026cd0dc74e899e4e077e6cafb2606e2a77 Mon Sep 17 00:00:00 2001 From: Inga Ulusoy Date: Thu, 25 Sep 2025 12:45:14 +0200 Subject: [PATCH 06/23] fix: missing dependency, obsolete keyword, dash maintenance, demo notebook for new summary --- ammico/display.py | 25 ++++-------------- ammico/notebooks/DemoNotebook_ammico.ipynb | 30 +++++++++------------- ammico/text.py | 11 ++------ pyproject.toml | 1 + 4 files changed, 20 insertions(+), 47 deletions(-) diff --git a/ammico/display.py b/ammico/display.py index c80d0db..2e7c9e4 100644 --- a/ammico/display.py +++ b/ammico/display.py @@ -94,7 +94,6 @@ class AnalysisExplorer: State("left_select_id", "options"), State("left_select_id", "value"), State("Dropdown_select_Detector", "value"), - State("setting_Text_analyse_text", "value"), State("setting_privacy_env_var", "value"), State("setting_Emotion_emotion_threshold", "value"), State("setting_Emotion_race_threshold", "value"), @@ -157,14 +156,6 @@ class AnalysisExplorer: id="settings_TextDetector", style={"display": "none"}, children=[ - dbc.Row( - dcc.Checklist( - ["Analyse text"], - ["Analyse text"], - id="setting_Text_analyse_text", - style={"margin-bottom": "10px"}, - ), - ), # row 1 dbc.Row( dbc.Col( @@ -344,7 +335,7 @@ class AnalysisExplorer: port (int, optional): The port number to run the server on (default: 8050). """ - self.app.run_server(debug=True, port=port) + self.app.run(debug=True, port=port) # Dash callbacks def update_picture(self, img_path: str): @@ -375,16 +366,15 @@ class AnalysisExplorer: } if setting_input == "TextDetector": - return display_flex, display_none, display_none, display_none + return display_flex, display_none, display_none if setting_input == "EmotionDetector": - return display_none, display_flex, display_none, display_none - + return display_none, display_flex, display_none if setting_input == "ColorDetector": - return display_none, display_none, display_flex, display_none + return display_none, display_none, display_flex else: - return display_none, display_none, display_none, display_none + return display_none, display_none, display_none def _right_output_analysis( self, @@ -392,7 +382,6 @@ class AnalysisExplorer: all_img_options: dict, current_img_value: str, detector_value: str, - settings_text_analyse_text: list, setting_privacy_env_var: str, setting_emotion_emotion_threshold: int, setting_emotion_race_threshold: int, @@ -426,12 +415,8 @@ class AnalysisExplorer: identify_function = identify_dict[detector_value] if detector_value == "TextDetector": - analyse_text = ( - True if settings_text_analyse_text == ["Analyse text"] else False - ) detector_class = identify_function( image_copy, - analyse_text=analyse_text, accept_privacy=( setting_privacy_env_var if setting_privacy_env_var diff --git a/ammico/notebooks/DemoNotebook_ammico.ipynb b/ammico/notebooks/DemoNotebook_ammico.ipynb index fc8fe22..e17860c 100644 --- a/ammico/notebooks/DemoNotebook_ammico.ipynb +++ b/ammico/notebooks/DemoNotebook_ammico.ipynb @@ -104,7 +104,8 @@ "import ammico\n", "\n", "# for displaying a progress bar\n", - "from tqdm import tqdm" + "from tqdm import tqdm\n", + "import os" ] }, { @@ -140,7 +141,9 @@ "metadata": {}, "outputs": [], "source": [ - "# os.environ[\"GOOGLE_APPLICATION_CREDENTIALS\"] = \"/content/drive/MyDrive/misinformation-data/misinformation-campaign-981aa55a3b13.json\"" + "os.environ[\"GOOGLE_APPLICATION_CREDENTIALS\"] = (\n", + " \"/home/inga/projects/misinformation-project/misinformation-notes/misinformation-campaign-981aa55a3b13.json\"\n", + ")" ] }, { @@ -171,6 +174,7 @@ "metadata": {}, "outputs": [], "source": [ + "data_path = \"./data-test\"\n", "image_dict = ammico.find_files(\n", " # path = \"/content/drive/MyDrive/misinformation-data/\",\n", " path=str(data_path),\n", @@ -337,7 +341,7 @@ " enumerate(image_dict.keys()), total=len(image_dict)\n", "): # loop through all images\n", " image_dict[key] = ammico.TextDetector(\n", - " image_dict[key], analyse_text=True\n", + " image_dict[key]\n", " ).analyse_image() # analyse image with EmotionDetector and update dict\n", "\n", " if (\n", @@ -361,23 +365,12 @@ "outputs": [], "source": [ "# initialize the models\n", - "image_summary_detector = ammico.SummaryDetector(\n", - " subdict=image_dict, analysis_type=\"summary\", model_type=\"base\"\n", + "model = ammico.MultimodalSummaryModel()\n", + "image_summary_detector = ammico.ImageSummaryDetector(\n", + " subdict=image_dict, summary_model=model\n", ")\n", "\n", - "# run the analysis without having to re-iniatialize the model\n", - "for num, key in tqdm(\n", - " enumerate(image_dict.keys()), total=len(image_dict)\n", - "): # loop through all images\n", - " image_dict[key] = image_summary_detector.analyse_image(\n", - " subdict=image_dict[key], analysis_type=\"summary\"\n", - " ) # analyse image with SummaryDetector and update dict\n", - "\n", - " if (\n", - " num % dump_every == 0 | num == len(image_dict) - 1\n", - " ): # save results every dump_every to dump_file\n", - " image_df = ammico.get_dataframe(image_dict)\n", - " image_df.to_csv(dump_file)" + "image_summary_detector.analyse_images(analysis_type=\"summary\")" ] }, { @@ -394,6 +387,7 @@ "outputs": [], "source": [ "# initialize the models\n", + "# currently this does not work because of the way the summary detector is implemented\n", "image_summary_detector = ammico.SummaryDetector(\n", " subdict=image_dict, analysis_type=\"summary\", model_type=\"base\"\n", ")\n", diff --git a/ammico/text.py b/ammico/text.py index bf39cc6..4bec28c 100644 --- a/ammico/text.py +++ b/ammico/text.py @@ -67,7 +67,6 @@ class TextDetector(AnalysisMethod): def __init__( self, subdict: dict, - analyse_text: bool = False, skip_extraction: bool = False, accept_privacy: str = "PRIVACY_AMMICO", ) -> None: @@ -76,8 +75,6 @@ class TextDetector(AnalysisMethod): Args: subdict (dict): Dictionary containing file name/path, and possibly previous analysis results from other modules. - analyse_text (bool, optional): Decide if extracted text will be further subject - to analysis. Defaults to False. skip_extraction (bool, optional): Decide if text will be extracted from images or is already provided via a csv. Defaults to False. accept_privacy (str, optional): Environment variable to accept the privacy @@ -96,17 +93,13 @@ class TextDetector(AnalysisMethod): "Privacy disclosure not accepted - skipping text detection." ) self.translator = Translator(raise_exception=True) - if not isinstance(analyse_text, bool): - raise ValueError("analyse_text needs to be set to true or false") - self.analyse_text = analyse_text self.skip_extraction = skip_extraction if not isinstance(skip_extraction, bool): raise ValueError("skip_extraction needs to be set to true or false") if self.skip_extraction: print("Skipping text extraction from image.") print("Reading text directly from provided dictionary.") - if self.analyse_text: - self._initialize_spacy() + self._initialize_spacy() def set_keys(self) -> dict: """Set the default keys for text analysis. @@ -183,7 +176,7 @@ class TextDetector(AnalysisMethod): self._truncate_text() self.translate_text() self.remove_linebreaks() - if self.analyse_text and self.subdict["text_english"]: + if self.subdict["text_english"]: self._run_spacy() return self.subdict diff --git a/pyproject.toml b/pyproject.toml index 1507f8b..c0f5440 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -47,6 +47,7 @@ dependencies = [ "spacy", "tensorflow<2.15", # instead of <=2.16.0 to make it compatible with CUDA 11.8, may change after updating CUDA version. "tf-keras", + "torchvision", "tqdm", "transformers>=4.54", "webcolors", From d1a4954669b4917fa78cb40f23ab9f3fe21437ed Mon Sep 17 00:00:00 2001 From: Inga Ulusoy Date: Fri, 26 Sep 2025 08:51:09 +0200 Subject: [PATCH 07/23] tests: remove analyse_text keyword for text detector --- ammico/test/test_display.py | 1 - ammico/test/test_text.py | 20 ++++---------------- 2 files changed, 4 insertions(+), 17 deletions(-) diff --git a/ammico/test/test_display.py b/ammico/test/test_display.py index 83d53dc..d7241a2 100644 --- a/ammico/test/test_display.py +++ b/ammico/test/test_display.py @@ -50,7 +50,6 @@ def test_right_output_analysis_emotions(get_AE, get_options, monkeypatch): get_options[3], get_options[0], "EmotionDetector", - True, "SOME_VAR", 50, 50, diff --git a/ammico/test/test_text.py b/ammico/test/test_text.py index cffb321..cd9f863 100644 --- a/ammico/test/test_text.py +++ b/ammico/test/test_text.py @@ -52,24 +52,16 @@ def test_privacy_statement(monkeypatch): def test_TextDetector(set_testdict, accepted): for item in set_testdict: test_obj = tt.TextDetector(set_testdict[item], accept_privacy=accepted) - assert not test_obj.analyse_text assert not test_obj.skip_extraction assert test_obj.subdict["filename"] == set_testdict[item]["filename"] - test_obj = tt.TextDetector( - {}, analyse_text=True, skip_extraction=True, accept_privacy=accepted - ) - assert test_obj.analyse_text + test_obj = tt.TextDetector({}, skip_extraction=True, accept_privacy=accepted) assert test_obj.skip_extraction - with pytest.raises(ValueError): - tt.TextDetector({}, analyse_text=1.0, accept_privacy=accepted) with pytest.raises(ValueError): tt.TextDetector({}, skip_extraction=1.0, accept_privacy=accepted) def test_run_spacy(set_testdict, get_path, accepted): - test_obj = tt.TextDetector( - set_testdict["IMG_3755"], analyse_text=True, accept_privacy=accepted - ) + test_obj = tt.TextDetector(set_testdict["IMG_3755"], accept_privacy=accepted) ref_file = get_path + "text_IMG_3755.txt" with open(ref_file, "r") as file: reference_text = file.read() @@ -108,15 +100,11 @@ def test_analyse_image(set_testdict, set_environ, accepted): for item in set_testdict: test_obj = tt.TextDetector(set_testdict[item], accept_privacy=accepted) test_obj.analyse_image() - test_obj = tt.TextDetector( - set_testdict[item], analyse_text=True, accept_privacy=accepted - ) + test_obj = tt.TextDetector(set_testdict[item], accept_privacy=accepted) test_obj.analyse_image() testdict = {} testdict["text"] = 20000 * "m" - test_obj = tt.TextDetector( - testdict, skip_extraction=True, analyse_text=True, accept_privacy=accepted - ) + test_obj = tt.TextDetector(testdict, skip_extraction=True, accept_privacy=accepted) test_obj.analyse_image() assert test_obj.subdict["text_truncated"] == 5000 * "m" assert test_obj.subdict["text"] == 20000 * "m" From ece132fe14a6139d88d2d7e7559ab6ed10980b21 Mon Sep 17 00:00:00 2001 From: DimasfromLavoisier Date: Fri, 26 Sep 2025 17:29:46 +0200 Subject: [PATCH 08/23] optimize validation of analysis type --- ammico/image_summary.py | 58 +++++++++++++++++++++++++++-------------- 1 file changed, 38 insertions(+), 20 deletions(-) diff --git a/ammico/image_summary.py b/ammico/image_summary.py index 0cdaebe..3ccc3f4 100644 --- a/ammico/image_summary.py +++ b/ammico/image_summary.py @@ -6,7 +6,7 @@ import torch from PIL import Image import warnings -from typing import List, Optional, Union, Dict, Any +from typing import List, Optional, Union, Dict, Any, Tuple from collections.abc import Sequence as _Sequence from transformers import GenerationConfig from qwen_vl_utils import process_vision_info @@ -118,6 +118,36 @@ class ImageSummaryDetector(AnalysisMethod): return inputs + def _validate_analysis_type( + self, + analysis_type: Union["AnalysisType", str], + list_of_questions: Optional[List[str]], + max_questions_per_image: int, + ) -> Tuple[str, List[str], bool, bool]: + if isinstance(analysis_type, AnalysisType): + analysis_type = analysis_type.value + + allowed = {"summary", "questions", "summary_and_questions"} + if analysis_type not in allowed: + raise ValueError(f"analysis_type must be one of {allowed}") + + if list_of_questions is None: + list_of_questions = [ + "Are there people in the image?", + "What is this picture about?", + ] + + if analysis_type in ("questions", "summary_and_questions"): + if len(list_of_questions) > max_questions_per_image: + raise ValueError( + f"Number of questions per image ({len(list_of_questions)}) exceeds safety cap ({max_questions_per_image}). Reduce questions or increase max_questions_per_image." + ) + + is_summary = analysis_type in ("summary", "summary_and_questions") + is_questions = analysis_type in ("questions", "summary_and_questions") + + return analysis_type, list_of_questions, is_summary, is_questions + def analyse_images( self, analysis_type: Union[AnalysisType, str] = AnalysisType.SUMMARY_AND_QUESTIONS, @@ -141,25 +171,18 @@ class ImageSummaryDetector(AnalysisMethod): self.subdict (dict): dictionary with analysis results. """ # TODO: add option to ask multiple questions per image as one batch. - if isinstance(analysis_type, AnalysisType): - analysis_type = analysis_type.value - - allowed = {"summary", "questions", "summary_and_questions"} - if analysis_type not in allowed: - raise ValueError(f"analysis_type must be one of {allowed}") - - if list_of_questions is None: - list_of_questions = [ - "Are there people in the image?", - "What is this picture about?", - ] + analysis_type, list_of_questions, is_summary, is_questions = ( + self._validate_analysis_type( + analysis_type, list_of_questions, max_questions_per_image + ) + ) keys = list(self.subdict.keys()) for batch_start in range(0, len(keys), keys_batch_size): batch_keys = keys[batch_start : batch_start + keys_batch_size] for key in batch_keys: entry = self.subdict[key] - if analysis_type in ("summary", "summary_and_questions"): + if is_summary: try: caps = self.generate_caption( entry, @@ -172,12 +195,7 @@ class ImageSummaryDetector(AnalysisMethod): "Caption generation failed for key %s: %s", key, e ) - if analysis_type in ("questions", "summary_and_questions"): - if len(list_of_questions) > max_questions_per_image: - raise ValueError( - f"Number of questions per image ({len(list_of_questions)}) exceeds safety cap ({max_questions_per_image})." - " Reduce questions or increase max_questions_per_image." - ) + if is_questions: try: vqa_map = self.answer_questions( list_of_questions, entry, is_concise_answer From 5c7e2c3f640241fb00b4a51d82e9d200675e4a1e Mon Sep 17 00:00:00 2001 From: DimasfromLavoisier Date: Fri, 26 Sep 2025 18:23:58 +0200 Subject: [PATCH 09/23] 1st try --- ammico/display.py | 79 ++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 78 insertions(+), 1 deletion(-) diff --git a/ammico/display.py b/ammico/display.py index 2e7c9e4..ff86a3e 100644 --- a/ammico/display.py +++ b/ammico/display.py @@ -5,6 +5,8 @@ import pandas as pd from dash import html, Input, Output, dcc, State, Dash from PIL import Image import dash_bootstrap_components as dbc +import warnings +from typing import Dict, Any, List COLOR_SCHEMES = [ @@ -94,6 +96,9 @@ class AnalysisExplorer: State("left_select_id", "options"), State("left_select_id", "value"), State("Dropdown_select_Detector", "value"), + State("Dropdown_analysis_type", "value"), + State("checkbox_enable_image_tasks", "value"), + State("textarea_questions", "value"), State("setting_privacy_env_var", "value"), State("setting_Emotion_emotion_threshold", "value"), State("setting_Emotion_race_threshold", "value"), @@ -291,6 +296,37 @@ class AnalysisExplorer: ), justify="start", ), + # NEW: Analysis-type selector (summary/questions/summary_and_questions) + dbc.Row( + dcc.Dropdown( + id="Dropdown_analysis_type", + options=[{"label": v, "value": v} for v in SUMMARY_ANALYSIS_TYPE], + value="summary_and_questions", + style={"width": "60%", "margin-top": "8px"}, + ), + justify="start", + ), + # NEW: Enable image-level tasks (VQA / caption) checkbox + dbc.Row( + dcc.Checklist( + id="checkbox_enable_image_tasks", + options=[{"label": "Enable Image Tasks (Caption / VQA)", "value": "enabled"}], + value=["enabled"], # default enabled + inline=True, + style={"margin-top": "8px"}, + ), + justify="start", + ), + # NEW: Questions textarea (newline-separated). Only used if analysis_type includes "questions". + dbc.Row( + dcc.Textarea( + id="textarea_questions", + value="Are there people in the image?\nWhat is this picture about?", + placeholder="One question per line...", + style={"width": "60%", "height": "120px", "margin-top": "8px"}, + ), + justify="start", + ), dbc.Row( children=[self._create_setting_layout()], id="div_detector_args", @@ -383,6 +419,7 @@ class AnalysisExplorer: current_img_value: str, detector_value: str, setting_privacy_env_var: str, + checkbox_enable_image_tasks_value: List[str], setting_emotion_emotion_threshold: int, setting_emotion_race_threshold: int, setting_emotion_gender_threshold: int, @@ -414,6 +451,10 @@ class AnalysisExplorer: # detector value is the string name of the chosen detector identify_function = identify_dict[detector_value] + identify_function = identify_dict.get(detector_value) + if identify_function is None: + detector_class = None + if detector_value == "TextDetector": detector_class = identify_function( image_copy, @@ -442,8 +483,32 @@ class AnalysisExplorer: ) else: detector_class = identify_function(image_copy) - analysis_dict = detector_class.analyse_image() + + if detector_class is not None: + analysis_dict = detector_class.analyse_image() + else: + analysis_dict = {} + image_tasks_result: Dict[str, Any] = {} + enable_image_tasks = "enabled" in (checkbox_enable_image_tasks_value or []) + if enable_image_tasks: + # parse questions textarea: newline separated + if textarea_questions_value: + questions_list = [q.strip() for q in textarea_questions_value.splitlines() if q.strip()] + else: + questions_list = None + + try: + image_tasks_result = self.analyse_image( + image_copy, + analysis_type=analysis_type_value, + list_of_questions=questions_list, + is_concise_summary=True, + is_concise_answer=True, + ) + except Exception as e: + warnings.warn(f"Image tasks failed: {e}") + image_tasks_result = {"image_tasks_error": str(e)} # Initialize an empty dictionary new_analysis_dict = {} @@ -459,6 +524,18 @@ class AnalysisExplorer: # Add the new key-value pair to the new dictionary new_analysis_dict[k] = new_value + if "caption" in image_tasks_result: + new_analysis_dict["caption"] = image_tasks_result.get("caption", "") + if "vqa" in image_tasks_result: + # vqa is expected to be a dict; convert to readable string + vqa_entries = image_tasks_result["vqa"] + if isinstance(vqa_entries, dict): + new_analysis_dict["vqa"] = "; ".join([f"{q}: {a}" for q, a in vqa_entries.items()]) + else: + new_analysis_dict["vqa"] = str(vqa_entries) + for err_key in ("caption_error", "vqa_error", "image_tasks_error"): + if err_key in image_tasks_result: + new_analysis_dict[err_key] = image_tasks_result[err_key] df = pd.DataFrame([new_analysis_dict]).set_index("filename").T df.index.rename("filename", inplace=True) From 402a379f9c819d89b752ebb4c27af261c3f7c097 Mon Sep 17 00:00:00 2001 From: Dmitrii Kapitan Date: Sat, 27 Sep 2025 16:42:05 +0200 Subject: [PATCH 10/23] basic integration into display functionality --- ammico/display.py | 251 +++++++++++++++++++++++----------------- ammico/image_summary.py | 58 ++++++++-- 2 files changed, 192 insertions(+), 117 deletions(-) diff --git a/ammico/display.py b/ammico/display.py index ff86a3e..b916dbf 100644 --- a/ammico/display.py +++ b/ammico/display.py @@ -1,12 +1,14 @@ import ammico.faces as faces import ammico.text as text import ammico.colors as colors +import ammico.image_summary as image_summary +from ammico.model import MultimodalSummaryModel import pandas as pd from dash import html, Input, Output, dcc, State, Dash from PIL import Image import dash_bootstrap_components as dbc import warnings -from typing import Dict, Any, List +from typing import Dict, Any, List, Optional COLOR_SCHEMES = [ @@ -97,7 +99,6 @@ class AnalysisExplorer: State("left_select_id", "value"), State("Dropdown_select_Detector", "value"), State("Dropdown_analysis_type", "value"), - State("checkbox_enable_image_tasks", "value"), State("textarea_questions", "value"), State("setting_privacy_env_var", "value"), State("setting_Emotion_emotion_threshold", "value"), @@ -112,9 +113,15 @@ class AnalysisExplorer: Output("settings_TextDetector", "style"), Output("settings_EmotionDetector", "style"), Output("settings_ColorDetector", "style"), + Output("settings_VQA", "style"), Input("Dropdown_select_Detector", "value"), )(self._update_detector_setting) + self.app.callback( + Output("textarea_questions", "style"), + Input("Dropdown_analysis_type", "value"), + )(self._show_questions_textarea_on_demand) + # I split the different sections into subfunctions for better clarity def _top_file_explorer(self, mydict: dict) -> html.Div: """Initialize the file explorer dropdown for selecting the file to be analyzed. @@ -268,8 +275,69 @@ class AnalysisExplorer: ) ], ), + # start VQA settings + html.Div( + id="settings_VQA", + style={"display": "none"}, + children=[ + dbc.Card( + [ + dbc.CardBody( + [ + dbc.Row( + dbc.Col( + dcc.Dropdown( + id="Dropdown_analysis_type", + options=[ + {"label": v, "value": v} + for v in SUMMARY_ANALYSIS_TYPE + ], + value="summary_and_questions", + clearable=False, + style={ + "width": "100%", + "minWidth": "240px", + "maxWidth": "520px", + }, + ), + ), + justify="start", + ), + html.Div(style={"height": "8px"}), + dbc.Row( + [ + dbc.Col( + dcc.Textarea( + id="textarea_questions", + value="Are there people in the image?\nWhat is this picture about?", + placeholder="One question per line...", + style={ + "width": "100%", + "minHeight": "160px", + "height": "220px", + "resize": "vertical", + "overflow": "auto", + }, + rows=8, + ), + width=12, + ), + ], + justify="start", + ), + ] + ) + ], + style={ + "width": "100%", + "marginTop": "10px", + "zIndex": 2000, + }, + ) + ], + ), ], - style={"width": "100%", "display": "inline-block"}, + style={"width": "100%", "display": "inline-block", "overflow": "visible"}, ) return settings_layout @@ -289,6 +357,7 @@ class AnalysisExplorer: "TextDetector", "EmotionDetector", "ColorDetector", + "VQA", ], value="TextDetector", id="Dropdown_select_Detector", @@ -296,37 +365,6 @@ class AnalysisExplorer: ), justify="start", ), - # NEW: Analysis-type selector (summary/questions/summary_and_questions) - dbc.Row( - dcc.Dropdown( - id="Dropdown_analysis_type", - options=[{"label": v, "value": v} for v in SUMMARY_ANALYSIS_TYPE], - value="summary_and_questions", - style={"width": "60%", "margin-top": "8px"}, - ), - justify="start", - ), - # NEW: Enable image-level tasks (VQA / caption) checkbox - dbc.Row( - dcc.Checklist( - id="checkbox_enable_image_tasks", - options=[{"label": "Enable Image Tasks (Caption / VQA)", "value": "enabled"}], - value=["enabled"], # default enabled - inline=True, - style={"margin-top": "8px"}, - ), - justify="start", - ), - # NEW: Questions textarea (newline-separated). Only used if analysis_type includes "questions". - dbc.Row( - dcc.Textarea( - id="textarea_questions", - value="Are there people in the image?\nWhat is this picture about?", - placeholder="One question per line...", - style={"width": "60%", "height": "120px", "margin-top": "8px"}, - ), - justify="start", - ), dbc.Row( children=[self._create_setting_layout()], id="div_detector_args", @@ -402,15 +440,22 @@ class AnalysisExplorer: } if setting_input == "TextDetector": - return display_flex, display_none, display_none + return display_flex, display_none, display_none, display_none if setting_input == "EmotionDetector": - return display_none, display_flex, display_none + return display_none, display_flex, display_none, display_none if setting_input == "ColorDetector": - return display_none, display_none, display_flex - + return display_none, display_none, display_flex, display_none + if setting_input == "VQA": + return display_none, display_none, display_none, display_flex else: - return display_none, display_none, display_none + return display_none, display_none, display_none, display_none + + def _parse_questions(self, text: Optional[str]) -> Optional[List[str]]: + if not text: + return None + qs = [q.strip() for q in text.splitlines() if q.strip()] + return qs if qs else None def _right_output_analysis( self, @@ -418,8 +463,9 @@ class AnalysisExplorer: all_img_options: dict, current_img_value: str, detector_value: str, + analysis_type_value: str, + textarea_questions_value: str, setting_privacy_env_var: str, - checkbox_enable_image_tasks_value: List[str], setting_emotion_emotion_threshold: int, setting_emotion_race_threshold: int, setting_emotion_gender_threshold: int, @@ -439,78 +485,71 @@ class AnalysisExplorer: "EmotionDetector": faces.EmotionDetector, "TextDetector": text.TextDetector, "ColorDetector": colors.ColorDetector, + "VQA": image_summary.ImageSummaryDetector, } # Get image ID from dropdown value, which is the filepath if current_img_value is None: return {} image_id = all_img_options[current_img_value] - # copy image so prvious runs don't leave their default values in the dict - image_copy = self.mydict[image_id].copy() - - # detector value is the string name of the chosen detector - identify_function = identify_dict[detector_value] - - identify_function = identify_dict.get(detector_value) - if identify_function is None: - detector_class = None - - if detector_value == "TextDetector": - detector_class = identify_function( - image_copy, - accept_privacy=( - setting_privacy_env_var - if setting_privacy_env_var - else "PRIVACY_AMMICO" - ), - ) - elif detector_value == "EmotionDetector": - detector_class = identify_function( - image_copy, - emotion_threshold=setting_emotion_emotion_threshold, - race_threshold=setting_emotion_race_threshold, - gender_threshold=setting_emotion_gender_threshold, - accept_disclosure=( - setting_emotion_env_var - if setting_emotion_env_var - else "DISCLOSURE_AMMICO" - ), - ) - elif detector_value == "ColorDetector": - detector_class = identify_function( - image_copy, - delta_e_method=setting_color_delta_e_method, - ) - else: - detector_class = identify_function(image_copy) - - if detector_class is not None: - analysis_dict = detector_class.analyse_image() - else: - analysis_dict = {} - - image_tasks_result: Dict[str, Any] = {} - enable_image_tasks = "enabled" in (checkbox_enable_image_tasks_value or []) - if enable_image_tasks: - # parse questions textarea: newline separated - if textarea_questions_value: - questions_list = [q.strip() for q in textarea_questions_value.splitlines() if q.strip()] - else: - questions_list = None + image_copy = self.mydict.get(image_id, {}).copy() + analysis_dict: Dict[str, Any] = {} + if detector_value == "VQA": try: - image_tasks_result = self.analyse_image( + qwen_model = MultimodalSummaryModel( + model_id="Qwen/Qwen2.5-VL-3B-Instruct" + ) # TODO: allow user to specify model + vqa_cls = identify_dict.get("VQA") + vqa_detector = vqa_cls(qwen_model, subdict={}) + questions_list = self._parse_questions(textarea_questions_value) + analysis_result = vqa_detector.analyse_image( image_copy, analysis_type=analysis_type_value, list_of_questions=questions_list, is_concise_summary=True, is_concise_answer=True, ) + analysis_dict = analysis_result or {} except Exception as e: - warnings.warn(f"Image tasks failed: {e}") - image_tasks_result = {"image_tasks_error": str(e)} - # Initialize an empty dictionary - new_analysis_dict = {} + warnings.warn(f"VQA/Image tasks failed: {e}") + analysis_dict = {"image_tasks_error": str(e)} + else: + # detector value is the string name of the chosen detector + identify_function = identify_dict[detector_value] + + if detector_value == "TextDetector": + detector_class = identify_function( + image_copy, + accept_privacy=( + setting_privacy_env_var + if setting_privacy_env_var + else "PRIVACY_AMMICO" + ), + ) + elif detector_value == "EmotionDetector": + detector_class = identify_function( + image_copy, + emotion_threshold=setting_emotion_emotion_threshold, + race_threshold=setting_emotion_race_threshold, + gender_threshold=setting_emotion_gender_threshold, + accept_disclosure=( + setting_emotion_env_var + if setting_emotion_env_var + else "DISCLOSURE_AMMICO" + ), + ) + elif detector_value == "ColorDetector": + detector_class = identify_function( + image_copy, + delta_e_method=setting_color_delta_e_method, + ) + else: + detector_class = identify_function(image_copy) + + analysis_dict = detector_class.analyse_image() + + new_analysis_dict: Dict[str, Any] = {} # Iterate over the items in the original dictionary for k, v in analysis_dict.items(): @@ -524,21 +563,15 @@ class AnalysisExplorer: # Add the new key-value pair to the new dictionary new_analysis_dict[k] = new_value - if "caption" in image_tasks_result: - new_analysis_dict["caption"] = image_tasks_result.get("caption", "") - if "vqa" in image_tasks_result: - # vqa is expected to be a dict; convert to readable string - vqa_entries = image_tasks_result["vqa"] - if isinstance(vqa_entries, dict): - new_analysis_dict["vqa"] = "; ".join([f"{q}: {a}" for q, a in vqa_entries.items()]) - else: - new_analysis_dict["vqa"] = str(vqa_entries) - for err_key in ("caption_error", "vqa_error", "image_tasks_error"): - if err_key in image_tasks_result: - new_analysis_dict[err_key] = image_tasks_result[err_key] df = pd.DataFrame([new_analysis_dict]).set_index("filename").T df.index.rename("filename", inplace=True) return dbc.Table.from_dataframe( df, striped=True, bordered=True, hover=True, index=True ) + + def _show_questions_textarea_on_demand(self, analysis_type_value: str) -> dict: + if analysis_type_value in ("questions", "summary_and_questions"): + return {"display": "block", "width": "100%"} + else: + return {"display": "none"} diff --git a/ammico/image_summary.py b/ammico/image_summary.py index 3ccc3f4..203ef21 100644 --- a/ammico/image_summary.py +++ b/ammico/image_summary.py @@ -16,7 +16,7 @@ class ImageSummaryDetector(AnalysisMethod): def __init__( self, summary_model: MultimodalSummaryModel, - subdict: dict = {}, + subdict: Optional[Dict[str, Any]] = None, ) -> None: """ Class for analysing images using QWEN-2.5-VL model. @@ -29,6 +29,8 @@ class ImageSummaryDetector(AnalysisMethod): Returns: None. """ + if subdict is None: + subdict = {} super().__init__(subdict) self.summary_model = summary_model @@ -148,7 +150,50 @@ class ImageSummaryDetector(AnalysisMethod): return analysis_type, list_of_questions, is_summary, is_questions - def analyse_images( + def analyse_image( + self, + entry: dict, + analysis_type: Union[str, AnalysisType] = AnalysisType.SUMMARY_AND_QUESTIONS, + list_of_questions: Optional[List[str]] = None, + max_questions_per_image: int = 32, + is_concise_summary: bool = True, + is_concise_answer: bool = True, + ) -> Dict[str, Any]: + """ + Analyse a single image entry. Returns dict with keys depending on analysis_type: + - 'caption' (str) if summary requested + - 'vqa' (dict) if questions requested + """ + self.subdict = entry + analysis_type, list_of_questions, is_summary, is_questions = ( + self._validate_analysis_type( + analysis_type, list_of_questions, max_questions_per_image + ) + ) + + if is_summary: + try: + caps = self.generate_caption( + entry, + num_return_sequences=1, + is_concise_summary=is_concise_summary, + ) + self.subdict["caption"] = caps[0] if caps else "" + except Exception as e: + warnings.warn(f"Caption generation failed: {e}") + + if is_questions: + try: + vqa_map = self.answer_questions( + list_of_questions, entry, is_concise_answer + ) + self.subdict["vqa"] = vqa_map + except Exception as e: + warnings.warn(f"VQA failed: {e}") + + return self.subdict + + def analyse_images_from_dict( self, analysis_type: Union[AnalysisType, str] = AnalysisType.SUMMARY_AND_QUESTIONS, list_of_questions: Optional[List[str]] = None, @@ -191,9 +236,7 @@ class ImageSummaryDetector(AnalysisMethod): ) entry["caption"] = caps[0] if caps else "" except Exception as e: - warnings.warn( - "Caption generation failed for key %s: %s", key, e - ) + warnings.warn(f"Caption generation failed: {e}") if is_questions: try: @@ -202,7 +245,7 @@ class ImageSummaryDetector(AnalysisMethod): ) entry["vqa"] = vqa_map except Exception as e: - warnings.warn("VQA failed for key %s: %s", key, e) + warnings.warn(f"VQA failed: {e}") self.subdict[key] = entry return self.subdict @@ -251,8 +294,7 @@ class ImageSummaryDetector(AnalysisMethod): ) except RuntimeError as e: warnings.warn( - "Retry without autocast failed: %s. Attempting cudnn-disabled retry.", - e, + f"Retry without autocast failed: {e}. Attempting cudnn-disabled retry." ) cudnn_was_enabled = ( torch.backends.cudnn.is_available() and torch.backends.cudnn.enabled From 75b9bc101bd122232ceab01b598423b86b3f939b Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 19 Sep 2025 10:33:59 +0200 Subject: [PATCH 11/23] [pre-commit.ci] pre-commit autoupdate (#265) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit updates: - [github.com/astral-sh/ruff-pre-commit: v0.12.10 → v0.13.0](https://github.com/astral-sh/ruff-pre-commit/compare/v0.12.10...v0.13.0) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .pre-commit-config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 7b28fc9..6dd7854 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -6,7 +6,7 @@ repos: files: ".ipynb" - repo: https://github.com/astral-sh/ruff-pre-commit # Ruff version. - rev: v0.12.10 + rev: v0.13.0 hooks: # Run the linter. - id: ruff-check From 32d032595d1632fb77263c6e1407cf1856a10d11 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 23 Sep 2025 09:53:38 +0200 Subject: [PATCH 12/23] [pre-commit.ci] pre-commit autoupdate (#267) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit updates: - [github.com/astral-sh/ruff-pre-commit: v0.13.0 → v0.13.1](https://github.com/astral-sh/ruff-pre-commit/compare/v0.13.0...v0.13.1) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .pre-commit-config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 6dd7854..d879bc0 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -6,7 +6,7 @@ repos: files: ".ipynb" - repo: https://github.com/astral-sh/ruff-pre-commit # Ruff version. - rev: v0.13.0 + rev: v0.13.1 hooks: # Run the linter. - id: ruff-check From 483f128f9608c5c376c8334f2fc8dcbf4af403c8 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 9 Oct 2025 09:26:18 +0200 Subject: [PATCH 13/23] [pre-commit.ci] pre-commit autoupdate (#269) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit updates: - [github.com/astral-sh/ruff-pre-commit: v0.13.1 → v0.13.3](https://github.com/astral-sh/ruff-pre-commit/compare/v0.13.1...v0.13.3) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .pre-commit-config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index d879bc0..03d5cf5 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -6,7 +6,7 @@ repos: files: ".ipynb" - repo: https://github.com/astral-sh/ruff-pre-commit # Ruff version. - rev: v0.13.1 + rev: v0.13.3 hooks: # Run the linter. - id: ruff-check From 3018800ed47843c88fc9956f2755bec399ee37f6 Mon Sep 17 00:00:00 2001 From: DimasfromLavoisier Date: Thu, 9 Oct 2025 17:15:32 +0200 Subject: [PATCH 14/23] update test-display --- ammico/display.py | 6 ++++++ ammico/test/test_display.py | 3 +++ 2 files changed, 9 insertions(+) diff --git a/ammico/display.py b/ammico/display.py index b916dbf..5b860a7 100644 --- a/ammico/display.py +++ b/ammico/display.py @@ -100,6 +100,7 @@ class AnalysisExplorer: State("Dropdown_select_Detector", "value"), State("Dropdown_analysis_type", "value"), State("textarea_questions", "value"), + State("setting_Text_analyse_text", "value"), State("setting_privacy_env_var", "value"), State("setting_Emotion_emotion_threshold", "value"), State("setting_Emotion_race_threshold", "value"), @@ -465,6 +466,7 @@ class AnalysisExplorer: detector_value: str, analysis_type_value: str, textarea_questions_value: str, + settings_text_analyse_text: list, setting_privacy_env_var: str, setting_emotion_emotion_threshold: int, setting_emotion_race_threshold: int, @@ -519,8 +521,12 @@ class AnalysisExplorer: identify_function = identify_dict[detector_value] if detector_value == "TextDetector": + analyse_text = ( + True if settings_text_analyse_text == ["Analyse text"] else False + ) detector_class = identify_function( image_copy, + analyse_text=analyse_text, accept_privacy=( setting_privacy_env_var if setting_privacy_env_var diff --git a/ammico/test/test_display.py b/ammico/test/test_display.py index d7241a2..3cdb333 100644 --- a/ammico/test/test_display.py +++ b/ammico/test/test_display.py @@ -50,6 +50,9 @@ def test_right_output_analysis_emotions(get_AE, get_options, monkeypatch): get_options[3], get_options[0], "EmotionDetector", + "summary", + "Some question", + True, "SOME_VAR", 50, 50, From d810dbc3669d640098a4135f6c06c25845729539 Mon Sep 17 00:00:00 2001 From: DimasfromLavoisier Date: Fri, 10 Oct 2025 17:05:48 +0200 Subject: [PATCH 15/23] add base model tests --- ammico/model.py | 15 +++++++++++++++ ammico/test/conftest.py | 10 ++++++++++ ammico/test/test_model.py | 27 +++++++++++++++++++++++++++ pyproject.toml | 3 +-- 4 files changed, 53 insertions(+), 2 deletions(-) create mode 100644 ammico/test/test_model.py diff --git a/ammico/model.py b/ammico/model.py index 80cc31f..cdc1161 100644 --- a/ammico/model.py +++ b/ammico/model.py @@ -27,6 +27,15 @@ class MultimodalSummaryModel: cache_dir: huggingface cache dir (optional). """ self.device = self._resolve_device(device) + + if model_id is not None and model_id not in ( + self.DEFAULT_CUDA_MODEL, + self.DEFAULT_CPU_MODEL, + ): + raise ValueError( + f"model_id must be one of {self.DEFAULT_CUDA_MODEL} or {self.DEFAULT_CPU_MODEL}" + ) + self.model_id = model_id or ( self.DEFAULT_CUDA_MODEL if self.device == "cuda" else self.DEFAULT_CPU_MODEL ) @@ -94,6 +103,12 @@ class MultimodalSummaryModel: if self.model is not None: del self.model self.model = None + if self.processor is not None: + del self.processor + self.processor = None + if self.tokenizer is not None: + del self.tokenizer + self.tokenizer = None finally: try: if torch.cuda.is_available(): diff --git a/ammico/test/conftest.py b/ammico/test/conftest.py index cb42774..2010e1e 100644 --- a/ammico/test/conftest.py +++ b/ammico/test/conftest.py @@ -1,5 +1,6 @@ import os import pytest +from ammico.model import MultimodalSummaryModel @pytest.fixture @@ -46,3 +47,12 @@ def get_test_my_dict(get_path): }, } return test_my_dict + + +@pytest.fixture(scope="session") +def model(): + m = MultimodalSummaryModel(device="cpu") + try: + yield m + finally: + m.close() diff --git a/ammico/test/test_model.py b/ammico/test/test_model.py new file mode 100644 index 0000000..ac652c0 --- /dev/null +++ b/ammico/test/test_model.py @@ -0,0 +1,27 @@ +import pytest +import torch +from ammico.model import MultimodalSummaryModel + + +def test_model_init(model): + assert model.model is not None + assert model.processor is not None + assert model.tokenizer is not None + assert model.device is not None + + +def test_model_invalid_device(): + with pytest.raises(ValueError): + MultimodalSummaryModel(device="invalid_device") + + +def test_model_invalid_model_id(): + with pytest.raises(ValueError): + MultimodalSummaryModel(model_id="non_existent_model", device="cpu") + + +def test_free_resources(): + model = MultimodalSummaryModel(device="cpu") + model.close() + assert model.model is None + assert model.processor is None diff --git a/pyproject.toml b/pyproject.toml index c0f5440..4f29217 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,10 +26,9 @@ dependencies = [ "colour-science", "dash", "dash-bootstrap-components", - "decord", "deepface", "google-cloud-vision", - "googletrans-py", # instead of googletrans4.0.0rc1, for a temporary solution due the incompatibility with jupyterlab + "googletrans-py", # instead of googletrans4.0.0rc1, for a temporary solution due to incompatibility with jupyterlab "grpcio", "huggingface-hub>=0.34.0", "importlib_metadata", From d6e0fbeffe6841839d0ae6114f08525b43567631 Mon Sep 17 00:00:00 2001 From: DimasfromLavoisier Date: Mon, 13 Oct 2025 13:51:24 +0200 Subject: [PATCH 16/23] add vqa tests --- ammico/test/test_image_summary.py | 37 +++++++++++++++++++++++++++++++ ammico/test/test_model.py | 5 ++++- pyproject.toml | 2 +- 3 files changed, 42 insertions(+), 2 deletions(-) create mode 100644 ammico/test/test_image_summary.py diff --git a/ammico/test/test_image_summary.py b/ammico/test/test_image_summary.py new file mode 100644 index 0000000..ad48298 --- /dev/null +++ b/ammico/test/test_image_summary.py @@ -0,0 +1,37 @@ +from ammico.image_summary import ImageSummaryDetector + +import pytest + + +@pytest.mark.long +def test_image_summary_detector(model, get_testdict): + detector = ImageSummaryDetector(summary_model=model, subdict=get_testdict) + results = detector.analyse_images_from_dict(analysis_type="summary") + assert len(results) == 2 + for key in get_testdict.keys(): + assert key in results + assert "caption" in results[key] + assert isinstance(results[key]["caption"], str) + assert len(results[key]["caption"]) > 0 + + +@pytest.mark.long +def test_image_summary_detector_questions(model, get_testdict): + list_of_questions = [ + "What is happening in the image?", + "How many cars are in the image in total?", + ] + detector = ImageSummaryDetector(summary_model=model, subdict=get_testdict) + results = detector.analyse_images_from_dict( + analysis_type="questions", list_of_questions=list_of_questions + ) + assert len(results) == 2 + for key in get_testdict.keys(): + assert "vqa" in results[key] + if key == "IMG_2746": + assert "marathon" in results[key]["vqa"][0].lower() + + if key == "IMG_2809": + assert ( + "two" in results[key]["vqa"][1].lower() or "2" in results[key]["vqa"][1] + ) diff --git a/ammico/test/test_model.py b/ammico/test/test_model.py index ac652c0..d82dd86 100644 --- a/ammico/test/test_model.py +++ b/ammico/test/test_model.py @@ -1,8 +1,8 @@ import pytest -import torch from ammico.model import MultimodalSummaryModel +@pytest.mark.long def test_model_init(model): assert model.model is not None assert model.processor is not None @@ -10,16 +10,19 @@ def test_model_init(model): assert model.device is not None +@pytest.mark.long def test_model_invalid_device(): with pytest.raises(ValueError): MultimodalSummaryModel(device="invalid_device") +@pytest.mark.long def test_model_invalid_model_id(): with pytest.raises(ValueError): MultimodalSummaryModel(model_id="non_existent_model", device="cpu") +@pytest.mark.long def test_free_resources(): model = MultimodalSummaryModel(device="cpu") model.close() diff --git a/pyproject.toml b/pyproject.toml index 4f29217..1ead16e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,7 +39,7 @@ dependencies = [ "pandas", "Pillow", "pooch", - "qwen-vl-utils[decord]==0.0.8", + "qwen-vl-utils", "retina_face", "safetensors>=0.6.2", "setuptools", From af97981547b2a88609eb9ceafe4e74e93081f085 Mon Sep 17 00:00:00 2001 From: Dmitrii Kapitan Date: Tue, 14 Oct 2025 11:31:27 +0200 Subject: [PATCH 17/23] Excluding `long` tests from github actions, since there is not enough memory for it --- .github/workflows/ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 015a8ee..1b1675a 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -31,7 +31,7 @@ jobs: - name: Run pytest run: | cd ammico - python -m pytest -svv -m "not gcv" --cov=. --cov-report=xml + python -m pytest -svv -m "not gcv and not long" --cov=. --cov-report=xml - name: Upload coverage if: matrix.os == 'ubuntu-latest' && matrix.python-version == '3.11' uses: codecov/codecov-action@v3 From a1b30f433152d25700ee9b511068f5c4778622b9 Mon Sep 17 00:00:00 2001 From: DimasfromLavoisier Date: Fri, 5 Sep 2025 12:04:31 +0200 Subject: [PATCH 18/23] first, brute-force version of video analysis --- ammico/video_summary.py | 408 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 408 insertions(+) create mode 100644 ammico/video_summary.py diff --git a/ammico/video_summary.py b/ammico/video_summary.py new file mode 100644 index 0000000..1f3f8c5 --- /dev/null +++ b/ammico/video_summary.py @@ -0,0 +1,408 @@ +import decord +import os +import re +import math +import torch +import warnings +import numpy as np +from PIL import Image + +from ammico.model import MultimodalSummaryModel +from ammico.utils import AnalysisMethod, AnalysisType + +from typing import List, Optional, Union, Dict, Any, Generator, Tuple +from transformers import GenerationConfig + + +class VideoSummaryDetector(AnalysisMethod): + def __init__( + self, + summary_model: MultimodalSummaryModel, + subdict: dict = {}, + gpu_id: int = 0, + ) -> None: + """ + Class for analysing videos using QWEN-2.5-VL model. + It provides methods for generating captions and answering questions about videos. + + Args: + summary_model ([type], optional): An instance of MultimodalSummaryModel to be used for analysis. + subdict (dict, optional): Dictionary containing the video to be analysed. Defaults to {}. + + Returns: + None. + """ + + super().__init__(subdict) + self.summary_model = summary_model + self.gpu_id = gpu_id + + def _normalize_whitespace(self, s: str) -> str: + return re.sub(r"\s+", " ", s).strip() + + def _strip_prompt_prefix_literal(self, decoded: str, prompt: str) -> str: + """ + Remove any literal prompt prefix from decoded text using a normalized-substring match. + Guarantees no prompt text remains at the start of returned string (best-effort). + """ + if not decoded: + return "" + if not prompt: + return decoded.strip() + + d_norm = self._normalize_whitespace(decoded) + p_norm = self._normalize_whitespace(prompt) + + idx = d_norm.find(p_norm) + if idx != -1: + running = [] + for i, ch in enumerate(decoded): + running.append(ch if not ch.isspace() else " ") + cur_norm = self._normalize_whitespace("".join(running)) + if cur_norm.endswith(p_norm): + return decoded[i + 1 :].lstrip() if i + 1 < len(decoded) else "" + m = re.match( + r"^(?:\s*(system|user|assistant)[:\s-]*\n?)+", decoded, flags=re.IGNORECASE + ) + if m: + return decoded[m.end() :].lstrip() + + return decoded.lstrip("\n\r ").lstrip(":;- ").strip() + + def _decode_trimmed_outputs( + self, + generated_ids: torch.Tensor, + inputs: Dict[str, torch.Tensor], + tokenizer, + prompt_texts: List[str], + ) -> List[str]: + """ + Trim prompt tokens using attention_mask/input_ids when available and decode to strings. + Then remove any literal prompt prefix using prompt_texts (one per batch element). + """ + + decoded_results = [] + batch_size = generated_ids.shape[0] + + if "input_ids" in inputs: + lengths = ( + inputs["input_ids"] + .ne( + tokenizer.pad_token_id + if tokenizer.pad_token_id is not None + else tokenizer.eos_token_id + ) + .sum(dim=1) + .tolist() + ) + else: + lengths = [0] * batch_size + + trimmed_ids = [] + for i in range(batch_size): + out_ids = generated_ids[i] + in_len = int(lengths[i]) if i < len(lengths) else 0 + if out_ids.size(0) > in_len: + t = out_ids[in_len:] + else: + t = out_ids.new_empty((0,), dtype=out_ids.dtype) + trimmed_ids.append(t) + + decoded = tokenizer.batch_decode( + trimmed_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True + ) + for ptext, raw in zip(prompt_texts, decoded): + cleaned = self._strip_prompt_prefix_literal(raw, ptext) + decoded_results.append(cleaned) + return decoded_results + + def _generate_from_processor_inputs( + self, + processor_inputs: Dict[str, torch.Tensor], + prompt_texts: List[str], + model, + tokenizer, + ): + """ + Run model.generate on already-processed processor_inputs (tensors moved to device), + then decode and trim prompt tokens & remove literal prompt prefixes using prompt_texts. + """ + gen_conf = GenerationConfig( + max_new_tokens=64, + do_sample=False, + num_return_sequences=1, + ) + + for k, v in list(processor_inputs.items()): + if isinstance(v, torch.Tensor): + processor_inputs[k] = v.to(model.device) + + with torch.inference_mode(): + try: + if self.summary_model.device == "cuda": + with torch.cuda.amp.autocast(enabled=True): + generated_ids = self.summary_model.model.generate( + **processor_inputs, generation_config=gen_conf + ) + else: + generated_ids = self.summary_model.model.generate( + **processor_inputs, generation_config=gen_conf + ) + except RuntimeError as e: + warnings.warn( + "Retry without autocast failed: %s. Attempting cudnn-disabled retry.", + e, + ) + cudnn_was_enabled = ( + torch.backends.cudnn.is_available() and torch.backends.cudnn.enabled + ) + if cudnn_was_enabled: + torch.backends.cudnn.enabled = False + try: + generated_ids = self.summary_model.model.generate( + **processor_inputs, generation_config=gen_conf + ) + except Exception as retry_error: + raise RuntimeError( + f"Failed to generate ids after retry: {retry_error}" + ) from retry_error + finally: + if cudnn_was_enabled: + torch.backends.cudnn.enabled = True + + decoded = self._decode_trimmed_outputs( + generated_ids, processor_inputs, tokenizer, prompt_texts + ) + return decoded + + def _tensor_batch_to_pil_list(self, batch: torch.Tensor) -> List[Image.Image]: + """ + Convert a uint8 torch tensor batch (B, C, H, W) on CPU to list of PIL images (RGB). + The conversion is done on CPU and returns PIL.Image objects. + """ + if batch.device.type != "cpu": + batch = batch.to("cpu") + + batch = batch.contiguous() + if batch.dtype != torch.uint8: + batch = batch.to(torch.uint8) + pil_list: List[Image.Image] = [] + for frame in batch: + arr = frame.permute(1, 2, 0).numpy() + pil_list.append(Image.fromarray(arr)) + return pil_list + + def _extract_video_frames( + self, + video_path: Union[str, os.PathLike], + frame_rate_per_second: float = 2, + batch_size: int = 32, + ) -> Dict[str, Any]: + """ + Extract frames from a video at a specified frame rate and return them as a generator of batches. + Args: + video_path (Union[str, os.PathLike]): Path to the video file. + frame_rate_per_second (float, optional): Frame extraction rate in frames per second. Default is 2. + batch_size (int, optional): Number of frames to include in each batch. Default is 32. + Returns: + Dict[str, Any]: A dictionary containing a generator that yields batches of frames and their timestamps + and the total number of extracted frames. + """ + + device = ( + torch.device("cuda") if (torch.cuda.is_available()) else torch.device("cpu") + ) + if device == "cuda": + ctx = decord.gpu(self.gpu_id) + else: + ctx = decord.cpu() + # TODO: to support GPU version of decord: build from source to enable GPU acclerator + # https://github.com/dmlc/decord + + vr = decord.VideoReader(video_path, ctx=ctx) + nframes = len(vr) + video_fps = vr.get_avg_fps() + if video_fps is None or video_fps <= 0: + video_fps = 30.0 + + duration = nframes / float(video_fps) + + if frame_rate_per_second <= 0: + raise ValueError("frame_rate_per_second must be > 0") + + n_samples = max(1, int(math.floor(duration * frame_rate_per_second))) + sample_times = ( + torch.linspace(0, duration, steps=n_samples) + if n_samples > 1 + else torch.tensor([0.0]) + ) + indices = (sample_times * video_fps).round().long() + indices = torch.clamp(indices, 0, nframes - 1).unique(sorted=True) + timestamps = indices.to(torch.float32) / float(video_fps) + + total_samples = indices.numel() + + def gen() -> Generator[Tuple[torch.Tensor, torch.Tensor], None, None]: + for batch_start in range(0, total_samples, batch_size): + batch_idx_tensor = indices[batch_start : batch_start + batch_size] + batch_idx_list = [int(x.item()) for x in batch_idx_tensor] + batch_frames_np = vr.get_batch(batch_idx_list).asnumpy() + batch_frames = ( + torch.from_numpy(batch_frames_np).permute(0, 3, 1, 2).contiguous() + ) + batch_times = timestamps[batch_start : batch_start + batch_size] + if device is not None: + batch_frames = batch_frames.to(device, non_blocking=True) + batch_times = batch_times.to(device, non_blocking=True) + yield batch_frames, batch_times + + return {"generator": gen(), "n_frames": total_samples} + + def brute_force_summary( + self, + extracted_video_dict: Dict[str, Any], + summary_instruction: str = "Summarize the following frame captions into a concise paragraph (1-3 sentences):", + ) -> Dict[str, Any]: + """ + Generate captions for all extracted frames and then produce a concise summary of the video. + Args: + extracted_video_dict (Dict[str, Any]): Dictionary containing the frame generator and number of frames. + summary_instruction (str, optional): Instruction for summarizing the captions. Defaults to a concise paragraph. + Returns: + Dict[str, Any]: A dictionary containing the list of captions with timestamps and the final summary. + """ + + gen = extracted_video_dict["generator"] + caption_instruction = "Describe this image in one concise caption." + collected: List[Tuple[float, str]] = [] + + for batch_frames, batch_times in gen: + pil_list = self._tensor_batch_to_pil_list(batch_frames.cpu()) + proc = self.summary_model.processor + + prompt_texts = [] + for p in pil_list: + messages = [ + { + "role": "user", + "content": [ + {"type": "image", "image": p}, + {"type": "text", "text": caption_instruction}, + ], + } + ] + try: + prompt_text = proc.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) + except TypeError: + prompt_text = proc.apply_chat_template(messages) + prompt_texts.append(prompt_text) + + processor_inputs = proc( + text=prompt_texts, images=pil_list, return_tensors="pt", padding=True + ) + captions = self._generate_from_processor_inputs( + processor_inputs, + prompt_texts, + self.summary_model, + self.summary_model.tokenizer, + ) + + batch_times_cpu = ( + batch_times.cpu().tolist() + if isinstance(batch_times, torch.Tensor) + else list(batch_times) + ) + for t, c in zip(batch_times_cpu, captions): + collected.append((float(t), c)) + + collected.sort(key=lambda x: x[0]) + + MAX_CAPTIONS_FOR_SUMMARY = 200 + caps_for_summary = ( + collected[-MAX_CAPTIONS_FOR_SUMMARY:] + if len(collected) > MAX_CAPTIONS_FOR_SUMMARY + else collected + ) + + bullets = [] + for t, c in caps_for_summary: + snippet = c.replace("\n", " ").strip() + bullets.append(f"- [{t:.3f}s] {snippet}") + + combined_captions_text = "\n".join(bullets) + summary_user_text = ( + summary_instruction + + "\n\n" + + combined_captions_text + + "\n\nPlease produce a single concise paragraph." + ) + + proc = self.summary_model.processor + if hasattr(proc, "apply_chat_template"): + messages = [ + { + "role": "user", + "content": [{"type": "text", "text": summary_user_text}], + } + ] + try: + summary_prompt_text = proc.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) + except TypeError: + summary_prompt_text = proc.apply_chat_template(messages) + summary_inputs = proc( + text=[summary_prompt_text], return_tensors="pt", padding=True + ) + else: + summary_prompt_text = summary_user_text + summary_inputs = self.summary_model.tokenizer( + summary_prompt_text, return_tensors="pt" + ) + + summary_inputs = { + k: v.to(self.summary_model.device) if isinstance(v, torch.Tensor) else v + for k, v in summary_inputs.items() + } + final_summary_list = self._generate_from_processor_inputs( + summary_inputs, + [summary_prompt_text], + self.summary_model.model, + self.summary_model.tokenizer, + ) + final_summary = final_summary_list[0].strip() if final_summary_list else "" + + return { + "captions": collected, + "summary": final_summary, + } + + def analyse_videos(self, frame_rate_per_second: float = 2.0) -> Dict[str, Any]: + """ + Analyse the video specified in self.subdict using frame extraction and captioning. + For short videos (<=50 frames at the specified frame rate), it uses brute-force captioning. + For longer videos, it currently defaults to brute-force captioning but can be extended for more complex methods. + + Args: + frame_rate_per_second (float): Frame extraction rate in frames per second. Default is 2.0. + Returns: + Dict[str, Any]: A dictionary containing the analysis results, including captions and summary. + """ + + minimal_edge_of_frames = 50 + extracted_video_dict = self._extract_video_frames( + self.subdict["video_path"], frame_rate_per_second=frame_rate_per_second + ) + if extracted_video_dict["n_frames"] <= minimal_edge_of_frames: + answer = self.brute_force_summary(extracted_video_dict) + + else: + # TODO: implement processing for long videos + summary_instruction = "Describe this image in a single caption, including all important details." + answer = self.brute_force_summary( + extracted_video_dict, summary_instruction=summary_instruction + ) + + return answer From 83cfff9dcebf93e25f85c548c6354cf392724491 Mon Sep 17 00:00:00 2001 From: DimasfromLavoisier Date: Fri, 26 Sep 2025 16:13:47 +0200 Subject: [PATCH 19/23] add example --- ammico/__init__.py | 6 +- ammico/notebooks/DemoVideoSummaryVQA.ipynb | 96 ++++++++ ammico/utils.py | 26 ++- ammico/video_summary.py | 255 +++++++++++---------- 4 files changed, 258 insertions(+), 125 deletions(-) create mode 100644 ammico/notebooks/DemoVideoSummaryVQA.ipynb diff --git a/ammico/__init__.py b/ammico/__init__.py index 1bf343d..5554ee0 100644 --- a/ammico/__init__.py +++ b/ammico/__init__.py @@ -3,7 +3,8 @@ from ammico.faces import EmotionDetector, ethical_disclosure from ammico.model import MultimodalSummaryModel from ammico.text import TextDetector, TextAnalyzer, privacy_disclosure from ammico.image_summary import ImageSummaryDetector -from ammico.utils import find_files, get_dataframe +from ammico.utils import find_files, get_dataframe, AnalysisType, find_videos +from ammico.video_summary import VideoSummaryDetector # Export the version defined in project metadata try: @@ -14,13 +15,16 @@ except ImportError: __version__ = "unknown" __all__ = [ + "AnalysisType", "AnalysisExplorer", "EmotionDetector", "MultimodalSummaryModel", "TextDetector", "TextAnalyzer", "ImageSummaryDetector", + "VideoSummaryDetector", "find_files", + "find_videos", "get_dataframe", "ethical_disclosure", "privacy_disclosure", diff --git a/ammico/notebooks/DemoVideoSummaryVQA.ipynb b/ammico/notebooks/DemoVideoSummaryVQA.ipynb new file mode 100644 index 0000000..2e80165 --- /dev/null +++ b/ammico/notebooks/DemoVideoSummaryVQA.ipynb @@ -0,0 +1,96 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Video summary and visual question answering" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import ammico" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Currently this module supports only video summarization, but it will be updated in the nearest future" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "video_dict = ammico.find_videos(\n", + " path=str(\"/insert/your/path/here/\"), # path to the folder with images\n", + " limit=-1, # -1 means no limit on the number of files, by default it is set to 20\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "model = ammico.MultimodalSummaryModel()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "vid_summary_model = ammico.VideoSummaryDetector(summary_model=model, subdict=video_dict)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "summary_dict = vid_summary_model.analyse_video()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "summary_dict[\"summary\"]" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "ammico-dev", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.13" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/ammico/utils.py b/ammico/utils.py index 38f4144..d698708 100644 --- a/ammico/utils.py +++ b/ammico/utils.py @@ -5,8 +5,6 @@ import pooch import importlib_resources import collections import random - - from enum import Enum @@ -103,6 +101,30 @@ def _limit_results(results, limit): return results +def find_videos( + path: str = None, + pattern=["mp4"], # TODO: test with more video formats + recursive: bool = True, + limit=5, + random_seed: int = None, +) -> dict: + """Find video files on the file system.""" + if path is None: + path = os.environ.get("AMMICO_DATA_HOME", ".") + if isinstance(pattern, str): + pattern = [pattern] + results = [] + for p in pattern: + results.extend(_match_pattern(path, p, recursive=recursive)) + if len(results) == 0: + raise FileNotFoundError(f"No files found in {path} with pattern '{pattern}'") + if random_seed is not None: + random.seed(random_seed) + random.shuffle(results) + videos = _limit_results(results, limit) + return initialize_dict(videos) + + def find_files( path: str = None, pattern=["png", "jpg", "jpeg", "gif", "webp", "avif", "tiff"], diff --git a/ammico/video_summary.py b/ammico/video_summary.py index 1f3f8c5..52890dd 100644 --- a/ammico/video_summary.py +++ b/ammico/video_summary.py @@ -1,16 +1,14 @@ import decord -import os import re import math import torch import warnings -import numpy as np from PIL import Image from ammico.model import MultimodalSummaryModel -from ammico.utils import AnalysisMethod, AnalysisType +from ammico.utils import AnalysisMethod -from typing import List, Optional, Union, Dict, Any, Generator, Tuple +from typing import List, Optional, Dict, Any, Generator, Tuple from transformers import GenerationConfig @@ -19,7 +17,6 @@ class VideoSummaryDetector(AnalysisMethod): self, summary_model: MultimodalSummaryModel, subdict: dict = {}, - gpu_id: int = 0, ) -> None: """ Class for analysing videos using QWEN-2.5-VL model. @@ -35,7 +32,89 @@ class VideoSummaryDetector(AnalysisMethod): super().__init__(subdict) self.summary_model = summary_model - self.gpu_id = gpu_id + + def _frame_batch_generator( + self, + indices: torch.Tensor, + timestamps: torch.Tensor, + batch_size: int, + vr, + ) -> Generator[Tuple[torch.Tensor, torch.Tensor], None, None]: + """ + Yield batches of (frames, timestamps) for given frame indices. + - frames are returned as a torch.Tensor with shape (B, C, H, W). + - timestamps is a 1D torch.Tensor with B elements. + """ + total = int(indices.numel()) + device = torch.device("cpu") + + for start in range(0, total, batch_size): + batch_idx_tensor = indices[start : start + batch_size] + # convert to python ints for decord API + batch_idx_list = [int(x.item()) for x in batch_idx_tensor] + + # decord returns ndarray-like object; keep memory layout minimal and convert once + batch_frames_np = vr.get_batch(batch_idx_list).asnumpy() + + # convert to CHW torch layout + batch_frames = ( + torch.from_numpy(batch_frames_np).permute(0, 3, 1, 2).contiguous() + ).to(device, non_blocking=True) + + batch_times = timestamps[start : start + batch_size].to( + device, non_blocking=True + ) + + yield batch_frames, batch_times + + def _extract_video_frames( + self, + entry: Optional[Dict[str, Any]], + frame_rate_per_second: float = 2, + batch_size: int = 32, + ) -> Dict[str, Any]: + """ + Extract frames from a video at a specified frame rate and return them as a generator of batches. + Args: + filename (Union[str, os.PathLike]): Path to the video file. + frame_rate_per_second (float, optional): Frame extraction rate in frames per second. Default is 2. + batch_size (int, optional): Number of frames to include in each batch. Default is 32. + Returns: + Dict[str, Any]: A dictionary containing a generator that yields batches of frames and their timestamps + and the total number of extracted frames. + """ + + filename = entry.get("filename") + if not filename: + raise ValueError("entry must contain key 'filename'") + + # TODO: consider using torchcodec for video decoding, since decord is no longer actively maintained + vr = decord.VideoReader(filename) + + nframes = len(vr) + video_fps = vr.get_avg_fps() + if video_fps is None or video_fps <= 0: + video_fps = 30.0 + + duration = nframes / float(video_fps) + + if frame_rate_per_second <= 0: + raise ValueError("frame_rate_per_second must be > 0") + + n_samples = max(1, int(math.floor(duration * frame_rate_per_second))) + sample_times = ( + torch.linspace(0, duration, steps=n_samples) + if n_samples > 1 + else torch.tensor([0.0]) + ) + indices = (sample_times * video_fps).round().long() + indices = torch.clamp(indices, 0, nframes - 1).unique(sorted=True) + timestamps = indices.to(torch.float32) / float(video_fps) + + total_samples = int(indices.numel()) + generator = self._frame_batch_generator(indices, timestamps, batch_size, vr) + + return {"generator": generator, "n_frames": total_samples} def _normalize_whitespace(self, s: str) -> str: return re.sub(r"\s+", " ", s).strip() @@ -192,72 +271,6 @@ class VideoSummaryDetector(AnalysisMethod): pil_list.append(Image.fromarray(arr)) return pil_list - def _extract_video_frames( - self, - video_path: Union[str, os.PathLike], - frame_rate_per_second: float = 2, - batch_size: int = 32, - ) -> Dict[str, Any]: - """ - Extract frames from a video at a specified frame rate and return them as a generator of batches. - Args: - video_path (Union[str, os.PathLike]): Path to the video file. - frame_rate_per_second (float, optional): Frame extraction rate in frames per second. Default is 2. - batch_size (int, optional): Number of frames to include in each batch. Default is 32. - Returns: - Dict[str, Any]: A dictionary containing a generator that yields batches of frames and their timestamps - and the total number of extracted frames. - """ - - device = ( - torch.device("cuda") if (torch.cuda.is_available()) else torch.device("cpu") - ) - if device == "cuda": - ctx = decord.gpu(self.gpu_id) - else: - ctx = decord.cpu() - # TODO: to support GPU version of decord: build from source to enable GPU acclerator - # https://github.com/dmlc/decord - - vr = decord.VideoReader(video_path, ctx=ctx) - nframes = len(vr) - video_fps = vr.get_avg_fps() - if video_fps is None or video_fps <= 0: - video_fps = 30.0 - - duration = nframes / float(video_fps) - - if frame_rate_per_second <= 0: - raise ValueError("frame_rate_per_second must be > 0") - - n_samples = max(1, int(math.floor(duration * frame_rate_per_second))) - sample_times = ( - torch.linspace(0, duration, steps=n_samples) - if n_samples > 1 - else torch.tensor([0.0]) - ) - indices = (sample_times * video_fps).round().long() - indices = torch.clamp(indices, 0, nframes - 1).unique(sorted=True) - timestamps = indices.to(torch.float32) / float(video_fps) - - total_samples = indices.numel() - - def gen() -> Generator[Tuple[torch.Tensor, torch.Tensor], None, None]: - for batch_start in range(0, total_samples, batch_size): - batch_idx_tensor = indices[batch_start : batch_start + batch_size] - batch_idx_list = [int(x.item()) for x in batch_idx_tensor] - batch_frames_np = vr.get_batch(batch_idx_list).asnumpy() - batch_frames = ( - torch.from_numpy(batch_frames_np).permute(0, 3, 1, 2).contiguous() - ) - batch_times = timestamps[batch_start : batch_start + batch_size] - if device is not None: - batch_frames = batch_frames.to(device, non_blocking=True) - batch_times = batch_times.to(device, non_blocking=True) - yield batch_frames, batch_times - - return {"generator": gen(), "n_frames": total_samples} - def brute_force_summary( self, extracted_video_dict: Dict[str, Any], @@ -275,10 +288,10 @@ class VideoSummaryDetector(AnalysisMethod): gen = extracted_video_dict["generator"] caption_instruction = "Describe this image in one concise caption." collected: List[Tuple[float, str]] = [] + proc = self.summary_model.processor for batch_frames, batch_times in gen: pil_list = self._tensor_batch_to_pil_list(batch_frames.cpu()) - proc = self.summary_model.processor prompt_texts = [] for p in pil_list: @@ -291,12 +304,10 @@ class VideoSummaryDetector(AnalysisMethod): ], } ] - try: - prompt_text = proc.apply_chat_template( - messages, tokenize=False, add_generation_prompt=True - ) - except TypeError: - prompt_text = proc.apply_chat_template(messages) + + prompt_text = proc.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) prompt_texts.append(prompt_text) processor_inputs = proc( @@ -309,15 +320,16 @@ class VideoSummaryDetector(AnalysisMethod): self.summary_model.tokenizer, ) - batch_times_cpu = ( - batch_times.cpu().tolist() - if isinstance(batch_times, torch.Tensor) - else list(batch_times) - ) - for t, c in zip(batch_times_cpu, captions): + # normalize batch_times to Python floats + if isinstance(batch_times, torch.Tensor): + batch_times_list = batch_times.cpu().tolist() + else: + batch_times_list = list(batch_times) + for t, c in zip(batch_times_list, captions): collected.append((float(t), c)) collected.sort(key=lambda x: x[0]) + gen.close() MAX_CAPTIONS_FOR_SUMMARY = 200 caps_for_summary = ( @@ -339,28 +351,20 @@ class VideoSummaryDetector(AnalysisMethod): + "\n\nPlease produce a single concise paragraph." ) - proc = self.summary_model.processor - if hasattr(proc, "apply_chat_template"): - messages = [ - { - "role": "user", - "content": [{"type": "text", "text": summary_user_text}], - } - ] - try: - summary_prompt_text = proc.apply_chat_template( - messages, tokenize=False, add_generation_prompt=True - ) - except TypeError: - summary_prompt_text = proc.apply_chat_template(messages) - summary_inputs = proc( - text=[summary_prompt_text], return_tensors="pt", padding=True - ) - else: - summary_prompt_text = summary_user_text - summary_inputs = self.summary_model.tokenizer( - summary_prompt_text, return_tensors="pt" - ) + messages = [ + { + "role": "user", + "content": [{"type": "text", "text": summary_user_text}], + } + ] + + summary_prompt_text = proc.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) + + summary_inputs = proc( + text=[summary_prompt_text], return_tensors="pt", padding=True + ) summary_inputs = { k: v.to(self.summary_model.device) if isinstance(v, torch.Tensor) else v @@ -379,11 +383,11 @@ class VideoSummaryDetector(AnalysisMethod): "summary": final_summary, } - def analyse_videos(self, frame_rate_per_second: float = 2.0) -> Dict[str, Any]: + def analyse_video(self, frame_rate_per_second: float = 2.0) -> Dict[str, Any]: """ Analyse the video specified in self.subdict using frame extraction and captioning. - For short videos (<=50 frames at the specified frame rate), it uses brute-force captioning. - For longer videos, it currently defaults to brute-force captioning but can be extended for more complex methods. + For short videos (<=100 frames at the specified frame rate), it uses brute-force captioning. + For longer videos, it currently defaults to brute-force captioning, but can be extended for more complex methods. Args: frame_rate_per_second (float): Frame extraction rate in frames per second. Default is 2.0. @@ -391,18 +395,25 @@ class VideoSummaryDetector(AnalysisMethod): Dict[str, Any]: A dictionary containing the analysis results, including captions and summary. """ - minimal_edge_of_frames = 50 - extracted_video_dict = self._extract_video_frames( - self.subdict["video_path"], frame_rate_per_second=frame_rate_per_second - ) - if extracted_video_dict["n_frames"] <= minimal_edge_of_frames: - answer = self.brute_force_summary(extracted_video_dict) - - else: - # TODO: implement processing for long videos - summary_instruction = "Describe this image in a single caption, including all important details." - answer = self.brute_force_summary( - extracted_video_dict, summary_instruction=summary_instruction + minimal_edge_of_frames = 100 + all_answers = {} + # TODO: add support for answering questions about videos + for video_key in list(self.subdict.keys()): + entry = self.subdict[video_key] + extracted_video_dict = self._extract_video_frames( + entry, frame_rate_per_second=frame_rate_per_second ) + if extracted_video_dict["n_frames"] <= minimal_edge_of_frames: + answer = self.brute_force_summary(extracted_video_dict) + + else: + # TODO: implement processing for long videos + summary_instruction = "Describe this image in a single caption, including all important details." + answer = self.brute_force_summary( + extracted_video_dict, summary_instruction=summary_instruction + ) + + all_answers[video_key] = {"summary": answer["summary"]} + # TODO: captions has to be post-processed with foreseeing audio analysis return answer From 8b207cb347ce6bf189184a8dbd7437824ffd7638 Mon Sep 17 00:00:00 2001 From: Dmitrii Kapitan Date: Fri, 17 Oct 2025 18:54:40 +0200 Subject: [PATCH 20/23] update dependencies to integrate torchcodec --- environment.yml | 9 +++++---- pyproject.toml | 1 + 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/environment.yml b/environment.yml index 2268e08..58cd20e 100644 --- a/environment.yml +++ b/environment.yml @@ -9,11 +9,12 @@ channels: dependencies: - python=3.11 - cudatoolkit=11.8 - - pytorch=2.3.1 + - pytorch=2.5.1 - pytorch-cuda=11.8 - - torchvision=0.18.1 - - torchaudio=2.3.1 + - torchvision=0.20.1 + - torchaudio=2.5.1 - faiss-gpu-raft=1.8.0 - ipykernel - jupyterlab - - jupyterlab_widgets \ No newline at end of file + - jupyterlab_widgets + - ffmpeg<8 \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 1ead16e..b672d06 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -49,6 +49,7 @@ dependencies = [ "torchvision", "tqdm", "transformers>=4.54", + "torchcodec<0.2", "webcolors", ] From 731353dfdb6771207312b2dc20d4806494b2ff83 Mon Sep 17 00:00:00 2001 From: Dmitrii Kapitan Date: Fri, 17 Oct 2025 19:02:07 +0200 Subject: [PATCH 21/23] integrate torchcodec into decoding video --- ammico/video_summary.py | 121 ++++++++++++++++++++-------------------- 1 file changed, 61 insertions(+), 60 deletions(-) diff --git a/ammico/video_summary.py b/ammico/video_summary.py index 52890dd..88836d9 100644 --- a/ammico/video_summary.py +++ b/ammico/video_summary.py @@ -1,9 +1,9 @@ -import decord import re import math import torch import warnings from PIL import Image +from torchcodec.decoders import VideoDecoder from ammico.model import MultimodalSummaryModel from ammico.utils import AnalysisMethod @@ -35,44 +35,38 @@ class VideoSummaryDetector(AnalysisMethod): def _frame_batch_generator( self, - indices: torch.Tensor, timestamps: torch.Tensor, batch_size: int, - vr, + video_decoder: VideoDecoder, ) -> Generator[Tuple[torch.Tensor, torch.Tensor], None, None]: """ Yield batches of (frames, timestamps) for given frame indices. - frames are returned as a torch.Tensor with shape (B, C, H, W). - timestamps is a 1D torch.Tensor with B elements. """ - total = int(indices.numel()) - device = torch.device("cpu") + total = int(timestamps.numel()) for start in range(0, total, batch_size): - batch_idx_tensor = indices[start : start + batch_size] - # convert to python ints for decord API - batch_idx_list = [int(x.item()) for x in batch_idx_tensor] + batch_secs = timestamps[start : start + batch_size].tolist() + fb = video_decoder.get_frames_played_at(batch_secs) + frames = fb.data - # decord returns ndarray-like object; keep memory layout minimal and convert once - batch_frames_np = vr.get_batch(batch_idx_list).asnumpy() - - # convert to CHW torch layout - batch_frames = ( - torch.from_numpy(batch_frames_np).permute(0, 3, 1, 2).contiguous() - ).to(device, non_blocking=True) - - batch_times = timestamps[start : start + batch_size].to( - device, non_blocking=True + if not frames.is_contiguous(): + frames = frames.contiguous() + pts = fb.pts_seconds + pts_out = ( + pts.cpu().to(dtype=torch.float32) + if isinstance(pts, torch.Tensor) + else torch.tensor(pts, dtype=torch.float32) ) - - yield batch_frames, batch_times + yield frames, pts_out def _extract_video_frames( self, - entry: Optional[Dict[str, Any]], - frame_rate_per_second: float = 2, + entry: Dict[str, Any], + frame_rate_per_second: float = 2.0, batch_size: int = 32, - ) -> Dict[str, Any]: + ) -> Generator[Tuple[torch.Tensor, torch.Tensor], None, None]: """ Extract frames from a video at a specified frame rate and return them as a generator of batches. Args: @@ -80,41 +74,57 @@ class VideoSummaryDetector(AnalysisMethod): frame_rate_per_second (float, optional): Frame extraction rate in frames per second. Default is 2. batch_size (int, optional): Number of frames to include in each batch. Default is 32. Returns: - Dict[str, Any]: A dictionary containing a generator that yields batches of frames and their timestamps - and the total number of extracted frames. + Generator[Tuple[torch.Tensor, torch.Tensor], None, None]: A generator yielding tuples of + (frames, timestamps), where frames is a tensor of shape (B, C, H, W) and timestamps is a 1D tensor of length B. """ filename = entry.get("filename") if not filename: raise ValueError("entry must contain key 'filename'") - # TODO: consider using torchcodec for video decoding, since decord is no longer actively maintained - vr = decord.VideoReader(filename) + video_decoder = VideoDecoder(filename) + meta = video_decoder.metadata - nframes = len(vr) - video_fps = vr.get_avg_fps() - if video_fps is None or video_fps <= 0: + video_fps = getattr(meta, "average_fps", None) + if video_fps is None or not ( + isinstance(video_fps, (int, float)) and video_fps > 0 + ): video_fps = 30.0 - duration = nframes / float(video_fps) + begin_stream_seconds = getattr(meta, "begin_stream_seconds", None) + end_stream_seconds = getattr(meta, "end_stream_seconds", None) + nframes = len(video_decoder) + if getattr(meta, "duration_seconds", None) is not None: + duration = float(meta.duration_seconds) + elif begin_stream_seconds is not None and end_stream_seconds is not None: + duration = float(end_stream_seconds) - float(begin_stream_seconds) + elif nframes: + duration = float(nframes) / float(video_fps) + else: + duration = 0.0 if frame_rate_per_second <= 0: raise ValueError("frame_rate_per_second must be > 0") n_samples = max(1, int(math.floor(duration * frame_rate_per_second))) - sample_times = ( - torch.linspace(0, duration, steps=n_samples) - if n_samples > 1 - else torch.tensor([0.0]) - ) - indices = (sample_times * video_fps).round().long() - indices = torch.clamp(indices, 0, nframes - 1).unique(sorted=True) - timestamps = indices.to(torch.float32) / float(video_fps) - total_samples = int(indices.numel()) - generator = self._frame_batch_generator(indices, timestamps, batch_size, vr) + if begin_stream_seconds is not None and end_stream_seconds is not None: + sample_times = torch.linspace( + float(begin_stream_seconds), float(end_stream_seconds), steps=n_samples + ) + if sample_times.numel() > 1: + sample_times = torch.clamp( + sample_times, + min=float(begin_stream_seconds), + max=float(end_stream_seconds) - 1e-6, + ) + else: + sample_times = torch.linspace(0.0, max(0.0, duration), steps=n_samples) - return {"generator": generator, "n_frames": total_samples} + sample_times = sample_times.to(dtype=torch.float32, device="cpu") + generator = self._frame_batch_generator(sample_times, batch_size, video_decoder) + + return generator def _normalize_whitespace(self, s: str) -> str: return re.sub(r"\s+", " ", s).strip() @@ -273,8 +283,8 @@ class VideoSummaryDetector(AnalysisMethod): def brute_force_summary( self, - extracted_video_dict: Dict[str, Any], - summary_instruction: str = "Summarize the following frame captions into a concise paragraph (1-3 sentences):", + extracted_video_gen: Generator[Tuple[torch.Tensor, torch.Tensor], None, None], + summary_instruction: str = "Analyze the following captions from multiple frames of the same video and summarize the overall content of the video in one concise paragraph (1-3 sentences). Focus on the key themes, actions, or events across the video, not just the individual frames.", ) -> Dict[str, Any]: """ Generate captions for all extracted frames and then produce a concise summary of the video. @@ -285,12 +295,11 @@ class VideoSummaryDetector(AnalysisMethod): Dict[str, Any]: A dictionary containing the list of captions with timestamps and the final summary. """ - gen = extracted_video_dict["generator"] caption_instruction = "Describe this image in one concise caption." collected: List[Tuple[float, str]] = [] proc = self.summary_model.processor - for batch_frames, batch_times in gen: + for batch_frames, batch_times in extracted_video_gen: pil_list = self._tensor_batch_to_pil_list(batch_frames.cpu()) prompt_texts = [] @@ -320,7 +329,6 @@ class VideoSummaryDetector(AnalysisMethod): self.summary_model.tokenizer, ) - # normalize batch_times to Python floats if isinstance(batch_times, torch.Tensor): batch_times_list = batch_times.cpu().tolist() else: @@ -329,7 +337,7 @@ class VideoSummaryDetector(AnalysisMethod): collected.append((float(t), c)) collected.sort(key=lambda x: x[0]) - gen.close() + extracted_video_gen.close() MAX_CAPTIONS_FOR_SUMMARY = 200 caps_for_summary = ( @@ -383,7 +391,9 @@ class VideoSummaryDetector(AnalysisMethod): "summary": final_summary, } - def analyse_video(self, frame_rate_per_second: float = 2.0) -> Dict[str, Any]: + def analyse_videos_from_dict( + self, frame_rate_per_second: float = 2.0 + ) -> Dict[str, Any]: """ Analyse the video specified in self.subdict using frame extraction and captioning. For short videos (<=100 frames at the specified frame rate), it uses brute-force captioning. @@ -395,24 +405,15 @@ class VideoSummaryDetector(AnalysisMethod): Dict[str, Any]: A dictionary containing the analysis results, including captions and summary. """ - minimal_edge_of_frames = 100 all_answers = {} # TODO: add support for answering questions about videos for video_key in list(self.subdict.keys()): entry = self.subdict[video_key] - extracted_video_dict = self._extract_video_frames( + extracted_video_gen = self._extract_video_frames( entry, frame_rate_per_second=frame_rate_per_second ) - if extracted_video_dict["n_frames"] <= minimal_edge_of_frames: - answer = self.brute_force_summary(extracted_video_dict) - - else: - # TODO: implement processing for long videos - summary_instruction = "Describe this image in a single caption, including all important details." - answer = self.brute_force_summary( - extracted_video_dict, summary_instruction=summary_instruction - ) + answer = self.brute_force_summary(extracted_video_gen) all_answers[video_key] = {"summary": answer["summary"]} # TODO: captions has to be post-processed with foreseeing audio analysis From 8c26a8de5e67fbd687983e063436e69b90b8e080 Mon Sep 17 00:00:00 2001 From: DimasfromLavoisier Date: Mon, 20 Oct 2025 15:22:59 +0200 Subject: [PATCH 22/23] update deprecated torch.cuda.amp --- ammico/video_summary.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ammico/video_summary.py b/ammico/video_summary.py index 88836d9..666c373 100644 --- a/ammico/video_summary.py +++ b/ammico/video_summary.py @@ -8,7 +8,7 @@ from torchcodec.decoders import VideoDecoder from ammico.model import MultimodalSummaryModel from ammico.utils import AnalysisMethod -from typing import List, Optional, Dict, Any, Generator, Tuple +from typing import List, Dict, Any, Generator, Tuple from transformers import GenerationConfig @@ -229,7 +229,7 @@ class VideoSummaryDetector(AnalysisMethod): with torch.inference_mode(): try: if self.summary_model.device == "cuda": - with torch.cuda.amp.autocast(enabled=True): + with torch.amp.autocast("cuda", enabled=True): generated_ids = self.summary_model.model.generate( **processor_inputs, generation_config=gen_conf ) From c2b2079d4e6d70820042f5f78edded37d8eb0e62 Mon Sep 17 00:00:00 2001 From: DimasfromLavoisier Date: Wed, 22 Oct 2025 17:12:19 +0200 Subject: [PATCH 23/23] add functionality for video vqa --- ammico/image_summary.py | 55 ++----- ammico/utils.py | 114 ++++++++++++++ ammico/video_summary.py | 328 ++++++++++++++++++++++++++-------------- 3 files changed, 334 insertions(+), 163 deletions(-) diff --git a/ammico/image_summary.py b/ammico/image_summary.py index 203ef21..adc01af 100644 --- a/ammico/image_summary.py +++ b/ammico/image_summary.py @@ -6,7 +6,7 @@ import torch from PIL import Image import warnings -from typing import List, Optional, Union, Dict, Any, Tuple +from typing import List, Optional, Union, Dict, Any from collections.abc import Sequence as _Sequence from transformers import GenerationConfig from qwen_vl_utils import process_vision_info @@ -120,42 +120,11 @@ class ImageSummaryDetector(AnalysisMethod): return inputs - def _validate_analysis_type( - self, - analysis_type: Union["AnalysisType", str], - list_of_questions: Optional[List[str]], - max_questions_per_image: int, - ) -> Tuple[str, List[str], bool, bool]: - if isinstance(analysis_type, AnalysisType): - analysis_type = analysis_type.value - - allowed = {"summary", "questions", "summary_and_questions"} - if analysis_type not in allowed: - raise ValueError(f"analysis_type must be one of {allowed}") - - if list_of_questions is None: - list_of_questions = [ - "Are there people in the image?", - "What is this picture about?", - ] - - if analysis_type in ("questions", "summary_and_questions"): - if len(list_of_questions) > max_questions_per_image: - raise ValueError( - f"Number of questions per image ({len(list_of_questions)}) exceeds safety cap ({max_questions_per_image}). Reduce questions or increase max_questions_per_image." - ) - - is_summary = analysis_type in ("summary", "summary_and_questions") - is_questions = analysis_type in ("questions", "summary_and_questions") - - return analysis_type, list_of_questions, is_summary, is_questions - def analyse_image( self, entry: dict, - analysis_type: Union[str, AnalysisType] = AnalysisType.SUMMARY_AND_QUESTIONS, + analysis_type: Union[str, AnalysisType] = AnalysisType.SUMMARY, list_of_questions: Optional[List[str]] = None, - max_questions_per_image: int = 32, is_concise_summary: bool = True, is_concise_answer: bool = True, ) -> Dict[str, Any]: @@ -165,10 +134,8 @@ class ImageSummaryDetector(AnalysisMethod): - 'vqa' (dict) if questions requested """ self.subdict = entry - analysis_type, list_of_questions, is_summary, is_questions = ( - self._validate_analysis_type( - analysis_type, list_of_questions, max_questions_per_image - ) + analysis_type, is_summary, is_questions = AnalysisType._validate_analysis_type( + analysis_type, list_of_questions ) if is_summary: @@ -195,9 +162,8 @@ class ImageSummaryDetector(AnalysisMethod): def analyse_images_from_dict( self, - analysis_type: Union[AnalysisType, str] = AnalysisType.SUMMARY_AND_QUESTIONS, + analysis_type: Union[AnalysisType, str] = AnalysisType.SUMMARY, list_of_questions: Optional[List[str]] = None, - max_questions_per_image: int = 32, keys_batch_size: int = 16, is_concise_summary: bool = True, is_concise_answer: bool = True, @@ -208,7 +174,6 @@ class ImageSummaryDetector(AnalysisMethod): Args: analysis_type (str): type of the analysis. list_of_questions (list[str]): list of questions. - max_questions_per_image (int): maximum number of questions per image. We recommend to keep it low to avoid long processing times and high memory usage. keys_batch_size (int): number of images to process in a batch. is_concise_summary (bool): whether to generate concise summary. is_concise_answer (bool): whether to generate concise answers. @@ -216,10 +181,8 @@ class ImageSummaryDetector(AnalysisMethod): self.subdict (dict): dictionary with analysis results. """ # TODO: add option to ask multiple questions per image as one batch. - analysis_type, list_of_questions, is_summary, is_questions = ( - self._validate_analysis_type( - analysis_type, list_of_questions, max_questions_per_image - ) + analysis_type, is_summary, is_questions = AnalysisType._validate_analysis_type( + analysis_type, list_of_questions ) keys = list(self.subdict.keys()) @@ -284,7 +247,7 @@ class ImageSummaryDetector(AnalysisMethod): with torch.inference_mode(): try: if self.summary_model.device == "cuda": - with torch.cuda.amp.autocast(enabled=True): + with torch.amp.autocast("cuda", enabled=True): generated_ids = self.summary_model.model.generate( **inputs, generation_config=gen_conf ) @@ -366,7 +329,7 @@ class ImageSummaryDetector(AnalysisMethod): inputs = self._prepare_inputs(chunk, entry) with torch.inference_mode(): if self.summary_model.device == "cuda": - with torch.cuda.amp.autocast(enabled=True): + with torch.amp.autocast("cuda", enabled=True): out_ids = self.summary_model.model.generate( **inputs, generation_config=gen_conf ) diff --git a/ammico/utils.py b/ammico/utils.py index d698708..07bae90 100644 --- a/ammico/utils.py +++ b/ammico/utils.py @@ -6,6 +6,8 @@ import importlib_resources import collections import random from enum import Enum +from typing import List, Tuple, Optional, Union +import re pkg = importlib_resources.files("ammico") @@ -46,6 +48,35 @@ class AnalysisType(str, Enum): QUESTIONS = "questions" SUMMARY_AND_QUESTIONS = "summary_and_questions" + @classmethod + def _validate_analysis_type( + cls, + analysis_type: Union["AnalysisType", str], + list_of_questions: Optional[List[str]], + ) -> Tuple[str, bool, bool]: + max_questions_per_image = 15 # safety cap to avoid too many questions + if isinstance(analysis_type, AnalysisType): + analysis_type = analysis_type.value + + allowed = {item.value for item in AnalysisType} + if analysis_type not in allowed: + raise ValueError(f"analysis_type must be one of {allowed}") + + if analysis_type in ("questions", "summary_and_questions"): + if not list_of_questions: + raise ValueError( + "list_of_questions must be provided for QUESTIONS analysis type." + ) + + if len(list_of_questions) > max_questions_per_image: + raise ValueError( + f"Number of questions per image ({len(list_of_questions)}) exceeds safety cap ({max_questions_per_image}). Reduce questions or increase max_questions_per_image." + ) + + is_summary = analysis_type in ("summary", "summary_and_questions") + is_questions = analysis_type in ("questions", "summary_and_questions") + return analysis_type, is_summary, is_questions + class AnalysisMethod: """Base class to be inherited by all analysis methods.""" @@ -101,6 +132,89 @@ def _limit_results(results, limit): return results +def _categorize_outputs( + collected: List[Tuple[float, str]], + include_questions: bool = False, +) -> Tuple[List[str], List[str]]: + """ + Categorize collected outputs into summary bullets and VQA bullets. + Args: + collected (List[Tuple[float, str]]): List of tuples containing timestamps and generated texts. + Returns: + Tuple[List[str], List[str]]: A tuple containing two lists - summary bullets and VQA bullets. + """ + MAX_CAPTIONS_FOR_SUMMARY = 600 # TODO For now, this is a constant value, but later we need to make it adjustable, with the idea of cutting out the most similar frames to reduce the load on the system. + caps_for_summary_vqa = ( + collected[-MAX_CAPTIONS_FOR_SUMMARY:] + if len(collected) > MAX_CAPTIONS_FOR_SUMMARY + else collected + ) + bullets_summary = [] + bullets_vqa = [] + + for t, c in caps_for_summary_vqa: + if include_questions: + result_sections = c.strip() + m = re.search( + r"Summary\s*:\s*(.*?)\s*(?:VQA\s+Answers\s*:\s*(.*))?$", + result_sections, + flags=re.IGNORECASE | re.DOTALL, + ) + if m: + summary_text = ( + m.group(1).replace("\n", " ").strip() if m.group(1) else None + ) + vqa_text = m.group(2).strip() if m.group(2) else None + if not summary_text or not vqa_text: + raise ValueError( + f"Model output is missing either summary or VQA answers: {c}" + ) + bullets_summary.append(f"- [{t:.3f}s] {summary_text}") + bullets_vqa.append(f"- [{t:.3f}s] {vqa_text}") + else: + raise ValueError( + f"Failed to parse summary and VQA answers from model output: {c}" + ) + else: + snippet = c.replace("\n", " ").strip() + bullets_summary.append(f"- [{t:.3f}s] {snippet}") + return bullets_summary, bullets_vqa + + +def _normalize_whitespace(s: str) -> str: + return re.sub(r"\s+", " ", s).strip() + + +def _strip_prompt_prefix_literal(decoded: str, prompt: str) -> str: + """ + Remove any literal prompt prefix from decoded text using a normalized-substring match. + Guarantees no prompt text remains at the start of returned string (best-effort). + """ + if not decoded: + return "" + if not prompt: + return decoded.strip() + + d_norm = _normalize_whitespace(decoded) + p_norm = _normalize_whitespace(prompt) + + idx = d_norm.find(p_norm) + if idx != -1: + running = [] + for i, ch in enumerate(decoded): + running.append(ch if not ch.isspace() else " ") + cur_norm = _normalize_whitespace("".join(running)) + if cur_norm.endswith(p_norm): + return decoded[i + 1 :].lstrip() if i + 1 < len(decoded) else "" + m = re.match( + r"^(?:\s*(system|user|assistant)[:\s-]*\n?)+", decoded, flags=re.IGNORECASE + ) + if m: + return decoded[m.end() :].lstrip() + + return decoded.lstrip("\n\r ").lstrip(":;- ").strip() + + def find_videos( path: str = None, pattern=["mp4"], # TODO: test with more video formats diff --git a/ammico/video_summary.py b/ammico/video_summary.py index 666c373..391f2d0 100644 --- a/ammico/video_summary.py +++ b/ammico/video_summary.py @@ -6,17 +6,24 @@ from PIL import Image from torchcodec.decoders import VideoDecoder from ammico.model import MultimodalSummaryModel -from ammico.utils import AnalysisMethod +from ammico.utils import ( + AnalysisMethod, + AnalysisType, + _categorize_outputs, + _strip_prompt_prefix_literal, +) -from typing import List, Dict, Any, Generator, Tuple +from typing import List, Dict, Any, Generator, Tuple, Union, Optional from transformers import GenerationConfig class VideoSummaryDetector(AnalysisMethod): + MAX_SAMPLES_CAP = 1000 # safety cap for total extracted frames + def __init__( self, summary_model: MultimodalSummaryModel, - subdict: dict = {}, + subdict: Optional[Dict[str, Any]] = None, ) -> None: """ Class for analysing videos using QWEN-2.5-VL model. @@ -29,6 +36,8 @@ class VideoSummaryDetector(AnalysisMethod): Returns: None. """ + if subdict is None: + subdict = {} super().__init__(subdict) self.summary_model = summary_model @@ -107,6 +116,7 @@ class VideoSummaryDetector(AnalysisMethod): raise ValueError("frame_rate_per_second must be > 0") n_samples = max(1, int(math.floor(duration * frame_rate_per_second))) + n_samples = min(n_samples, self.MAX_SAMPLES_CAP) if begin_stream_seconds is not None and end_stream_seconds is not None: sample_times = torch.linspace( @@ -126,38 +136,6 @@ class VideoSummaryDetector(AnalysisMethod): return generator - def _normalize_whitespace(self, s: str) -> str: - return re.sub(r"\s+", " ", s).strip() - - def _strip_prompt_prefix_literal(self, decoded: str, prompt: str) -> str: - """ - Remove any literal prompt prefix from decoded text using a normalized-substring match. - Guarantees no prompt text remains at the start of returned string (best-effort). - """ - if not decoded: - return "" - if not prompt: - return decoded.strip() - - d_norm = self._normalize_whitespace(decoded) - p_norm = self._normalize_whitespace(prompt) - - idx = d_norm.find(p_norm) - if idx != -1: - running = [] - for i, ch in enumerate(decoded): - running.append(ch if not ch.isspace() else " ") - cur_norm = self._normalize_whitespace("".join(running)) - if cur_norm.endswith(p_norm): - return decoded[i + 1 :].lstrip() if i + 1 < len(decoded) else "" - m = re.match( - r"^(?:\s*(system|user|assistant)[:\s-]*\n?)+", decoded, flags=re.IGNORECASE - ) - if m: - return decoded[m.end() :].lstrip() - - return decoded.lstrip("\n\r ").lstrip(":;- ").strip() - def _decode_trimmed_outputs( self, generated_ids: torch.Tensor, @@ -170,20 +148,18 @@ class VideoSummaryDetector(AnalysisMethod): Then remove any literal prompt prefix using prompt_texts (one per batch element). """ - decoded_results = [] batch_size = generated_ids.shape[0] if "input_ids" in inputs: - lengths = ( - inputs["input_ids"] - .ne( - tokenizer.pad_token_id - if tokenizer.pad_token_id is not None - else tokenizer.eos_token_id - ) - .sum(dim=1) - .tolist() + token_for_padding = ( + tokenizer.pad_token_id + if getattr(tokenizer, "pad_token_id", None) is not None + else getattr(tokenizer, "eos_token_id", None) ) + if token_for_padding is None: + lengths = [int(inputs["input_ids"].shape[1])] * batch_size + else: + lengths = inputs["input_ids"].ne(token_for_padding).sum(dim=1).tolist() else: lengths = [0] * batch_size @@ -195,13 +171,15 @@ class VideoSummaryDetector(AnalysisMethod): t = out_ids[in_len:] else: t = out_ids.new_empty((0,), dtype=out_ids.dtype) - trimmed_ids.append(t) + t_cpu = t.to("cpu") + trimmed_ids.append(t_cpu.tolist()) decoded = tokenizer.batch_decode( trimmed_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True ) + decoded_results = [] for ptext, raw in zip(prompt_texts, decoded): - cleaned = self._strip_prompt_prefix_literal(raw, ptext) + cleaned = _strip_prompt_prefix_literal(raw, ptext) decoded_results.append(cleaned) return decoded_results @@ -209,7 +187,6 @@ class VideoSummaryDetector(AnalysisMethod): self, processor_inputs: Dict[str, torch.Tensor], prompt_texts: List[str], - model, tokenizer, ): """ @@ -224,7 +201,7 @@ class VideoSummaryDetector(AnalysisMethod): for k, v in list(processor_inputs.items()): if isinstance(v, torch.Tensor): - processor_inputs[k] = v.to(model.device) + processor_inputs[k] = v.to(self.summary_model.device) with torch.inference_mode(): try: @@ -239,8 +216,8 @@ class VideoSummaryDetector(AnalysisMethod): ) except RuntimeError as e: warnings.warn( - "Retry without autocast failed: %s. Attempting cudnn-disabled retry.", - e, + f"Generation failed with error: {e}. Retrying with cuDNN disabled.", + RuntimeWarning, ) cudnn_was_enabled = ( torch.backends.cudnn.is_available() and torch.backends.cudnn.enabled @@ -273,7 +250,9 @@ class VideoSummaryDetector(AnalysisMethod): batch = batch.to("cpu") batch = batch.contiguous() - if batch.dtype != torch.uint8: + if batch.dtype.is_floating_point: + batch = (batch.clamp(0.0, 1.0) * 255.0).to(torch.uint8) + elif batch.dtype != torch.uint8: batch = batch.to(torch.uint8) pil_list: List[Image.Image] = [] for frame in batch: @@ -281,10 +260,10 @@ class VideoSummaryDetector(AnalysisMethod): pil_list.append(Image.fromarray(arr)) return pil_list - def brute_force_summary( + def make_captions_from_extracted_frames( self, extracted_video_gen: Generator[Tuple[torch.Tensor, torch.Tensor], None, None], - summary_instruction: str = "Analyze the following captions from multiple frames of the same video and summarize the overall content of the video in one concise paragraph (1-3 sentences). Focus on the key themes, actions, or events across the video, not just the individual frames.", + list_of_questions: Optional[List[str]] = None, ) -> Dict[str, Any]: """ Generate captions for all extracted frames and then produce a concise summary of the video. @@ -296,68 +275,93 @@ class VideoSummaryDetector(AnalysisMethod): """ caption_instruction = "Describe this image in one concise caption." + include_questions = bool(list_of_questions) + if include_questions: + q_block = "\n".join( + [f"{i + 1}. {q.strip()}" for i, q in enumerate(list_of_questions)] + ) + caption_instruction += ( + " In addition to the concise caption, also answer the following questions based ONLY on the image. Answers must be very brief and concise." + " Produce exactly two labeled sections: \n\n" + "Summary: \n\n" + "VQA Answers: \n1. \n2. \n etc." + "\nReturn only those two sections for each image (do not add extra commentary)." + "\nIf the answer cannot be determined based on the provided answer blocks," + ' reply with the line "The answer cannot be determined based on the information provided."' + f"\n\nQuestions:\n{q_block}" + ) collected: List[Tuple[float, str]] = [] proc = self.summary_model.processor + try: + for batch_frames, batch_times in extracted_video_gen: + pil_list = self._tensor_batch_to_pil_list(batch_frames.cpu()) - for batch_frames, batch_times in extracted_video_gen: - pil_list = self._tensor_batch_to_pil_list(batch_frames.cpu()) + prompt_texts = [] + for p in pil_list: + messages = [ + { + "role": "user", + "content": [ + {"type": "image", "image": p}, + {"type": "text", "text": caption_instruction}, + ], + } + ] - prompt_texts = [] - for p in pil_list: - messages = [ - { - "role": "user", - "content": [ - {"type": "image", "image": p}, - {"type": "text", "text": caption_instruction}, - ], - } - ] + prompt_text = proc.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) + prompt_texts.append(prompt_text) - prompt_text = proc.apply_chat_template( - messages, tokenize=False, add_generation_prompt=True + processor_inputs = proc( + text=prompt_texts, + images=pil_list, + return_tensors="pt", + padding=True, + ) + captions = self._generate_from_processor_inputs( + processor_inputs, + prompt_texts, + self.summary_model.tokenizer, ) - prompt_texts.append(prompt_text) - processor_inputs = proc( - text=prompt_texts, images=pil_list, return_tensors="pt", padding=True - ) - captions = self._generate_from_processor_inputs( - processor_inputs, - prompt_texts, - self.summary_model, - self.summary_model.tokenizer, - ) - - if isinstance(batch_times, torch.Tensor): - batch_times_list = batch_times.cpu().tolist() - else: - batch_times_list = list(batch_times) - for t, c in zip(batch_times_list, captions): - collected.append((float(t), c)) + if isinstance(batch_times, torch.Tensor): + batch_times_list = batch_times.cpu().tolist() + else: + batch_times_list = list(batch_times) + for t, c in zip(batch_times_list, captions): + collected.append((float(t), c)) + finally: + try: + extracted_video_gen.close() + except Exception: + warnings.warn("Failed to close video frame generator.", RuntimeWarning) collected.sort(key=lambda x: x[0]) - extracted_video_gen.close() + bullets_summary, bullets_vqa = _categorize_outputs(collected, include_questions) - MAX_CAPTIONS_FOR_SUMMARY = 200 - caps_for_summary = ( - collected[-MAX_CAPTIONS_FOR_SUMMARY:] - if len(collected) > MAX_CAPTIONS_FOR_SUMMARY - else collected - ) + return { + "summary_bullets": bullets_summary, + "vqa_bullets": bullets_vqa, + } # TODO consider taking out time stamps from the returned structure - bullets = [] - for t, c in caps_for_summary: - snippet = c.replace("\n", " ").strip() - bullets.append(f"- [{t:.3f}s] {snippet}") + def final_summary(self, summary_dict: Dict[str, Any]) -> Dict[str, Any]: + """ + Produce a concise summary of the video, based on generated captions for all extracted frames. + Args: + summary_dict (Dict[str, Any]): Dictionary containing captions for the frames. + Returns: + Dict[str, Any]: A dictionary containing the list of captions with timestamps and the final summary. + """ + summary_instruction = "Analyze the following captions from multiple frames of the same video and summarize the overall content of the video in one concise paragraph (1-3 sentences). Focus on the key themes, actions, or events across the video, not just the individual frames." + proc = self.summary_model.processor + + bullets = summary_dict.get("summary_bullets", []) + if not bullets: + raise ValueError("No captions available for summary generation.") combined_captions_text = "\n".join(bullets) - summary_user_text = ( - summary_instruction - + "\n\n" - + combined_captions_text - + "\n\nPlease produce a single concise paragraph." - ) + summary_user_text = summary_instruction + "\n\n" + combined_captions_text + "\n" messages = [ { @@ -381,40 +385,130 @@ class VideoSummaryDetector(AnalysisMethod): final_summary_list = self._generate_from_processor_inputs( summary_inputs, [summary_prompt_text], - self.summary_model.model, self.summary_model.tokenizer, ) final_summary = final_summary_list[0].strip() if final_summary_list else "" return { - "captions": collected, "summary": final_summary, } + def final_answers( + self, + answers_dict: Dict[str, Any], + list_of_questions: List[str], + ) -> Dict[str, Any]: + """ + Answer the list of questions for the video based on the VQA bullets from the frames. + Args: + answers_dict (Dict[str, Any]): Dictionary containing the VQA bullets. + Returns: + Dict[str, Any]: A dictionary containing the list of answers to the questions. + """ + vqa_bullets = answers_dict.get("vqa_bullets", []) + if not vqa_bullets: + raise ValueError( + "No VQA bullets generated for single frames available for answering questions." + ) + + include_questions = bool(list_of_questions) + if include_questions: + q_block = "\n".join( + [f"{i + 1}. {q.strip()}" for i, q in enumerate(list_of_questions)] + ) + prompt = ( + "You are provided with a set of short VQA-captions, each of which is a block of short answers" + " extracted from individual frames of the same video.\n\n" + "VQA-captions (use ONLY these to answer):\n" + f"{vqa_bullets}\n\n" + "Answer the following questions briefly, based ONLY on the lists of answers provided above. The VQA-captions above contain answers" + " to the same questions you are about to answer. If the answer cannot be determined based on the provided answer blocks," + ' reply with the line "The answer cannot be determined based on the information provided."' + "Questions:\n" + f"{q_block}\n\n" + "Produce an ordered list with answers in the same order as the questions. You must have this structure of your output: " + "Answers: \n1. \n2. \n etc." + "Return ONLY the ordered list with answers and NOTHING else — no commentary, no explanation, no surrounding markdown." + ) + else: + raise ValueError( + "list_of_questions must be provided for making final answers." + ) + + proc = self.summary_model.processor + messages = [ + { + "role": "user", + "content": [{"type": "text", "text": prompt}], + } + ] + final_vqa_prompt_text = proc.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) + final_vqa_inputs = proc( + text=[final_vqa_prompt_text], return_tensors="pt", padding=True + ) + final_vqa_inputs = { + k: v.to(self.summary_model.device) if isinstance(v, torch.Tensor) else v + for k, v in final_vqa_inputs.items() + } + + final_vqa_list = self._generate_from_processor_inputs( + final_vqa_inputs, + [final_vqa_prompt_text], + self.summary_model.tokenizer, + ) + + final_vqa_output = final_vqa_list[0].strip() if final_vqa_list else "" + vqa_answers = [] + answer_matches = re.findall( + r"\d+\.\s*(.*?)(?=\n\d+\.|$)", final_vqa_output, flags=re.DOTALL + ) + for answer in answer_matches: + vqa_answers.append(answer.strip()) + return { + "vqa_answers": vqa_answers, + } + def analyse_videos_from_dict( - self, frame_rate_per_second: float = 2.0 + self, + analysis_type: Union[AnalysisType, str] = AnalysisType.SUMMARY, + frame_rate_per_second: float = 2.0, + list_of_questions: Optional[List[str]] = None, ) -> Dict[str, Any]: """ Analyse the video specified in self.subdict using frame extraction and captioning. - For short videos (<=100 frames at the specified frame rate), it uses brute-force captioning. - For longer videos, it currently defaults to brute-force captioning, but can be extended for more complex methods. - Args: + analysis_type (Union[AnalysisType, str], optional): Type of analysis to perform. Defaults to AnalysisType.SUMMARY. frame_rate_per_second (float): Frame extraction rate in frames per second. Default is 2.0. + list_of_questions (List[str], optional): List of questions to answer about the video. Required if analysis_type includes questions. Returns: - Dict[str, Any]: A dictionary containing the analysis results, including captions and summary. + Dict[str, Any]: A dictionary containing the analysis results, including summary and answers for provided questions(if any). """ all_answers = {} - # TODO: add support for answering questions about videos - for video_key in list(self.subdict.keys()): - entry = self.subdict[video_key] - extracted_video_gen = self._extract_video_frames( + analysis_type, is_summary, is_questions = AnalysisType._validate_analysis_type( + analysis_type, list_of_questions + ) + + for video_key, entry in self.subdict.items(): + summary_and_vqa = {} + frames_generator = self._extract_video_frames( entry, frame_rate_per_second=frame_rate_per_second ) - answer = self.brute_force_summary(extracted_video_gen) - all_answers[video_key] = {"summary": answer["summary"]} + answers_dict = self.make_captions_from_extracted_frames( + frames_generator, list_of_questions=list_of_questions + ) # TODO: captions has to be post-processed with foreseeing audio analysis + # TODO: captions and answers may lead to prompt, that superior model limits. Consider hierarchical approach. + if is_summary: + answer = self.final_summary(answers_dict) + summary_and_vqa["summary"] = answer["summary"] + if is_questions: + answer = self.final_answers(answers_dict, list_of_questions) + summary_and_vqa["vqa_answers"] = answer["vqa_answers"] - return answer + all_answers[video_key] = summary_and_vqa + + return all_answers