diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 015a8ee..1b1675a 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -31,7 +31,7 @@ jobs: - name: Run pytest run: | cd ammico - python -m pytest -svv -m "not gcv" --cov=. --cov-report=xml + python -m pytest -svv -m "not gcv and not long" --cov=. --cov-report=xml - name: Upload coverage if: matrix.os == 'ubuntu-latest' && matrix.python-version == '3.11' uses: codecov/codecov-action@v3 diff --git a/ammico/__init__.py b/ammico/__init__.py index 67a7065..5554ee0 100644 --- a/ammico/__init__.py +++ b/ammico/__init__.py @@ -1,7 +1,10 @@ from ammico.display import AnalysisExplorer from ammico.faces import EmotionDetector, ethical_disclosure +from ammico.model import MultimodalSummaryModel from ammico.text import TextDetector, TextAnalyzer, privacy_disclosure -from ammico.utils import find_files, get_dataframe +from ammico.image_summary import ImageSummaryDetector +from ammico.utils import find_files, get_dataframe, AnalysisType, find_videos +from ammico.video_summary import VideoSummaryDetector # Export the version defined in project metadata try: @@ -12,11 +15,16 @@ except ImportError: __version__ = "unknown" __all__ = [ + "AnalysisType", "AnalysisExplorer", "EmotionDetector", + "MultimodalSummaryModel", "TextDetector", "TextAnalyzer", + "ImageSummaryDetector", + "VideoSummaryDetector", "find_files", + "find_videos", "get_dataframe", "ethical_disclosure", "privacy_disclosure", diff --git a/ammico/display.py b/ammico/display.py index c80d0db..5b860a7 100644 --- a/ammico/display.py +++ b/ammico/display.py @@ -1,10 +1,14 @@ import ammico.faces as faces import ammico.text as text import ammico.colors as colors +import ammico.image_summary as image_summary +from ammico.model import MultimodalSummaryModel import pandas as pd from dash import html, Input, Output, dcc, State, Dash from PIL import Image import dash_bootstrap_components as dbc +import warnings +from typing import Dict, Any, List, Optional COLOR_SCHEMES = [ @@ -94,6 +98,8 @@ class AnalysisExplorer: State("left_select_id", "options"), State("left_select_id", "value"), State("Dropdown_select_Detector", "value"), + State("Dropdown_analysis_type", "value"), + State("textarea_questions", "value"), State("setting_Text_analyse_text", "value"), State("setting_privacy_env_var", "value"), State("setting_Emotion_emotion_threshold", "value"), @@ -108,9 +114,15 @@ class AnalysisExplorer: Output("settings_TextDetector", "style"), Output("settings_EmotionDetector", "style"), Output("settings_ColorDetector", "style"), + Output("settings_VQA", "style"), Input("Dropdown_select_Detector", "value"), )(self._update_detector_setting) + self.app.callback( + Output("textarea_questions", "style"), + Input("Dropdown_analysis_type", "value"), + )(self._show_questions_textarea_on_demand) + # I split the different sections into subfunctions for better clarity def _top_file_explorer(self, mydict: dict) -> html.Div: """Initialize the file explorer dropdown for selecting the file to be analyzed. @@ -157,14 +169,6 @@ class AnalysisExplorer: id="settings_TextDetector", style={"display": "none"}, children=[ - dbc.Row( - dcc.Checklist( - ["Analyse text"], - ["Analyse text"], - id="setting_Text_analyse_text", - style={"margin-bottom": "10px"}, - ), - ), # row 1 dbc.Row( dbc.Col( @@ -272,8 +276,69 @@ class AnalysisExplorer: ) ], ), + # start VQA settings + html.Div( + id="settings_VQA", + style={"display": "none"}, + children=[ + dbc.Card( + [ + dbc.CardBody( + [ + dbc.Row( + dbc.Col( + dcc.Dropdown( + id="Dropdown_analysis_type", + options=[ + {"label": v, "value": v} + for v in SUMMARY_ANALYSIS_TYPE + ], + value="summary_and_questions", + clearable=False, + style={ + "width": "100%", + "minWidth": "240px", + "maxWidth": "520px", + }, + ), + ), + justify="start", + ), + html.Div(style={"height": "8px"}), + dbc.Row( + [ + dbc.Col( + dcc.Textarea( + id="textarea_questions", + value="Are there people in the image?\nWhat is this picture about?", + placeholder="One question per line...", + style={ + "width": "100%", + "minHeight": "160px", + "height": "220px", + "resize": "vertical", + "overflow": "auto", + }, + rows=8, + ), + width=12, + ), + ], + justify="start", + ), + ] + ) + ], + style={ + "width": "100%", + "marginTop": "10px", + "zIndex": 2000, + }, + ) + ], + ), ], - style={"width": "100%", "display": "inline-block"}, + style={"width": "100%", "display": "inline-block", "overflow": "visible"}, ) return settings_layout @@ -293,6 +358,7 @@ class AnalysisExplorer: "TextDetector", "EmotionDetector", "ColorDetector", + "VQA", ], value="TextDetector", id="Dropdown_select_Detector", @@ -344,7 +410,7 @@ class AnalysisExplorer: port (int, optional): The port number to run the server on (default: 8050). """ - self.app.run_server(debug=True, port=port) + self.app.run(debug=True, port=port) # Dash callbacks def update_picture(self, img_path: str): @@ -379,19 +445,27 @@ class AnalysisExplorer: if setting_input == "EmotionDetector": return display_none, display_flex, display_none, display_none - if setting_input == "ColorDetector": return display_none, display_none, display_flex, display_none - + if setting_input == "VQA": + return display_none, display_none, display_none, display_flex else: return display_none, display_none, display_none, display_none + def _parse_questions(self, text: Optional[str]) -> Optional[List[str]]: + if not text: + return None + qs = [q.strip() for q in text.splitlines() if q.strip()] + return qs if qs else None + def _right_output_analysis( self, n_clicks, all_img_options: dict, current_img_value: str, detector_value: str, + analysis_type_value: str, + textarea_questions_value: str, settings_text_analyse_text: list, setting_privacy_env_var: str, setting_emotion_emotion_threshold: int, @@ -413,54 +487,75 @@ class AnalysisExplorer: "EmotionDetector": faces.EmotionDetector, "TextDetector": text.TextDetector, "ColorDetector": colors.ColorDetector, + "VQA": image_summary.ImageSummaryDetector, } # Get image ID from dropdown value, which is the filepath if current_img_value is None: return {} image_id = all_img_options[current_img_value] - # copy image so prvious runs don't leave their default values in the dict - image_copy = self.mydict[image_id].copy() + image_copy = self.mydict.get(image_id, {}).copy() - # detector value is the string name of the chosen detector - identify_function = identify_dict[detector_value] - - if detector_value == "TextDetector": - analyse_text = ( - True if settings_text_analyse_text == ["Analyse text"] else False - ) - detector_class = identify_function( - image_copy, - analyse_text=analyse_text, - accept_privacy=( - setting_privacy_env_var - if setting_privacy_env_var - else "PRIVACY_AMMICO" - ), - ) - elif detector_value == "EmotionDetector": - detector_class = identify_function( - image_copy, - emotion_threshold=setting_emotion_emotion_threshold, - race_threshold=setting_emotion_race_threshold, - gender_threshold=setting_emotion_gender_threshold, - accept_disclosure=( - setting_emotion_env_var - if setting_emotion_env_var - else "DISCLOSURE_AMMICO" - ), - ) - elif detector_value == "ColorDetector": - detector_class = identify_function( - image_copy, - delta_e_method=setting_color_delta_e_method, - ) + analysis_dict: Dict[str, Any] = {} + if detector_value == "VQA": + try: + qwen_model = MultimodalSummaryModel( + model_id="Qwen/Qwen2.5-VL-3B-Instruct" + ) # TODO: allow user to specify model + vqa_cls = identify_dict.get("VQA") + vqa_detector = vqa_cls(qwen_model, subdict={}) + questions_list = self._parse_questions(textarea_questions_value) + analysis_result = vqa_detector.analyse_image( + image_copy, + analysis_type=analysis_type_value, + list_of_questions=questions_list, + is_concise_summary=True, + is_concise_answer=True, + ) + analysis_dict = analysis_result or {} + except Exception as e: + warnings.warn(f"VQA/Image tasks failed: {e}") + analysis_dict = {"image_tasks_error": str(e)} else: - detector_class = identify_function(image_copy) - analysis_dict = detector_class.analyse_image() + # detector value is the string name of the chosen detector + identify_function = identify_dict[detector_value] - # Initialize an empty dictionary - new_analysis_dict = {} + if detector_value == "TextDetector": + analyse_text = ( + True if settings_text_analyse_text == ["Analyse text"] else False + ) + detector_class = identify_function( + image_copy, + analyse_text=analyse_text, + accept_privacy=( + setting_privacy_env_var + if setting_privacy_env_var + else "PRIVACY_AMMICO" + ), + ) + elif detector_value == "EmotionDetector": + detector_class = identify_function( + image_copy, + emotion_threshold=setting_emotion_emotion_threshold, + race_threshold=setting_emotion_race_threshold, + gender_threshold=setting_emotion_gender_threshold, + accept_disclosure=( + setting_emotion_env_var + if setting_emotion_env_var + else "DISCLOSURE_AMMICO" + ), + ) + elif detector_value == "ColorDetector": + detector_class = identify_function( + image_copy, + delta_e_method=setting_color_delta_e_method, + ) + else: + detector_class = identify_function(image_copy) + + analysis_dict = detector_class.analyse_image() + + new_analysis_dict: Dict[str, Any] = {} # Iterate over the items in the original dictionary for k, v in analysis_dict.items(): @@ -480,3 +575,9 @@ class AnalysisExplorer: return dbc.Table.from_dataframe( df, striped=True, bordered=True, hover=True, index=True ) + + def _show_questions_textarea_on_demand(self, analysis_type_value: str) -> dict: + if analysis_type_value in ("questions", "summary_and_questions"): + return {"display": "block", "width": "100%"} + else: + return {"display": "none"} diff --git a/ammico/image_summary.py b/ammico/image_summary.py new file mode 100644 index 0000000..adc01af --- /dev/null +++ b/ammico/image_summary.py @@ -0,0 +1,365 @@ +from ammico.utils import AnalysisMethod, AnalysisType +from ammico.model import MultimodalSummaryModel + +import os +import torch +from PIL import Image +import warnings + +from typing import List, Optional, Union, Dict, Any +from collections.abc import Sequence as _Sequence +from transformers import GenerationConfig +from qwen_vl_utils import process_vision_info + + +class ImageSummaryDetector(AnalysisMethod): + def __init__( + self, + summary_model: MultimodalSummaryModel, + subdict: Optional[Dict[str, Any]] = None, + ) -> None: + """ + Class for analysing images using QWEN-2.5-VL model. + It provides methods for generating captions and answering questions about images. + + Args: + summary_model ([type], optional): An instance of MultimodalSummaryModel to be used for analysis. + subdict (dict, optional): Dictionary containing the image to be analysed. Defaults to {}. + + Returns: + None. + """ + if subdict is None: + subdict = {} + + super().__init__(subdict) + self.summary_model = summary_model + + def _load_pil_if_needed( + self, filename: Union[str, os.PathLike, Image.Image] + ) -> Image.Image: + if isinstance(filename, (str, os.PathLike)): + return Image.open(filename).convert("RGB") + elif isinstance(filename, Image.Image): + return filename.convert("RGB") + else: + raise ValueError("filename must be a path or PIL.Image") + + @staticmethod + def _is_sequence_but_not_str(obj: Any) -> bool: + """True for sequence-like but not a string/bytes/PIL.Image.""" + return isinstance(obj, _Sequence) and not isinstance( + obj, (str, bytes, Image.Image) + ) + + def _prepare_inputs( + self, list_of_questions: list[str], entry: Optional[Dict[str, Any]] = None + ) -> Dict[str, torch.Tensor]: + filename = entry.get("filename") + if filename is None: + raise ValueError("entry must contain key 'filename'") + + if isinstance(filename, (str, os.PathLike, Image.Image)): + images_context = self._load_pil_if_needed(filename) + elif self._is_sequence_but_not_str(filename): + images_context = [self._load_pil_if_needed(i) for i in filename] + else: + raise ValueError( + "Unsupported 'filename' entry: expected path, PIL.Image, or sequence." + ) + + images_only_messages = [ + { + "role": "user", + "content": [ + *( + [{"type": "image", "image": img} for img in images_context] + if isinstance(images_context, list) + else [{"type": "image", "image": images_context}] + ) + ], + } + ] + + try: + image_inputs, _ = process_vision_info(images_only_messages) + except Exception as e: + raise RuntimeError(f"Image processing failed: {e}") + + texts: List[str] = [] + for q in list_of_questions: + messages = [ + { + "role": "user", + "content": [ + *( + [ + {"type": "image", "image": image} + for image in images_context + ] + if isinstance(images_context, list) + else [{"type": "image", "image": images_context}] + ), + {"type": "text", "text": q}, + ], + } + ] + text = self.summary_model.processor.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) + texts.append(text) + + images_batch = [image_inputs] * len(texts) + inputs = self.summary_model.processor( + text=texts, + images=images_batch, + padding=True, + return_tensors="pt", + ) + inputs = {k: v.to(self.summary_model.device) for k, v in inputs.items()} + + return inputs + + def analyse_image( + self, + entry: dict, + analysis_type: Union[str, AnalysisType] = AnalysisType.SUMMARY, + list_of_questions: Optional[List[str]] = None, + is_concise_summary: bool = True, + is_concise_answer: bool = True, + ) -> Dict[str, Any]: + """ + Analyse a single image entry. Returns dict with keys depending on analysis_type: + - 'caption' (str) if summary requested + - 'vqa' (dict) if questions requested + """ + self.subdict = entry + analysis_type, is_summary, is_questions = AnalysisType._validate_analysis_type( + analysis_type, list_of_questions + ) + + if is_summary: + try: + caps = self.generate_caption( + entry, + num_return_sequences=1, + is_concise_summary=is_concise_summary, + ) + self.subdict["caption"] = caps[0] if caps else "" + except Exception as e: + warnings.warn(f"Caption generation failed: {e}") + + if is_questions: + try: + vqa_map = self.answer_questions( + list_of_questions, entry, is_concise_answer + ) + self.subdict["vqa"] = vqa_map + except Exception as e: + warnings.warn(f"VQA failed: {e}") + + return self.subdict + + def analyse_images_from_dict( + self, + analysis_type: Union[AnalysisType, str] = AnalysisType.SUMMARY, + list_of_questions: Optional[List[str]] = None, + keys_batch_size: int = 16, + is_concise_summary: bool = True, + is_concise_answer: bool = True, + ) -> Dict[str, dict]: + """ + Analyse image with model. + + Args: + analysis_type (str): type of the analysis. + list_of_questions (list[str]): list of questions. + keys_batch_size (int): number of images to process in a batch. + is_concise_summary (bool): whether to generate concise summary. + is_concise_answer (bool): whether to generate concise answers. + Returns: + self.subdict (dict): dictionary with analysis results. + """ + # TODO: add option to ask multiple questions per image as one batch. + analysis_type, is_summary, is_questions = AnalysisType._validate_analysis_type( + analysis_type, list_of_questions + ) + + keys = list(self.subdict.keys()) + for batch_start in range(0, len(keys), keys_batch_size): + batch_keys = keys[batch_start : batch_start + keys_batch_size] + for key in batch_keys: + entry = self.subdict[key] + if is_summary: + try: + caps = self.generate_caption( + entry, + num_return_sequences=1, + is_concise_summary=is_concise_summary, + ) + entry["caption"] = caps[0] if caps else "" + except Exception as e: + warnings.warn(f"Caption generation failed: {e}") + + if is_questions: + try: + vqa_map = self.answer_questions( + list_of_questions, entry, is_concise_answer + ) + entry["vqa"] = vqa_map + except Exception as e: + warnings.warn(f"VQA failed: {e}") + + self.subdict[key] = entry + return self.subdict + + def generate_caption( + self, + entry: Optional[Dict[str, Any]] = None, + num_return_sequences: int = 1, + is_concise_summary: bool = True, + ) -> List[str]: + """ + Create caption for image. Depending on is_concise_summary it will be either concise or detailed. + + Args: + entry (dict): dictionary containing the image to be captioned. + num_return_sequences (int): number of captions to generate. + is_concise_summary (bool): whether to generate concise summary. + + Returns: + results (list[str]): list of generated captions. + """ + if is_concise_summary: + prompt = ["Describe this image in one concise caption."] + max_new_tokens = 64 + else: + prompt = ["Describe this image."] + max_new_tokens = 256 + inputs = self._prepare_inputs(prompt, entry) + + gen_conf = GenerationConfig( + max_new_tokens=max_new_tokens, + do_sample=False, + num_return_sequences=num_return_sequences, + ) + + with torch.inference_mode(): + try: + if self.summary_model.device == "cuda": + with torch.amp.autocast("cuda", enabled=True): + generated_ids = self.summary_model.model.generate( + **inputs, generation_config=gen_conf + ) + else: + generated_ids = self.summary_model.model.generate( + **inputs, generation_config=gen_conf + ) + except RuntimeError as e: + warnings.warn( + f"Retry without autocast failed: {e}. Attempting cudnn-disabled retry." + ) + cudnn_was_enabled = ( + torch.backends.cudnn.is_available() and torch.backends.cudnn.enabled + ) + if cudnn_was_enabled: + torch.backends.cudnn.enabled = False + try: + generated_ids = self.summary_model.model.generate( + **inputs, generation_config=gen_conf + ) + except Exception as retry_error: + raise RuntimeError( + f"Failed to generate ids after retry: {retry_error}" + ) from retry_error + finally: + if cudnn_was_enabled: + torch.backends.cudnn.enabled = True + + decoded = None + if "input_ids" in inputs: + in_ids = inputs["input_ids"] + trimmed = [ + out_ids[len(inp_ids) :] + for inp_ids, out_ids in zip(in_ids, generated_ids) + ] + decoded = self.summary_model.tokenizer.batch_decode( + trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False + ) + else: + decoded = self.summary_model.tokenizer.batch_decode( + generated_ids, + skip_special_tokens=True, + clean_up_tokenization_spaces=False, + ) + + results = [d.strip() for d in decoded] + return results + + def answer_questions( + self, + list_of_questions: list[str], + entry: Optional[Dict[str, Any]] = None, + is_concise_answer: bool = True, + ) -> List[str]: + """ + Create answers for list of questions about image. + Args: + list_of_questions (list[str]): list of questions. + entry (dict): dictionary containing the image to be captioned. + is_concise_answer (bool): whether to generate concise answers. + Returns: + answers (list[str]): list of answers. + """ + if is_concise_answer: + gen_conf = GenerationConfig(max_new_tokens=64, do_sample=False) + for i in range(len(list_of_questions)): + if not list_of_questions[i].strip().endswith("?"): + list_of_questions[i] = list_of_questions[i].strip() + "?" + if not list_of_questions[i].lower().startswith("answer concisely"): + list_of_questions[i] = "Answer concisely: " + list_of_questions[i] + else: + gen_conf = GenerationConfig(max_new_tokens=128, do_sample=False) + + question_chunk_size = 8 + answers: List[str] = [] + n = len(list_of_questions) + for i in range(0, n, question_chunk_size): + chunk = list_of_questions[i : i + question_chunk_size] + inputs = self._prepare_inputs(chunk, entry) + with torch.inference_mode(): + if self.summary_model.device == "cuda": + with torch.amp.autocast("cuda", enabled=True): + out_ids = self.summary_model.model.generate( + **inputs, generation_config=gen_conf + ) + else: + out_ids = self.summary_model.model.generate( + **inputs, generation_config=gen_conf + ) + + if "input_ids" in inputs: + in_ids = inputs["input_ids"] + trimmed_batch = [ + out_row[len(inp_row) :] for inp_row, out_row in zip(in_ids, out_ids) + ] + decoded = self.summary_model.tokenizer.batch_decode( + trimmed_batch, + skip_special_tokens=True, + clean_up_tokenization_spaces=False, + ) + else: + decoded = self.summary_model.tokenizer.batch_decode( + out_ids, + skip_special_tokens=True, + clean_up_tokenization_spaces=False, + ) + + answers.extend([d.strip() for d in decoded]) + + if len(answers) != len(list_of_questions): + raise ValueError( + f"Expected {len(list_of_questions)} answers, but got {len(answers)}, try vary amount of questions" + ) + + return answers diff --git a/ammico/model.py b/ammico/model.py new file mode 100644 index 0000000..cdc1161 --- /dev/null +++ b/ammico/model.py @@ -0,0 +1,126 @@ +import torch +import warnings +from transformers import ( + Qwen2_5_VLForConditionalGeneration, + AutoProcessor, + BitsAndBytesConfig, + AutoTokenizer, +) +from typing import Optional + + +class MultimodalSummaryModel: + DEFAULT_CUDA_MODEL = "Qwen/Qwen2.5-VL-7B-Instruct" + DEFAULT_CPU_MODEL = "Qwen/Qwen2.5-VL-3B-Instruct" + + def __init__( + self, + model_id: Optional[str] = None, + device: Optional[str] = None, + cache_dir: Optional[str] = None, + ) -> None: + """ + Class for QWEN-2.5-VL model loading and inference. + Args: + model_id: Type of model to load, defaults to a smaller version for CPU if device is "cpu". + device: "cuda" or "cpu" (auto-detected when None). + cache_dir: huggingface cache dir (optional). + """ + self.device = self._resolve_device(device) + + if model_id is not None and model_id not in ( + self.DEFAULT_CUDA_MODEL, + self.DEFAULT_CPU_MODEL, + ): + raise ValueError( + f"model_id must be one of {self.DEFAULT_CUDA_MODEL} or {self.DEFAULT_CPU_MODEL}" + ) + + self.model_id = model_id or ( + self.DEFAULT_CUDA_MODEL if self.device == "cuda" else self.DEFAULT_CPU_MODEL + ) + + self.cache_dir = cache_dir + self._trust_remote_code = True + self._quantize = True + + self.model = None + self.processor = None + self.tokenizer = None + + self._load_model_and_processor() + + @staticmethod + def _resolve_device(device: Optional[str]) -> str: + if device is None: + return "cuda" if torch.cuda.is_available() else "cpu" + if device.lower() not in ("cuda", "cpu"): + raise ValueError("device must be 'cuda' or 'cpu'") + if device.lower() == "cuda" and not torch.cuda.is_available(): + warnings.warn( + "Although 'cuda' was requested, no CUDA device is available. Using CPU instead.", + RuntimeWarning, + stacklevel=2, + ) + return "cpu" + return device.lower() + + def _load_model_and_processor(self): + load_kwargs = {"trust_remote_code": self._trust_remote_code, "use_cache": True} + if self.cache_dir: + load_kwargs["cache_dir"] = self.cache_dir + + self.processor = AutoProcessor.from_pretrained( + self.model_id, padding_side="left", **load_kwargs + ) + self.tokenizer = AutoTokenizer.from_pretrained(self.model_id, **load_kwargs) + + if self.device == "cuda": + compute_dtype = ( + torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 + ) + bnb_config = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type="nf4", + bnb_4bit_compute_dtype=compute_dtype, + ) + load_kwargs["quantization_config"] = bnb_config + load_kwargs["device_map"] = "auto" + + else: + load_kwargs.pop("quantization_config", None) + load_kwargs.pop("device_map", None) + + self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained( + self.model_id, **load_kwargs + ) + self.model.eval() + + def _close(self) -> None: + """Free model resources (helpful in long-running processes).""" + try: + if self.model is not None: + del self.model + self.model = None + if self.processor is not None: + del self.processor + self.processor = None + if self.tokenizer is not None: + del self.tokenizer + self.tokenizer = None + finally: + try: + if torch.cuda.is_available(): + torch.cuda.empty_cache() + except Exception as e: + warnings.warn( + "Failed to empty CUDA cache. This is not critical, but may lead to memory lingering: " + f"{e!r}", + RuntimeWarning, + stacklevel=2, + ) + + def close(self) -> None: + """Free model resources (helpful in long-running processes).""" + self._close() diff --git a/ammico/notebooks/DemoImageSummaryVQA.ipynb b/ammico/notebooks/DemoImageSummaryVQA.ipynb new file mode 100644 index 0000000..b067e1f --- /dev/null +++ b/ammico/notebooks/DemoImageSummaryVQA.ipynb @@ -0,0 +1,190 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "0", + "metadata": {}, + "source": [ + "# Image summary and visual question answering" + ] + }, + { + "cell_type": "markdown", + "id": "1", + "metadata": {}, + "source": [ + "This notebook shows how to generate image captions and use the visual question answering with AMMICO. \n", + "\n", + "The first cell imports `ammico`.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2", + "metadata": {}, + "outputs": [], + "source": [ + "import ammico" + ] + }, + { + "cell_type": "markdown", + "id": "3", + "metadata": {}, + "source": [ + "The cell below loads the model for VQA tasks. By default, it loads a large model on the GPU (if your device supports CUDA), otherwise it loads a relatively smaller model on the CPU. But you can specify other settings (e.g., a small model on the GPU) if you want." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4", + "metadata": {}, + "outputs": [], + "source": [ + "model = ammico.MultimodalSummaryModel()" + ] + }, + { + "cell_type": "markdown", + "id": "5", + "metadata": {}, + "source": [ + "Here you need to provide the path to your google drive folder or local folder containing the images" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6", + "metadata": {}, + "outputs": [], + "source": [ + "image_dict = ammico.find_files(\n", + " path=str(\"/insert/your/path/here/\"),\n", + " limit=-1, # -1 means no limit on the number of files, by default it is set to 20\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "7", + "metadata": {}, + "source": [ + "The cell below creates an object that analyzes images and generates a summary using a specific model and image data." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8", + "metadata": {}, + "outputs": [], + "source": [ + "img = ammico.ImageSummaryDetector(summary_model=model, subdict=image_dict)" + ] + }, + { + "cell_type": "markdown", + "id": "9", + "metadata": {}, + "source": [ + "## Image summary " + ] + }, + { + "cell_type": "markdown", + "id": "10", + "metadata": {}, + "source": [ + "To start your work with images, you should call the `analyse_images` method.\n", + "\n", + "You can specify what kind of analysis you want to perform with `analysis_type`. `\"summary\"` will generate a summary for all pictures in your dictionary, `\"questions\"` will prepare answers to your questions for all pictures, and `\"summary_and_questions\"` will do both.\n", + "\n", + "Parameter `\"is_concise_summary\"` regulates the length of an answer.\n", + "\n", + "Here we want to get a long summary on each object in our image dictionary." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "11", + "metadata": {}, + "outputs": [], + "source": [ + "summaries = img.analyse_images(analysis_type=\"summary\", is_concise_summary=False)" + ] + }, + { + "cell_type": "markdown", + "id": "12", + "metadata": {}, + "source": [ + "## VQA" + ] + }, + { + "cell_type": "markdown", + "id": "13", + "metadata": {}, + "source": [ + "In addition to analyzing images in `ammico`, the same model can be used in VQA mode. To do this, you need to define the questions that will be applied to all images from your dict." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "14", + "metadata": {}, + "outputs": [], + "source": [ + "questions = [\"Are there any visible signs of violence?\", \"Is it safe to be there?\"]" + ] + }, + { + "cell_type": "markdown", + "id": "15", + "metadata": {}, + "source": [ + "Here is an example of VQA mode usage. You can specify whether you want to receive short answers (recommended option) or not." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "16", + "metadata": {}, + "outputs": [], + "source": [ + "vqa_results = img.analyse_images(\n", + " analysis_type=\"questions\",\n", + " list_of_questions=questions,\n", + " is_concise_answer=True,\n", + ")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "ammico-dev", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.13" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/ammico/notebooks/DemoNotebook_ammico.ipynb b/ammico/notebooks/DemoNotebook_ammico.ipynb index fc8fe22..e17860c 100644 --- a/ammico/notebooks/DemoNotebook_ammico.ipynb +++ b/ammico/notebooks/DemoNotebook_ammico.ipynb @@ -104,7 +104,8 @@ "import ammico\n", "\n", "# for displaying a progress bar\n", - "from tqdm import tqdm" + "from tqdm import tqdm\n", + "import os" ] }, { @@ -140,7 +141,9 @@ "metadata": {}, "outputs": [], "source": [ - "# os.environ[\"GOOGLE_APPLICATION_CREDENTIALS\"] = \"/content/drive/MyDrive/misinformation-data/misinformation-campaign-981aa55a3b13.json\"" + "os.environ[\"GOOGLE_APPLICATION_CREDENTIALS\"] = (\n", + " \"/home/inga/projects/misinformation-project/misinformation-notes/misinformation-campaign-981aa55a3b13.json\"\n", + ")" ] }, { @@ -171,6 +174,7 @@ "metadata": {}, "outputs": [], "source": [ + "data_path = \"./data-test\"\n", "image_dict = ammico.find_files(\n", " # path = \"/content/drive/MyDrive/misinformation-data/\",\n", " path=str(data_path),\n", @@ -337,7 +341,7 @@ " enumerate(image_dict.keys()), total=len(image_dict)\n", "): # loop through all images\n", " image_dict[key] = ammico.TextDetector(\n", - " image_dict[key], analyse_text=True\n", + " image_dict[key]\n", " ).analyse_image() # analyse image with EmotionDetector and update dict\n", "\n", " if (\n", @@ -361,23 +365,12 @@ "outputs": [], "source": [ "# initialize the models\n", - "image_summary_detector = ammico.SummaryDetector(\n", - " subdict=image_dict, analysis_type=\"summary\", model_type=\"base\"\n", + "model = ammico.MultimodalSummaryModel()\n", + "image_summary_detector = ammico.ImageSummaryDetector(\n", + " subdict=image_dict, summary_model=model\n", ")\n", "\n", - "# run the analysis without having to re-iniatialize the model\n", - "for num, key in tqdm(\n", - " enumerate(image_dict.keys()), total=len(image_dict)\n", - "): # loop through all images\n", - " image_dict[key] = image_summary_detector.analyse_image(\n", - " subdict=image_dict[key], analysis_type=\"summary\"\n", - " ) # analyse image with SummaryDetector and update dict\n", - "\n", - " if (\n", - " num % dump_every == 0 | num == len(image_dict) - 1\n", - " ): # save results every dump_every to dump_file\n", - " image_df = ammico.get_dataframe(image_dict)\n", - " image_df.to_csv(dump_file)" + "image_summary_detector.analyse_images(analysis_type=\"summary\")" ] }, { @@ -394,6 +387,7 @@ "outputs": [], "source": [ "# initialize the models\n", + "# currently this does not work because of the way the summary detector is implemented\n", "image_summary_detector = ammico.SummaryDetector(\n", " subdict=image_dict, analysis_type=\"summary\", model_type=\"base\"\n", ")\n", diff --git a/ammico/notebooks/DemoVideoSummaryVQA.ipynb b/ammico/notebooks/DemoVideoSummaryVQA.ipynb new file mode 100644 index 0000000..2e80165 --- /dev/null +++ b/ammico/notebooks/DemoVideoSummaryVQA.ipynb @@ -0,0 +1,96 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Video summary and visual question answering" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import ammico" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Currently this module supports only video summarization, but it will be updated in the nearest future" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "video_dict = ammico.find_videos(\n", + " path=str(\"/insert/your/path/here/\"), # path to the folder with images\n", + " limit=-1, # -1 means no limit on the number of files, by default it is set to 20\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "model = ammico.MultimodalSummaryModel()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "vid_summary_model = ammico.VideoSummaryDetector(summary_model=model, subdict=video_dict)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "summary_dict = vid_summary_model.analyse_video()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "summary_dict[\"summary\"]" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "ammico-dev", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.13" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/ammico/test/conftest.py b/ammico/test/conftest.py index cb42774..2010e1e 100644 --- a/ammico/test/conftest.py +++ b/ammico/test/conftest.py @@ -1,5 +1,6 @@ import os import pytest +from ammico.model import MultimodalSummaryModel @pytest.fixture @@ -46,3 +47,12 @@ def get_test_my_dict(get_path): }, } return test_my_dict + + +@pytest.fixture(scope="session") +def model(): + m = MultimodalSummaryModel(device="cpu") + try: + yield m + finally: + m.close() diff --git a/ammico/test/test_display.py b/ammico/test/test_display.py index 83d53dc..3cdb333 100644 --- a/ammico/test/test_display.py +++ b/ammico/test/test_display.py @@ -50,6 +50,8 @@ def test_right_output_analysis_emotions(get_AE, get_options, monkeypatch): get_options[3], get_options[0], "EmotionDetector", + "summary", + "Some question", True, "SOME_VAR", 50, diff --git a/ammico/test/test_image_summary.py b/ammico/test/test_image_summary.py new file mode 100644 index 0000000..ad48298 --- /dev/null +++ b/ammico/test/test_image_summary.py @@ -0,0 +1,37 @@ +from ammico.image_summary import ImageSummaryDetector + +import pytest + + +@pytest.mark.long +def test_image_summary_detector(model, get_testdict): + detector = ImageSummaryDetector(summary_model=model, subdict=get_testdict) + results = detector.analyse_images_from_dict(analysis_type="summary") + assert len(results) == 2 + for key in get_testdict.keys(): + assert key in results + assert "caption" in results[key] + assert isinstance(results[key]["caption"], str) + assert len(results[key]["caption"]) > 0 + + +@pytest.mark.long +def test_image_summary_detector_questions(model, get_testdict): + list_of_questions = [ + "What is happening in the image?", + "How many cars are in the image in total?", + ] + detector = ImageSummaryDetector(summary_model=model, subdict=get_testdict) + results = detector.analyse_images_from_dict( + analysis_type="questions", list_of_questions=list_of_questions + ) + assert len(results) == 2 + for key in get_testdict.keys(): + assert "vqa" in results[key] + if key == "IMG_2746": + assert "marathon" in results[key]["vqa"][0].lower() + + if key == "IMG_2809": + assert ( + "two" in results[key]["vqa"][1].lower() or "2" in results[key]["vqa"][1] + ) diff --git a/ammico/test/test_model.py b/ammico/test/test_model.py new file mode 100644 index 0000000..d82dd86 --- /dev/null +++ b/ammico/test/test_model.py @@ -0,0 +1,30 @@ +import pytest +from ammico.model import MultimodalSummaryModel + + +@pytest.mark.long +def test_model_init(model): + assert model.model is not None + assert model.processor is not None + assert model.tokenizer is not None + assert model.device is not None + + +@pytest.mark.long +def test_model_invalid_device(): + with pytest.raises(ValueError): + MultimodalSummaryModel(device="invalid_device") + + +@pytest.mark.long +def test_model_invalid_model_id(): + with pytest.raises(ValueError): + MultimodalSummaryModel(model_id="non_existent_model", device="cpu") + + +@pytest.mark.long +def test_free_resources(): + model = MultimodalSummaryModel(device="cpu") + model.close() + assert model.model is None + assert model.processor is None diff --git a/ammico/test/test_text.py b/ammico/test/test_text.py index cffb321..cd9f863 100644 --- a/ammico/test/test_text.py +++ b/ammico/test/test_text.py @@ -52,24 +52,16 @@ def test_privacy_statement(monkeypatch): def test_TextDetector(set_testdict, accepted): for item in set_testdict: test_obj = tt.TextDetector(set_testdict[item], accept_privacy=accepted) - assert not test_obj.analyse_text assert not test_obj.skip_extraction assert test_obj.subdict["filename"] == set_testdict[item]["filename"] - test_obj = tt.TextDetector( - {}, analyse_text=True, skip_extraction=True, accept_privacy=accepted - ) - assert test_obj.analyse_text + test_obj = tt.TextDetector({}, skip_extraction=True, accept_privacy=accepted) assert test_obj.skip_extraction - with pytest.raises(ValueError): - tt.TextDetector({}, analyse_text=1.0, accept_privacy=accepted) with pytest.raises(ValueError): tt.TextDetector({}, skip_extraction=1.0, accept_privacy=accepted) def test_run_spacy(set_testdict, get_path, accepted): - test_obj = tt.TextDetector( - set_testdict["IMG_3755"], analyse_text=True, accept_privacy=accepted - ) + test_obj = tt.TextDetector(set_testdict["IMG_3755"], accept_privacy=accepted) ref_file = get_path + "text_IMG_3755.txt" with open(ref_file, "r") as file: reference_text = file.read() @@ -108,15 +100,11 @@ def test_analyse_image(set_testdict, set_environ, accepted): for item in set_testdict: test_obj = tt.TextDetector(set_testdict[item], accept_privacy=accepted) test_obj.analyse_image() - test_obj = tt.TextDetector( - set_testdict[item], analyse_text=True, accept_privacy=accepted - ) + test_obj = tt.TextDetector(set_testdict[item], accept_privacy=accepted) test_obj.analyse_image() testdict = {} testdict["text"] = 20000 * "m" - test_obj = tt.TextDetector( - testdict, skip_extraction=True, analyse_text=True, accept_privacy=accepted - ) + test_obj = tt.TextDetector(testdict, skip_extraction=True, accept_privacy=accepted) test_obj.analyse_image() assert test_obj.subdict["text_truncated"] == 5000 * "m" assert test_obj.subdict["text"] == 20000 * "m" diff --git a/ammico/text.py b/ammico/text.py index bf39cc6..4bec28c 100644 --- a/ammico/text.py +++ b/ammico/text.py @@ -67,7 +67,6 @@ class TextDetector(AnalysisMethod): def __init__( self, subdict: dict, - analyse_text: bool = False, skip_extraction: bool = False, accept_privacy: str = "PRIVACY_AMMICO", ) -> None: @@ -76,8 +75,6 @@ class TextDetector(AnalysisMethod): Args: subdict (dict): Dictionary containing file name/path, and possibly previous analysis results from other modules. - analyse_text (bool, optional): Decide if extracted text will be further subject - to analysis. Defaults to False. skip_extraction (bool, optional): Decide if text will be extracted from images or is already provided via a csv. Defaults to False. accept_privacy (str, optional): Environment variable to accept the privacy @@ -96,17 +93,13 @@ class TextDetector(AnalysisMethod): "Privacy disclosure not accepted - skipping text detection." ) self.translator = Translator(raise_exception=True) - if not isinstance(analyse_text, bool): - raise ValueError("analyse_text needs to be set to true or false") - self.analyse_text = analyse_text self.skip_extraction = skip_extraction if not isinstance(skip_extraction, bool): raise ValueError("skip_extraction needs to be set to true or false") if self.skip_extraction: print("Skipping text extraction from image.") print("Reading text directly from provided dictionary.") - if self.analyse_text: - self._initialize_spacy() + self._initialize_spacy() def set_keys(self) -> dict: """Set the default keys for text analysis. @@ -183,7 +176,7 @@ class TextDetector(AnalysisMethod): self._truncate_text() self.translate_text() self.remove_linebreaks() - if self.analyse_text and self.subdict["text_english"]: + if self.subdict["text_english"]: self._run_spacy() return self.subdict diff --git a/ammico/utils.py b/ammico/utils.py index 39a0ecb..07bae90 100644 --- a/ammico/utils.py +++ b/ammico/utils.py @@ -5,6 +5,9 @@ import pooch import importlib_resources import collections import random +from enum import Enum +from typing import List, Tuple, Optional, Union +import re pkg = importlib_resources.files("ammico") @@ -40,6 +43,41 @@ def ammico_prefetch_models(): res.get() +class AnalysisType(str, Enum): + SUMMARY = "summary" + QUESTIONS = "questions" + SUMMARY_AND_QUESTIONS = "summary_and_questions" + + @classmethod + def _validate_analysis_type( + cls, + analysis_type: Union["AnalysisType", str], + list_of_questions: Optional[List[str]], + ) -> Tuple[str, bool, bool]: + max_questions_per_image = 15 # safety cap to avoid too many questions + if isinstance(analysis_type, AnalysisType): + analysis_type = analysis_type.value + + allowed = {item.value for item in AnalysisType} + if analysis_type not in allowed: + raise ValueError(f"analysis_type must be one of {allowed}") + + if analysis_type in ("questions", "summary_and_questions"): + if not list_of_questions: + raise ValueError( + "list_of_questions must be provided for QUESTIONS analysis type." + ) + + if len(list_of_questions) > max_questions_per_image: + raise ValueError( + f"Number of questions per image ({len(list_of_questions)}) exceeds safety cap ({max_questions_per_image}). Reduce questions or increase max_questions_per_image." + ) + + is_summary = analysis_type in ("summary", "summary_and_questions") + is_questions = analysis_type in ("questions", "summary_and_questions") + return analysis_type, is_summary, is_questions + + class AnalysisMethod: """Base class to be inherited by all analysis methods.""" @@ -94,6 +132,113 @@ def _limit_results(results, limit): return results +def _categorize_outputs( + collected: List[Tuple[float, str]], + include_questions: bool = False, +) -> Tuple[List[str], List[str]]: + """ + Categorize collected outputs into summary bullets and VQA bullets. + Args: + collected (List[Tuple[float, str]]): List of tuples containing timestamps and generated texts. + Returns: + Tuple[List[str], List[str]]: A tuple containing two lists - summary bullets and VQA bullets. + """ + MAX_CAPTIONS_FOR_SUMMARY = 600 # TODO For now, this is a constant value, but later we need to make it adjustable, with the idea of cutting out the most similar frames to reduce the load on the system. + caps_for_summary_vqa = ( + collected[-MAX_CAPTIONS_FOR_SUMMARY:] + if len(collected) > MAX_CAPTIONS_FOR_SUMMARY + else collected + ) + bullets_summary = [] + bullets_vqa = [] + + for t, c in caps_for_summary_vqa: + if include_questions: + result_sections = c.strip() + m = re.search( + r"Summary\s*:\s*(.*?)\s*(?:VQA\s+Answers\s*:\s*(.*))?$", + result_sections, + flags=re.IGNORECASE | re.DOTALL, + ) + if m: + summary_text = ( + m.group(1).replace("\n", " ").strip() if m.group(1) else None + ) + vqa_text = m.group(2).strip() if m.group(2) else None + if not summary_text or not vqa_text: + raise ValueError( + f"Model output is missing either summary or VQA answers: {c}" + ) + bullets_summary.append(f"- [{t:.3f}s] {summary_text}") + bullets_vqa.append(f"- [{t:.3f}s] {vqa_text}") + else: + raise ValueError( + f"Failed to parse summary and VQA answers from model output: {c}" + ) + else: + snippet = c.replace("\n", " ").strip() + bullets_summary.append(f"- [{t:.3f}s] {snippet}") + return bullets_summary, bullets_vqa + + +def _normalize_whitespace(s: str) -> str: + return re.sub(r"\s+", " ", s).strip() + + +def _strip_prompt_prefix_literal(decoded: str, prompt: str) -> str: + """ + Remove any literal prompt prefix from decoded text using a normalized-substring match. + Guarantees no prompt text remains at the start of returned string (best-effort). + """ + if not decoded: + return "" + if not prompt: + return decoded.strip() + + d_norm = _normalize_whitespace(decoded) + p_norm = _normalize_whitespace(prompt) + + idx = d_norm.find(p_norm) + if idx != -1: + running = [] + for i, ch in enumerate(decoded): + running.append(ch if not ch.isspace() else " ") + cur_norm = _normalize_whitespace("".join(running)) + if cur_norm.endswith(p_norm): + return decoded[i + 1 :].lstrip() if i + 1 < len(decoded) else "" + m = re.match( + r"^(?:\s*(system|user|assistant)[:\s-]*\n?)+", decoded, flags=re.IGNORECASE + ) + if m: + return decoded[m.end() :].lstrip() + + return decoded.lstrip("\n\r ").lstrip(":;- ").strip() + + +def find_videos( + path: str = None, + pattern=["mp4"], # TODO: test with more video formats + recursive: bool = True, + limit=5, + random_seed: int = None, +) -> dict: + """Find video files on the file system.""" + if path is None: + path = os.environ.get("AMMICO_DATA_HOME", ".") + if isinstance(pattern, str): + pattern = [pattern] + results = [] + for p in pattern: + results.extend(_match_pattern(path, p, recursive=recursive)) + if len(results) == 0: + raise FileNotFoundError(f"No files found in {path} with pattern '{pattern}'") + if random_seed is not None: + random.seed(random_seed) + random.shuffle(results) + videos = _limit_results(results, limit) + return initialize_dict(videos) + + def find_files( path: str = None, pattern=["png", "jpg", "jpeg", "gif", "webp", "avif", "tiff"], diff --git a/ammico/video_summary.py b/ammico/video_summary.py new file mode 100644 index 0000000..391f2d0 --- /dev/null +++ b/ammico/video_summary.py @@ -0,0 +1,514 @@ +import re +import math +import torch +import warnings +from PIL import Image +from torchcodec.decoders import VideoDecoder + +from ammico.model import MultimodalSummaryModel +from ammico.utils import ( + AnalysisMethod, + AnalysisType, + _categorize_outputs, + _strip_prompt_prefix_literal, +) + +from typing import List, Dict, Any, Generator, Tuple, Union, Optional +from transformers import GenerationConfig + + +class VideoSummaryDetector(AnalysisMethod): + MAX_SAMPLES_CAP = 1000 # safety cap for total extracted frames + + def __init__( + self, + summary_model: MultimodalSummaryModel, + subdict: Optional[Dict[str, Any]] = None, + ) -> None: + """ + Class for analysing videos using QWEN-2.5-VL model. + It provides methods for generating captions and answering questions about videos. + + Args: + summary_model ([type], optional): An instance of MultimodalSummaryModel to be used for analysis. + subdict (dict, optional): Dictionary containing the video to be analysed. Defaults to {}. + + Returns: + None. + """ + if subdict is None: + subdict = {} + + super().__init__(subdict) + self.summary_model = summary_model + + def _frame_batch_generator( + self, + timestamps: torch.Tensor, + batch_size: int, + video_decoder: VideoDecoder, + ) -> Generator[Tuple[torch.Tensor, torch.Tensor], None, None]: + """ + Yield batches of (frames, timestamps) for given frame indices. + - frames are returned as a torch.Tensor with shape (B, C, H, W). + - timestamps is a 1D torch.Tensor with B elements. + """ + total = int(timestamps.numel()) + + for start in range(0, total, batch_size): + batch_secs = timestamps[start : start + batch_size].tolist() + fb = video_decoder.get_frames_played_at(batch_secs) + frames = fb.data + + if not frames.is_contiguous(): + frames = frames.contiguous() + pts = fb.pts_seconds + pts_out = ( + pts.cpu().to(dtype=torch.float32) + if isinstance(pts, torch.Tensor) + else torch.tensor(pts, dtype=torch.float32) + ) + yield frames, pts_out + + def _extract_video_frames( + self, + entry: Dict[str, Any], + frame_rate_per_second: float = 2.0, + batch_size: int = 32, + ) -> Generator[Tuple[torch.Tensor, torch.Tensor], None, None]: + """ + Extract frames from a video at a specified frame rate and return them as a generator of batches. + Args: + filename (Union[str, os.PathLike]): Path to the video file. + frame_rate_per_second (float, optional): Frame extraction rate in frames per second. Default is 2. + batch_size (int, optional): Number of frames to include in each batch. Default is 32. + Returns: + Generator[Tuple[torch.Tensor, torch.Tensor], None, None]: A generator yielding tuples of + (frames, timestamps), where frames is a tensor of shape (B, C, H, W) and timestamps is a 1D tensor of length B. + """ + + filename = entry.get("filename") + if not filename: + raise ValueError("entry must contain key 'filename'") + + video_decoder = VideoDecoder(filename) + meta = video_decoder.metadata + + video_fps = getattr(meta, "average_fps", None) + if video_fps is None or not ( + isinstance(video_fps, (int, float)) and video_fps > 0 + ): + video_fps = 30.0 + + begin_stream_seconds = getattr(meta, "begin_stream_seconds", None) + end_stream_seconds = getattr(meta, "end_stream_seconds", None) + nframes = len(video_decoder) + if getattr(meta, "duration_seconds", None) is not None: + duration = float(meta.duration_seconds) + elif begin_stream_seconds is not None and end_stream_seconds is not None: + duration = float(end_stream_seconds) - float(begin_stream_seconds) + elif nframes: + duration = float(nframes) / float(video_fps) + else: + duration = 0.0 + + if frame_rate_per_second <= 0: + raise ValueError("frame_rate_per_second must be > 0") + + n_samples = max(1, int(math.floor(duration * frame_rate_per_second))) + n_samples = min(n_samples, self.MAX_SAMPLES_CAP) + + if begin_stream_seconds is not None and end_stream_seconds is not None: + sample_times = torch.linspace( + float(begin_stream_seconds), float(end_stream_seconds), steps=n_samples + ) + if sample_times.numel() > 1: + sample_times = torch.clamp( + sample_times, + min=float(begin_stream_seconds), + max=float(end_stream_seconds) - 1e-6, + ) + else: + sample_times = torch.linspace(0.0, max(0.0, duration), steps=n_samples) + + sample_times = sample_times.to(dtype=torch.float32, device="cpu") + generator = self._frame_batch_generator(sample_times, batch_size, video_decoder) + + return generator + + def _decode_trimmed_outputs( + self, + generated_ids: torch.Tensor, + inputs: Dict[str, torch.Tensor], + tokenizer, + prompt_texts: List[str], + ) -> List[str]: + """ + Trim prompt tokens using attention_mask/input_ids when available and decode to strings. + Then remove any literal prompt prefix using prompt_texts (one per batch element). + """ + + batch_size = generated_ids.shape[0] + + if "input_ids" in inputs: + token_for_padding = ( + tokenizer.pad_token_id + if getattr(tokenizer, "pad_token_id", None) is not None + else getattr(tokenizer, "eos_token_id", None) + ) + if token_for_padding is None: + lengths = [int(inputs["input_ids"].shape[1])] * batch_size + else: + lengths = inputs["input_ids"].ne(token_for_padding).sum(dim=1).tolist() + else: + lengths = [0] * batch_size + + trimmed_ids = [] + for i in range(batch_size): + out_ids = generated_ids[i] + in_len = int(lengths[i]) if i < len(lengths) else 0 + if out_ids.size(0) > in_len: + t = out_ids[in_len:] + else: + t = out_ids.new_empty((0,), dtype=out_ids.dtype) + t_cpu = t.to("cpu") + trimmed_ids.append(t_cpu.tolist()) + + decoded = tokenizer.batch_decode( + trimmed_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True + ) + decoded_results = [] + for ptext, raw in zip(prompt_texts, decoded): + cleaned = _strip_prompt_prefix_literal(raw, ptext) + decoded_results.append(cleaned) + return decoded_results + + def _generate_from_processor_inputs( + self, + processor_inputs: Dict[str, torch.Tensor], + prompt_texts: List[str], + tokenizer, + ): + """ + Run model.generate on already-processed processor_inputs (tensors moved to device), + then decode and trim prompt tokens & remove literal prompt prefixes using prompt_texts. + """ + gen_conf = GenerationConfig( + max_new_tokens=64, + do_sample=False, + num_return_sequences=1, + ) + + for k, v in list(processor_inputs.items()): + if isinstance(v, torch.Tensor): + processor_inputs[k] = v.to(self.summary_model.device) + + with torch.inference_mode(): + try: + if self.summary_model.device == "cuda": + with torch.amp.autocast("cuda", enabled=True): + generated_ids = self.summary_model.model.generate( + **processor_inputs, generation_config=gen_conf + ) + else: + generated_ids = self.summary_model.model.generate( + **processor_inputs, generation_config=gen_conf + ) + except RuntimeError as e: + warnings.warn( + f"Generation failed with error: {e}. Retrying with cuDNN disabled.", + RuntimeWarning, + ) + cudnn_was_enabled = ( + torch.backends.cudnn.is_available() and torch.backends.cudnn.enabled + ) + if cudnn_was_enabled: + torch.backends.cudnn.enabled = False + try: + generated_ids = self.summary_model.model.generate( + **processor_inputs, generation_config=gen_conf + ) + except Exception as retry_error: + raise RuntimeError( + f"Failed to generate ids after retry: {retry_error}" + ) from retry_error + finally: + if cudnn_was_enabled: + torch.backends.cudnn.enabled = True + + decoded = self._decode_trimmed_outputs( + generated_ids, processor_inputs, tokenizer, prompt_texts + ) + return decoded + + def _tensor_batch_to_pil_list(self, batch: torch.Tensor) -> List[Image.Image]: + """ + Convert a uint8 torch tensor batch (B, C, H, W) on CPU to list of PIL images (RGB). + The conversion is done on CPU and returns PIL.Image objects. + """ + if batch.device.type != "cpu": + batch = batch.to("cpu") + + batch = batch.contiguous() + if batch.dtype.is_floating_point: + batch = (batch.clamp(0.0, 1.0) * 255.0).to(torch.uint8) + elif batch.dtype != torch.uint8: + batch = batch.to(torch.uint8) + pil_list: List[Image.Image] = [] + for frame in batch: + arr = frame.permute(1, 2, 0).numpy() + pil_list.append(Image.fromarray(arr)) + return pil_list + + def make_captions_from_extracted_frames( + self, + extracted_video_gen: Generator[Tuple[torch.Tensor, torch.Tensor], None, None], + list_of_questions: Optional[List[str]] = None, + ) -> Dict[str, Any]: + """ + Generate captions for all extracted frames and then produce a concise summary of the video. + Args: + extracted_video_dict (Dict[str, Any]): Dictionary containing the frame generator and number of frames. + summary_instruction (str, optional): Instruction for summarizing the captions. Defaults to a concise paragraph. + Returns: + Dict[str, Any]: A dictionary containing the list of captions with timestamps and the final summary. + """ + + caption_instruction = "Describe this image in one concise caption." + include_questions = bool(list_of_questions) + if include_questions: + q_block = "\n".join( + [f"{i + 1}. {q.strip()}" for i, q in enumerate(list_of_questions)] + ) + caption_instruction += ( + " In addition to the concise caption, also answer the following questions based ONLY on the image. Answers must be very brief and concise." + " Produce exactly two labeled sections: \n\n" + "Summary: \n\n" + "VQA Answers: \n1. \n2. \n etc." + "\nReturn only those two sections for each image (do not add extra commentary)." + "\nIf the answer cannot be determined based on the provided answer blocks," + ' reply with the line "The answer cannot be determined based on the information provided."' + f"\n\nQuestions:\n{q_block}" + ) + collected: List[Tuple[float, str]] = [] + proc = self.summary_model.processor + try: + for batch_frames, batch_times in extracted_video_gen: + pil_list = self._tensor_batch_to_pil_list(batch_frames.cpu()) + + prompt_texts = [] + for p in pil_list: + messages = [ + { + "role": "user", + "content": [ + {"type": "image", "image": p}, + {"type": "text", "text": caption_instruction}, + ], + } + ] + + prompt_text = proc.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) + prompt_texts.append(prompt_text) + + processor_inputs = proc( + text=prompt_texts, + images=pil_list, + return_tensors="pt", + padding=True, + ) + captions = self._generate_from_processor_inputs( + processor_inputs, + prompt_texts, + self.summary_model.tokenizer, + ) + + if isinstance(batch_times, torch.Tensor): + batch_times_list = batch_times.cpu().tolist() + else: + batch_times_list = list(batch_times) + for t, c in zip(batch_times_list, captions): + collected.append((float(t), c)) + finally: + try: + extracted_video_gen.close() + except Exception: + warnings.warn("Failed to close video frame generator.", RuntimeWarning) + + collected.sort(key=lambda x: x[0]) + bullets_summary, bullets_vqa = _categorize_outputs(collected, include_questions) + + return { + "summary_bullets": bullets_summary, + "vqa_bullets": bullets_vqa, + } # TODO consider taking out time stamps from the returned structure + + def final_summary(self, summary_dict: Dict[str, Any]) -> Dict[str, Any]: + """ + Produce a concise summary of the video, based on generated captions for all extracted frames. + Args: + summary_dict (Dict[str, Any]): Dictionary containing captions for the frames. + Returns: + Dict[str, Any]: A dictionary containing the list of captions with timestamps and the final summary. + """ + summary_instruction = "Analyze the following captions from multiple frames of the same video and summarize the overall content of the video in one concise paragraph (1-3 sentences). Focus on the key themes, actions, or events across the video, not just the individual frames." + proc = self.summary_model.processor + + bullets = summary_dict.get("summary_bullets", []) + if not bullets: + raise ValueError("No captions available for summary generation.") + + combined_captions_text = "\n".join(bullets) + summary_user_text = summary_instruction + "\n\n" + combined_captions_text + "\n" + + messages = [ + { + "role": "user", + "content": [{"type": "text", "text": summary_user_text}], + } + ] + + summary_prompt_text = proc.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) + + summary_inputs = proc( + text=[summary_prompt_text], return_tensors="pt", padding=True + ) + + summary_inputs = { + k: v.to(self.summary_model.device) if isinstance(v, torch.Tensor) else v + for k, v in summary_inputs.items() + } + final_summary_list = self._generate_from_processor_inputs( + summary_inputs, + [summary_prompt_text], + self.summary_model.tokenizer, + ) + final_summary = final_summary_list[0].strip() if final_summary_list else "" + + return { + "summary": final_summary, + } + + def final_answers( + self, + answers_dict: Dict[str, Any], + list_of_questions: List[str], + ) -> Dict[str, Any]: + """ + Answer the list of questions for the video based on the VQA bullets from the frames. + Args: + answers_dict (Dict[str, Any]): Dictionary containing the VQA bullets. + Returns: + Dict[str, Any]: A dictionary containing the list of answers to the questions. + """ + vqa_bullets = answers_dict.get("vqa_bullets", []) + if not vqa_bullets: + raise ValueError( + "No VQA bullets generated for single frames available for answering questions." + ) + + include_questions = bool(list_of_questions) + if include_questions: + q_block = "\n".join( + [f"{i + 1}. {q.strip()}" for i, q in enumerate(list_of_questions)] + ) + prompt = ( + "You are provided with a set of short VQA-captions, each of which is a block of short answers" + " extracted from individual frames of the same video.\n\n" + "VQA-captions (use ONLY these to answer):\n" + f"{vqa_bullets}\n\n" + "Answer the following questions briefly, based ONLY on the lists of answers provided above. The VQA-captions above contain answers" + " to the same questions you are about to answer. If the answer cannot be determined based on the provided answer blocks," + ' reply with the line "The answer cannot be determined based on the information provided."' + "Questions:\n" + f"{q_block}\n\n" + "Produce an ordered list with answers in the same order as the questions. You must have this structure of your output: " + "Answers: \n1. \n2. \n etc." + "Return ONLY the ordered list with answers and NOTHING else — no commentary, no explanation, no surrounding markdown." + ) + else: + raise ValueError( + "list_of_questions must be provided for making final answers." + ) + + proc = self.summary_model.processor + messages = [ + { + "role": "user", + "content": [{"type": "text", "text": prompt}], + } + ] + final_vqa_prompt_text = proc.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) + final_vqa_inputs = proc( + text=[final_vqa_prompt_text], return_tensors="pt", padding=True + ) + final_vqa_inputs = { + k: v.to(self.summary_model.device) if isinstance(v, torch.Tensor) else v + for k, v in final_vqa_inputs.items() + } + + final_vqa_list = self._generate_from_processor_inputs( + final_vqa_inputs, + [final_vqa_prompt_text], + self.summary_model.tokenizer, + ) + + final_vqa_output = final_vqa_list[0].strip() if final_vqa_list else "" + vqa_answers = [] + answer_matches = re.findall( + r"\d+\.\s*(.*?)(?=\n\d+\.|$)", final_vqa_output, flags=re.DOTALL + ) + for answer in answer_matches: + vqa_answers.append(answer.strip()) + return { + "vqa_answers": vqa_answers, + } + + def analyse_videos_from_dict( + self, + analysis_type: Union[AnalysisType, str] = AnalysisType.SUMMARY, + frame_rate_per_second: float = 2.0, + list_of_questions: Optional[List[str]] = None, + ) -> Dict[str, Any]: + """ + Analyse the video specified in self.subdict using frame extraction and captioning. + Args: + analysis_type (Union[AnalysisType, str], optional): Type of analysis to perform. Defaults to AnalysisType.SUMMARY. + frame_rate_per_second (float): Frame extraction rate in frames per second. Default is 2.0. + list_of_questions (List[str], optional): List of questions to answer about the video. Required if analysis_type includes questions. + Returns: + Dict[str, Any]: A dictionary containing the analysis results, including summary and answers for provided questions(if any). + """ + + all_answers = {} + analysis_type, is_summary, is_questions = AnalysisType._validate_analysis_type( + analysis_type, list_of_questions + ) + + for video_key, entry in self.subdict.items(): + summary_and_vqa = {} + frames_generator = self._extract_video_frames( + entry, frame_rate_per_second=frame_rate_per_second + ) + + answers_dict = self.make_captions_from_extracted_frames( + frames_generator, list_of_questions=list_of_questions + ) + # TODO: captions has to be post-processed with foreseeing audio analysis + # TODO: captions and answers may lead to prompt, that superior model limits. Consider hierarchical approach. + if is_summary: + answer = self.final_summary(answers_dict) + summary_and_vqa["summary"] = answer["summary"] + if is_questions: + answer = self.final_answers(answers_dict, list_of_questions) + summary_and_vqa["vqa_answers"] = answer["vqa_answers"] + + all_answers[video_key] = summary_and_vqa + + return all_answers diff --git a/environment.yml b/environment.yml new file mode 100644 index 0000000..58cd20e --- /dev/null +++ b/environment.yml @@ -0,0 +1,20 @@ +name: ammico-dev +channels: + - pytorch + - nvidia + - rapidsai + - conda-forge + - defaults + +dependencies: + - python=3.11 + - cudatoolkit=11.8 + - pytorch=2.5.1 + - pytorch-cuda=11.8 + - torchvision=0.20.1 + - torchaudio=2.5.1 + - faiss-gpu-raft=1.8.0 + - ipykernel + - jupyterlab + - jupyterlab_widgets + - ffmpeg<8 \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 4cad313..b672d06 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,16 +18,19 @@ classifiers = [ "Operating System :: OS Independent", "License :: OSI Approved :: MIT License", ] - + dependencies = [ + "accelerate>=0.22", + "bitsandbytes", "colorgram.py", "colour-science", "dash", "dash-bootstrap-components", "deepface", "google-cloud-vision", - "googletrans==4.0.0rc1", + "googletrans-py", # instead of googletrans4.0.0rc1, for a temporary solution due to incompatibility with jupyterlab "grpcio", + "huggingface-hub>=0.34.0", "importlib_metadata", "importlib_resources", "matplotlib", @@ -36,12 +39,17 @@ dependencies = [ "pandas", "Pillow", "pooch", + "qwen-vl-utils", "retina_face", + "safetensors>=0.6.2", "setuptools", "spacy", - "tensorflow<=2.16.0", + "tensorflow<2.15", # instead of <=2.16.0 to make it compatible with CUDA 11.8, may change after updating CUDA version. "tf-keras", + "torchvision", "tqdm", + "transformers>=4.54", + "torchcodec<0.2", "webcolors", ]