From c2b2079d4e6d70820042f5f78edded37d8eb0e62 Mon Sep 17 00:00:00 2001 From: DimasfromLavoisier Date: Wed, 22 Oct 2025 17:12:19 +0200 Subject: [PATCH] add functionality for video vqa --- ammico/image_summary.py | 55 ++----- ammico/utils.py | 114 ++++++++++++++ ammico/video_summary.py | 328 ++++++++++++++++++++++++++-------------- 3 files changed, 334 insertions(+), 163 deletions(-) diff --git a/ammico/image_summary.py b/ammico/image_summary.py index 203ef21..adc01af 100644 --- a/ammico/image_summary.py +++ b/ammico/image_summary.py @@ -6,7 +6,7 @@ import torch from PIL import Image import warnings -from typing import List, Optional, Union, Dict, Any, Tuple +from typing import List, Optional, Union, Dict, Any from collections.abc import Sequence as _Sequence from transformers import GenerationConfig from qwen_vl_utils import process_vision_info @@ -120,42 +120,11 @@ class ImageSummaryDetector(AnalysisMethod): return inputs - def _validate_analysis_type( - self, - analysis_type: Union["AnalysisType", str], - list_of_questions: Optional[List[str]], - max_questions_per_image: int, - ) -> Tuple[str, List[str], bool, bool]: - if isinstance(analysis_type, AnalysisType): - analysis_type = analysis_type.value - - allowed = {"summary", "questions", "summary_and_questions"} - if analysis_type not in allowed: - raise ValueError(f"analysis_type must be one of {allowed}") - - if list_of_questions is None: - list_of_questions = [ - "Are there people in the image?", - "What is this picture about?", - ] - - if analysis_type in ("questions", "summary_and_questions"): - if len(list_of_questions) > max_questions_per_image: - raise ValueError( - f"Number of questions per image ({len(list_of_questions)}) exceeds safety cap ({max_questions_per_image}). Reduce questions or increase max_questions_per_image." - ) - - is_summary = analysis_type in ("summary", "summary_and_questions") - is_questions = analysis_type in ("questions", "summary_and_questions") - - return analysis_type, list_of_questions, is_summary, is_questions - def analyse_image( self, entry: dict, - analysis_type: Union[str, AnalysisType] = AnalysisType.SUMMARY_AND_QUESTIONS, + analysis_type: Union[str, AnalysisType] = AnalysisType.SUMMARY, list_of_questions: Optional[List[str]] = None, - max_questions_per_image: int = 32, is_concise_summary: bool = True, is_concise_answer: bool = True, ) -> Dict[str, Any]: @@ -165,10 +134,8 @@ class ImageSummaryDetector(AnalysisMethod): - 'vqa' (dict) if questions requested """ self.subdict = entry - analysis_type, list_of_questions, is_summary, is_questions = ( - self._validate_analysis_type( - analysis_type, list_of_questions, max_questions_per_image - ) + analysis_type, is_summary, is_questions = AnalysisType._validate_analysis_type( + analysis_type, list_of_questions ) if is_summary: @@ -195,9 +162,8 @@ class ImageSummaryDetector(AnalysisMethod): def analyse_images_from_dict( self, - analysis_type: Union[AnalysisType, str] = AnalysisType.SUMMARY_AND_QUESTIONS, + analysis_type: Union[AnalysisType, str] = AnalysisType.SUMMARY, list_of_questions: Optional[List[str]] = None, - max_questions_per_image: int = 32, keys_batch_size: int = 16, is_concise_summary: bool = True, is_concise_answer: bool = True, @@ -208,7 +174,6 @@ class ImageSummaryDetector(AnalysisMethod): Args: analysis_type (str): type of the analysis. list_of_questions (list[str]): list of questions. - max_questions_per_image (int): maximum number of questions per image. We recommend to keep it low to avoid long processing times and high memory usage. keys_batch_size (int): number of images to process in a batch. is_concise_summary (bool): whether to generate concise summary. is_concise_answer (bool): whether to generate concise answers. @@ -216,10 +181,8 @@ class ImageSummaryDetector(AnalysisMethod): self.subdict (dict): dictionary with analysis results. """ # TODO: add option to ask multiple questions per image as one batch. - analysis_type, list_of_questions, is_summary, is_questions = ( - self._validate_analysis_type( - analysis_type, list_of_questions, max_questions_per_image - ) + analysis_type, is_summary, is_questions = AnalysisType._validate_analysis_type( + analysis_type, list_of_questions ) keys = list(self.subdict.keys()) @@ -284,7 +247,7 @@ class ImageSummaryDetector(AnalysisMethod): with torch.inference_mode(): try: if self.summary_model.device == "cuda": - with torch.cuda.amp.autocast(enabled=True): + with torch.amp.autocast("cuda", enabled=True): generated_ids = self.summary_model.model.generate( **inputs, generation_config=gen_conf ) @@ -366,7 +329,7 @@ class ImageSummaryDetector(AnalysisMethod): inputs = self._prepare_inputs(chunk, entry) with torch.inference_mode(): if self.summary_model.device == "cuda": - with torch.cuda.amp.autocast(enabled=True): + with torch.amp.autocast("cuda", enabled=True): out_ids = self.summary_model.model.generate( **inputs, generation_config=gen_conf ) diff --git a/ammico/utils.py b/ammico/utils.py index d698708..07bae90 100644 --- a/ammico/utils.py +++ b/ammico/utils.py @@ -6,6 +6,8 @@ import importlib_resources import collections import random from enum import Enum +from typing import List, Tuple, Optional, Union +import re pkg = importlib_resources.files("ammico") @@ -46,6 +48,35 @@ class AnalysisType(str, Enum): QUESTIONS = "questions" SUMMARY_AND_QUESTIONS = "summary_and_questions" + @classmethod + def _validate_analysis_type( + cls, + analysis_type: Union["AnalysisType", str], + list_of_questions: Optional[List[str]], + ) -> Tuple[str, bool, bool]: + max_questions_per_image = 15 # safety cap to avoid too many questions + if isinstance(analysis_type, AnalysisType): + analysis_type = analysis_type.value + + allowed = {item.value for item in AnalysisType} + if analysis_type not in allowed: + raise ValueError(f"analysis_type must be one of {allowed}") + + if analysis_type in ("questions", "summary_and_questions"): + if not list_of_questions: + raise ValueError( + "list_of_questions must be provided for QUESTIONS analysis type." + ) + + if len(list_of_questions) > max_questions_per_image: + raise ValueError( + f"Number of questions per image ({len(list_of_questions)}) exceeds safety cap ({max_questions_per_image}). Reduce questions or increase max_questions_per_image." + ) + + is_summary = analysis_type in ("summary", "summary_and_questions") + is_questions = analysis_type in ("questions", "summary_and_questions") + return analysis_type, is_summary, is_questions + class AnalysisMethod: """Base class to be inherited by all analysis methods.""" @@ -101,6 +132,89 @@ def _limit_results(results, limit): return results +def _categorize_outputs( + collected: List[Tuple[float, str]], + include_questions: bool = False, +) -> Tuple[List[str], List[str]]: + """ + Categorize collected outputs into summary bullets and VQA bullets. + Args: + collected (List[Tuple[float, str]]): List of tuples containing timestamps and generated texts. + Returns: + Tuple[List[str], List[str]]: A tuple containing two lists - summary bullets and VQA bullets. + """ + MAX_CAPTIONS_FOR_SUMMARY = 600 # TODO For now, this is a constant value, but later we need to make it adjustable, with the idea of cutting out the most similar frames to reduce the load on the system. + caps_for_summary_vqa = ( + collected[-MAX_CAPTIONS_FOR_SUMMARY:] + if len(collected) > MAX_CAPTIONS_FOR_SUMMARY + else collected + ) + bullets_summary = [] + bullets_vqa = [] + + for t, c in caps_for_summary_vqa: + if include_questions: + result_sections = c.strip() + m = re.search( + r"Summary\s*:\s*(.*?)\s*(?:VQA\s+Answers\s*:\s*(.*))?$", + result_sections, + flags=re.IGNORECASE | re.DOTALL, + ) + if m: + summary_text = ( + m.group(1).replace("\n", " ").strip() if m.group(1) else None + ) + vqa_text = m.group(2).strip() if m.group(2) else None + if not summary_text or not vqa_text: + raise ValueError( + f"Model output is missing either summary or VQA answers: {c}" + ) + bullets_summary.append(f"- [{t:.3f}s] {summary_text}") + bullets_vqa.append(f"- [{t:.3f}s] {vqa_text}") + else: + raise ValueError( + f"Failed to parse summary and VQA answers from model output: {c}" + ) + else: + snippet = c.replace("\n", " ").strip() + bullets_summary.append(f"- [{t:.3f}s] {snippet}") + return bullets_summary, bullets_vqa + + +def _normalize_whitespace(s: str) -> str: + return re.sub(r"\s+", " ", s).strip() + + +def _strip_prompt_prefix_literal(decoded: str, prompt: str) -> str: + """ + Remove any literal prompt prefix from decoded text using a normalized-substring match. + Guarantees no prompt text remains at the start of returned string (best-effort). + """ + if not decoded: + return "" + if not prompt: + return decoded.strip() + + d_norm = _normalize_whitespace(decoded) + p_norm = _normalize_whitespace(prompt) + + idx = d_norm.find(p_norm) + if idx != -1: + running = [] + for i, ch in enumerate(decoded): + running.append(ch if not ch.isspace() else " ") + cur_norm = _normalize_whitespace("".join(running)) + if cur_norm.endswith(p_norm): + return decoded[i + 1 :].lstrip() if i + 1 < len(decoded) else "" + m = re.match( + r"^(?:\s*(system|user|assistant)[:\s-]*\n?)+", decoded, flags=re.IGNORECASE + ) + if m: + return decoded[m.end() :].lstrip() + + return decoded.lstrip("\n\r ").lstrip(":;- ").strip() + + def find_videos( path: str = None, pattern=["mp4"], # TODO: test with more video formats diff --git a/ammico/video_summary.py b/ammico/video_summary.py index 666c373..391f2d0 100644 --- a/ammico/video_summary.py +++ b/ammico/video_summary.py @@ -6,17 +6,24 @@ from PIL import Image from torchcodec.decoders import VideoDecoder from ammico.model import MultimodalSummaryModel -from ammico.utils import AnalysisMethod +from ammico.utils import ( + AnalysisMethod, + AnalysisType, + _categorize_outputs, + _strip_prompt_prefix_literal, +) -from typing import List, Dict, Any, Generator, Tuple +from typing import List, Dict, Any, Generator, Tuple, Union, Optional from transformers import GenerationConfig class VideoSummaryDetector(AnalysisMethod): + MAX_SAMPLES_CAP = 1000 # safety cap for total extracted frames + def __init__( self, summary_model: MultimodalSummaryModel, - subdict: dict = {}, + subdict: Optional[Dict[str, Any]] = None, ) -> None: """ Class for analysing videos using QWEN-2.5-VL model. @@ -29,6 +36,8 @@ class VideoSummaryDetector(AnalysisMethod): Returns: None. """ + if subdict is None: + subdict = {} super().__init__(subdict) self.summary_model = summary_model @@ -107,6 +116,7 @@ class VideoSummaryDetector(AnalysisMethod): raise ValueError("frame_rate_per_second must be > 0") n_samples = max(1, int(math.floor(duration * frame_rate_per_second))) + n_samples = min(n_samples, self.MAX_SAMPLES_CAP) if begin_stream_seconds is not None and end_stream_seconds is not None: sample_times = torch.linspace( @@ -126,38 +136,6 @@ class VideoSummaryDetector(AnalysisMethod): return generator - def _normalize_whitespace(self, s: str) -> str: - return re.sub(r"\s+", " ", s).strip() - - def _strip_prompt_prefix_literal(self, decoded: str, prompt: str) -> str: - """ - Remove any literal prompt prefix from decoded text using a normalized-substring match. - Guarantees no prompt text remains at the start of returned string (best-effort). - """ - if not decoded: - return "" - if not prompt: - return decoded.strip() - - d_norm = self._normalize_whitespace(decoded) - p_norm = self._normalize_whitespace(prompt) - - idx = d_norm.find(p_norm) - if idx != -1: - running = [] - for i, ch in enumerate(decoded): - running.append(ch if not ch.isspace() else " ") - cur_norm = self._normalize_whitespace("".join(running)) - if cur_norm.endswith(p_norm): - return decoded[i + 1 :].lstrip() if i + 1 < len(decoded) else "" - m = re.match( - r"^(?:\s*(system|user|assistant)[:\s-]*\n?)+", decoded, flags=re.IGNORECASE - ) - if m: - return decoded[m.end() :].lstrip() - - return decoded.lstrip("\n\r ").lstrip(":;- ").strip() - def _decode_trimmed_outputs( self, generated_ids: torch.Tensor, @@ -170,20 +148,18 @@ class VideoSummaryDetector(AnalysisMethod): Then remove any literal prompt prefix using prompt_texts (one per batch element). """ - decoded_results = [] batch_size = generated_ids.shape[0] if "input_ids" in inputs: - lengths = ( - inputs["input_ids"] - .ne( - tokenizer.pad_token_id - if tokenizer.pad_token_id is not None - else tokenizer.eos_token_id - ) - .sum(dim=1) - .tolist() + token_for_padding = ( + tokenizer.pad_token_id + if getattr(tokenizer, "pad_token_id", None) is not None + else getattr(tokenizer, "eos_token_id", None) ) + if token_for_padding is None: + lengths = [int(inputs["input_ids"].shape[1])] * batch_size + else: + lengths = inputs["input_ids"].ne(token_for_padding).sum(dim=1).tolist() else: lengths = [0] * batch_size @@ -195,13 +171,15 @@ class VideoSummaryDetector(AnalysisMethod): t = out_ids[in_len:] else: t = out_ids.new_empty((0,), dtype=out_ids.dtype) - trimmed_ids.append(t) + t_cpu = t.to("cpu") + trimmed_ids.append(t_cpu.tolist()) decoded = tokenizer.batch_decode( trimmed_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True ) + decoded_results = [] for ptext, raw in zip(prompt_texts, decoded): - cleaned = self._strip_prompt_prefix_literal(raw, ptext) + cleaned = _strip_prompt_prefix_literal(raw, ptext) decoded_results.append(cleaned) return decoded_results @@ -209,7 +187,6 @@ class VideoSummaryDetector(AnalysisMethod): self, processor_inputs: Dict[str, torch.Tensor], prompt_texts: List[str], - model, tokenizer, ): """ @@ -224,7 +201,7 @@ class VideoSummaryDetector(AnalysisMethod): for k, v in list(processor_inputs.items()): if isinstance(v, torch.Tensor): - processor_inputs[k] = v.to(model.device) + processor_inputs[k] = v.to(self.summary_model.device) with torch.inference_mode(): try: @@ -239,8 +216,8 @@ class VideoSummaryDetector(AnalysisMethod): ) except RuntimeError as e: warnings.warn( - "Retry without autocast failed: %s. Attempting cudnn-disabled retry.", - e, + f"Generation failed with error: {e}. Retrying with cuDNN disabled.", + RuntimeWarning, ) cudnn_was_enabled = ( torch.backends.cudnn.is_available() and torch.backends.cudnn.enabled @@ -273,7 +250,9 @@ class VideoSummaryDetector(AnalysisMethod): batch = batch.to("cpu") batch = batch.contiguous() - if batch.dtype != torch.uint8: + if batch.dtype.is_floating_point: + batch = (batch.clamp(0.0, 1.0) * 255.0).to(torch.uint8) + elif batch.dtype != torch.uint8: batch = batch.to(torch.uint8) pil_list: List[Image.Image] = [] for frame in batch: @@ -281,10 +260,10 @@ class VideoSummaryDetector(AnalysisMethod): pil_list.append(Image.fromarray(arr)) return pil_list - def brute_force_summary( + def make_captions_from_extracted_frames( self, extracted_video_gen: Generator[Tuple[torch.Tensor, torch.Tensor], None, None], - summary_instruction: str = "Analyze the following captions from multiple frames of the same video and summarize the overall content of the video in one concise paragraph (1-3 sentences). Focus on the key themes, actions, or events across the video, not just the individual frames.", + list_of_questions: Optional[List[str]] = None, ) -> Dict[str, Any]: """ Generate captions for all extracted frames and then produce a concise summary of the video. @@ -296,68 +275,93 @@ class VideoSummaryDetector(AnalysisMethod): """ caption_instruction = "Describe this image in one concise caption." + include_questions = bool(list_of_questions) + if include_questions: + q_block = "\n".join( + [f"{i + 1}. {q.strip()}" for i, q in enumerate(list_of_questions)] + ) + caption_instruction += ( + " In addition to the concise caption, also answer the following questions based ONLY on the image. Answers must be very brief and concise." + " Produce exactly two labeled sections: \n\n" + "Summary: \n\n" + "VQA Answers: \n1. \n2. \n etc." + "\nReturn only those two sections for each image (do not add extra commentary)." + "\nIf the answer cannot be determined based on the provided answer blocks," + ' reply with the line "The answer cannot be determined based on the information provided."' + f"\n\nQuestions:\n{q_block}" + ) collected: List[Tuple[float, str]] = [] proc = self.summary_model.processor + try: + for batch_frames, batch_times in extracted_video_gen: + pil_list = self._tensor_batch_to_pil_list(batch_frames.cpu()) - for batch_frames, batch_times in extracted_video_gen: - pil_list = self._tensor_batch_to_pil_list(batch_frames.cpu()) + prompt_texts = [] + for p in pil_list: + messages = [ + { + "role": "user", + "content": [ + {"type": "image", "image": p}, + {"type": "text", "text": caption_instruction}, + ], + } + ] - prompt_texts = [] - for p in pil_list: - messages = [ - { - "role": "user", - "content": [ - {"type": "image", "image": p}, - {"type": "text", "text": caption_instruction}, - ], - } - ] + prompt_text = proc.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) + prompt_texts.append(prompt_text) - prompt_text = proc.apply_chat_template( - messages, tokenize=False, add_generation_prompt=True + processor_inputs = proc( + text=prompt_texts, + images=pil_list, + return_tensors="pt", + padding=True, + ) + captions = self._generate_from_processor_inputs( + processor_inputs, + prompt_texts, + self.summary_model.tokenizer, ) - prompt_texts.append(prompt_text) - processor_inputs = proc( - text=prompt_texts, images=pil_list, return_tensors="pt", padding=True - ) - captions = self._generate_from_processor_inputs( - processor_inputs, - prompt_texts, - self.summary_model, - self.summary_model.tokenizer, - ) - - if isinstance(batch_times, torch.Tensor): - batch_times_list = batch_times.cpu().tolist() - else: - batch_times_list = list(batch_times) - for t, c in zip(batch_times_list, captions): - collected.append((float(t), c)) + if isinstance(batch_times, torch.Tensor): + batch_times_list = batch_times.cpu().tolist() + else: + batch_times_list = list(batch_times) + for t, c in zip(batch_times_list, captions): + collected.append((float(t), c)) + finally: + try: + extracted_video_gen.close() + except Exception: + warnings.warn("Failed to close video frame generator.", RuntimeWarning) collected.sort(key=lambda x: x[0]) - extracted_video_gen.close() + bullets_summary, bullets_vqa = _categorize_outputs(collected, include_questions) - MAX_CAPTIONS_FOR_SUMMARY = 200 - caps_for_summary = ( - collected[-MAX_CAPTIONS_FOR_SUMMARY:] - if len(collected) > MAX_CAPTIONS_FOR_SUMMARY - else collected - ) + return { + "summary_bullets": bullets_summary, + "vqa_bullets": bullets_vqa, + } # TODO consider taking out time stamps from the returned structure - bullets = [] - for t, c in caps_for_summary: - snippet = c.replace("\n", " ").strip() - bullets.append(f"- [{t:.3f}s] {snippet}") + def final_summary(self, summary_dict: Dict[str, Any]) -> Dict[str, Any]: + """ + Produce a concise summary of the video, based on generated captions for all extracted frames. + Args: + summary_dict (Dict[str, Any]): Dictionary containing captions for the frames. + Returns: + Dict[str, Any]: A dictionary containing the list of captions with timestamps and the final summary. + """ + summary_instruction = "Analyze the following captions from multiple frames of the same video and summarize the overall content of the video in one concise paragraph (1-3 sentences). Focus on the key themes, actions, or events across the video, not just the individual frames." + proc = self.summary_model.processor + + bullets = summary_dict.get("summary_bullets", []) + if not bullets: + raise ValueError("No captions available for summary generation.") combined_captions_text = "\n".join(bullets) - summary_user_text = ( - summary_instruction - + "\n\n" - + combined_captions_text - + "\n\nPlease produce a single concise paragraph." - ) + summary_user_text = summary_instruction + "\n\n" + combined_captions_text + "\n" messages = [ { @@ -381,40 +385,130 @@ class VideoSummaryDetector(AnalysisMethod): final_summary_list = self._generate_from_processor_inputs( summary_inputs, [summary_prompt_text], - self.summary_model.model, self.summary_model.tokenizer, ) final_summary = final_summary_list[0].strip() if final_summary_list else "" return { - "captions": collected, "summary": final_summary, } + def final_answers( + self, + answers_dict: Dict[str, Any], + list_of_questions: List[str], + ) -> Dict[str, Any]: + """ + Answer the list of questions for the video based on the VQA bullets from the frames. + Args: + answers_dict (Dict[str, Any]): Dictionary containing the VQA bullets. + Returns: + Dict[str, Any]: A dictionary containing the list of answers to the questions. + """ + vqa_bullets = answers_dict.get("vqa_bullets", []) + if not vqa_bullets: + raise ValueError( + "No VQA bullets generated for single frames available for answering questions." + ) + + include_questions = bool(list_of_questions) + if include_questions: + q_block = "\n".join( + [f"{i + 1}. {q.strip()}" for i, q in enumerate(list_of_questions)] + ) + prompt = ( + "You are provided with a set of short VQA-captions, each of which is a block of short answers" + " extracted from individual frames of the same video.\n\n" + "VQA-captions (use ONLY these to answer):\n" + f"{vqa_bullets}\n\n" + "Answer the following questions briefly, based ONLY on the lists of answers provided above. The VQA-captions above contain answers" + " to the same questions you are about to answer. If the answer cannot be determined based on the provided answer blocks," + ' reply with the line "The answer cannot be determined based on the information provided."' + "Questions:\n" + f"{q_block}\n\n" + "Produce an ordered list with answers in the same order as the questions. You must have this structure of your output: " + "Answers: \n1. \n2. \n etc." + "Return ONLY the ordered list with answers and NOTHING else — no commentary, no explanation, no surrounding markdown." + ) + else: + raise ValueError( + "list_of_questions must be provided for making final answers." + ) + + proc = self.summary_model.processor + messages = [ + { + "role": "user", + "content": [{"type": "text", "text": prompt}], + } + ] + final_vqa_prompt_text = proc.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) + final_vqa_inputs = proc( + text=[final_vqa_prompt_text], return_tensors="pt", padding=True + ) + final_vqa_inputs = { + k: v.to(self.summary_model.device) if isinstance(v, torch.Tensor) else v + for k, v in final_vqa_inputs.items() + } + + final_vqa_list = self._generate_from_processor_inputs( + final_vqa_inputs, + [final_vqa_prompt_text], + self.summary_model.tokenizer, + ) + + final_vqa_output = final_vqa_list[0].strip() if final_vqa_list else "" + vqa_answers = [] + answer_matches = re.findall( + r"\d+\.\s*(.*?)(?=\n\d+\.|$)", final_vqa_output, flags=re.DOTALL + ) + for answer in answer_matches: + vqa_answers.append(answer.strip()) + return { + "vqa_answers": vqa_answers, + } + def analyse_videos_from_dict( - self, frame_rate_per_second: float = 2.0 + self, + analysis_type: Union[AnalysisType, str] = AnalysisType.SUMMARY, + frame_rate_per_second: float = 2.0, + list_of_questions: Optional[List[str]] = None, ) -> Dict[str, Any]: """ Analyse the video specified in self.subdict using frame extraction and captioning. - For short videos (<=100 frames at the specified frame rate), it uses brute-force captioning. - For longer videos, it currently defaults to brute-force captioning, but can be extended for more complex methods. - Args: + analysis_type (Union[AnalysisType, str], optional): Type of analysis to perform. Defaults to AnalysisType.SUMMARY. frame_rate_per_second (float): Frame extraction rate in frames per second. Default is 2.0. + list_of_questions (List[str], optional): List of questions to answer about the video. Required if analysis_type includes questions. Returns: - Dict[str, Any]: A dictionary containing the analysis results, including captions and summary. + Dict[str, Any]: A dictionary containing the analysis results, including summary and answers for provided questions(if any). """ all_answers = {} - # TODO: add support for answering questions about videos - for video_key in list(self.subdict.keys()): - entry = self.subdict[video_key] - extracted_video_gen = self._extract_video_frames( + analysis_type, is_summary, is_questions = AnalysisType._validate_analysis_type( + analysis_type, list_of_questions + ) + + for video_key, entry in self.subdict.items(): + summary_and_vqa = {} + frames_generator = self._extract_video_frames( entry, frame_rate_per_second=frame_rate_per_second ) - answer = self.brute_force_summary(extracted_video_gen) - all_answers[video_key] = {"summary": answer["summary"]} + answers_dict = self.make_captions_from_extracted_frames( + frames_generator, list_of_questions=list_of_questions + ) # TODO: captions has to be post-processed with foreseeing audio analysis + # TODO: captions and answers may lead to prompt, that superior model limits. Consider hierarchical approach. + if is_summary: + answer = self.final_summary(answers_dict) + summary_and_vqa["summary"] = answer["summary"] + if is_questions: + answer = self.final_answers(answers_dict, list_of_questions) + summary_and_vqa["vqa_answers"] = answer["vqa_answers"] - return answer + all_answers[video_key] = summary_and_vqa + + return all_answers