зеркало из
				https://github.com/ssciwr/AMMICO.git
				synced 2025-10-29 21:16:06 +02:00 
			
		
		
		
	Merge pull request #8 from DimasfromLavoisier/add_video_vqa
add functionality for video vqa
Этот коммит содержится в:
		
						Коммит
						3e036514cd
					
				| @ -6,7 +6,7 @@ import torch | ||||
| from PIL import Image | ||||
| import warnings | ||||
| 
 | ||||
| from typing import List, Optional, Union, Dict, Any, Tuple | ||||
| from typing import List, Optional, Union, Dict, Any | ||||
| from collections.abc import Sequence as _Sequence | ||||
| from transformers import GenerationConfig | ||||
| from qwen_vl_utils import process_vision_info | ||||
| @ -120,42 +120,11 @@ class ImageSummaryDetector(AnalysisMethod): | ||||
| 
 | ||||
|         return inputs | ||||
| 
 | ||||
|     def _validate_analysis_type( | ||||
|         self, | ||||
|         analysis_type: Union["AnalysisType", str], | ||||
|         list_of_questions: Optional[List[str]], | ||||
|         max_questions_per_image: int, | ||||
|     ) -> Tuple[str, List[str], bool, bool]: | ||||
|         if isinstance(analysis_type, AnalysisType): | ||||
|             analysis_type = analysis_type.value | ||||
| 
 | ||||
|         allowed = {"summary", "questions", "summary_and_questions"} | ||||
|         if analysis_type not in allowed: | ||||
|             raise ValueError(f"analysis_type must be one of {allowed}") | ||||
| 
 | ||||
|         if list_of_questions is None: | ||||
|             list_of_questions = [ | ||||
|                 "Are there people in the image?", | ||||
|                 "What is this picture about?", | ||||
|             ] | ||||
| 
 | ||||
|         if analysis_type in ("questions", "summary_and_questions"): | ||||
|             if len(list_of_questions) > max_questions_per_image: | ||||
|                 raise ValueError( | ||||
|                     f"Number of questions per image ({len(list_of_questions)}) exceeds safety cap ({max_questions_per_image}). Reduce questions or increase max_questions_per_image." | ||||
|                 ) | ||||
| 
 | ||||
|         is_summary = analysis_type in ("summary", "summary_and_questions") | ||||
|         is_questions = analysis_type in ("questions", "summary_and_questions") | ||||
| 
 | ||||
|         return analysis_type, list_of_questions, is_summary, is_questions | ||||
| 
 | ||||
|     def analyse_image( | ||||
|         self, | ||||
|         entry: dict, | ||||
|         analysis_type: Union[str, AnalysisType] = AnalysisType.SUMMARY_AND_QUESTIONS, | ||||
|         analysis_type: Union[str, AnalysisType] = AnalysisType.SUMMARY, | ||||
|         list_of_questions: Optional[List[str]] = None, | ||||
|         max_questions_per_image: int = 32, | ||||
|         is_concise_summary: bool = True, | ||||
|         is_concise_answer: bool = True, | ||||
|     ) -> Dict[str, Any]: | ||||
| @ -165,10 +134,8 @@ class ImageSummaryDetector(AnalysisMethod): | ||||
|             - 'vqa' (dict) if questions requested | ||||
|         """ | ||||
|         self.subdict = entry | ||||
|         analysis_type, list_of_questions, is_summary, is_questions = ( | ||||
|             self._validate_analysis_type( | ||||
|                 analysis_type, list_of_questions, max_questions_per_image | ||||
|             ) | ||||
|         analysis_type, is_summary, is_questions = AnalysisType._validate_analysis_type( | ||||
|             analysis_type, list_of_questions | ||||
|         ) | ||||
| 
 | ||||
|         if is_summary: | ||||
| @ -195,9 +162,8 @@ class ImageSummaryDetector(AnalysisMethod): | ||||
| 
 | ||||
|     def analyse_images_from_dict( | ||||
|         self, | ||||
|         analysis_type: Union[AnalysisType, str] = AnalysisType.SUMMARY_AND_QUESTIONS, | ||||
|         analysis_type: Union[AnalysisType, str] = AnalysisType.SUMMARY, | ||||
|         list_of_questions: Optional[List[str]] = None, | ||||
|         max_questions_per_image: int = 32, | ||||
|         keys_batch_size: int = 16, | ||||
|         is_concise_summary: bool = True, | ||||
|         is_concise_answer: bool = True, | ||||
| @ -208,7 +174,6 @@ class ImageSummaryDetector(AnalysisMethod): | ||||
|         Args: | ||||
|             analysis_type (str): type of the analysis. | ||||
|             list_of_questions (list[str]): list of questions. | ||||
|             max_questions_per_image (int): maximum number of questions per image. We recommend to keep it low to avoid long processing times and high memory usage. | ||||
|             keys_batch_size (int): number of images to process in a batch. | ||||
|             is_concise_summary (bool): whether to generate concise summary. | ||||
|             is_concise_answer (bool): whether to generate concise answers. | ||||
| @ -216,10 +181,8 @@ class ImageSummaryDetector(AnalysisMethod): | ||||
|             self.subdict (dict): dictionary with analysis results. | ||||
|         """ | ||||
|         # TODO: add option to ask multiple questions per image as one batch. | ||||
|         analysis_type, list_of_questions, is_summary, is_questions = ( | ||||
|             self._validate_analysis_type( | ||||
|                 analysis_type, list_of_questions, max_questions_per_image | ||||
|             ) | ||||
|         analysis_type, is_summary, is_questions = AnalysisType._validate_analysis_type( | ||||
|             analysis_type, list_of_questions | ||||
|         ) | ||||
| 
 | ||||
|         keys = list(self.subdict.keys()) | ||||
| @ -284,7 +247,7 @@ class ImageSummaryDetector(AnalysisMethod): | ||||
|         with torch.inference_mode(): | ||||
|             try: | ||||
|                 if self.summary_model.device == "cuda": | ||||
|                     with torch.cuda.amp.autocast(enabled=True): | ||||
|                     with torch.amp.autocast("cuda", enabled=True): | ||||
|                         generated_ids = self.summary_model.model.generate( | ||||
|                             **inputs, generation_config=gen_conf | ||||
|                         ) | ||||
| @ -366,7 +329,7 @@ class ImageSummaryDetector(AnalysisMethod): | ||||
|             inputs = self._prepare_inputs(chunk, entry) | ||||
|             with torch.inference_mode(): | ||||
|                 if self.summary_model.device == "cuda": | ||||
|                     with torch.cuda.amp.autocast(enabled=True): | ||||
|                     with torch.amp.autocast("cuda", enabled=True): | ||||
|                         out_ids = self.summary_model.model.generate( | ||||
|                             **inputs, generation_config=gen_conf | ||||
|                         ) | ||||
|  | ||||
							
								
								
									
										114
									
								
								ammico/utils.py
									
									
									
									
									
								
							
							
						
						
									
										114
									
								
								ammico/utils.py
									
									
									
									
									
								
							| @ -6,6 +6,8 @@ import importlib_resources | ||||
| import collections | ||||
| import random | ||||
| from enum import Enum | ||||
| from typing import List, Tuple, Optional, Union | ||||
| import re | ||||
| 
 | ||||
| 
 | ||||
| pkg = importlib_resources.files("ammico") | ||||
| @ -46,6 +48,35 @@ class AnalysisType(str, Enum): | ||||
|     QUESTIONS = "questions" | ||||
|     SUMMARY_AND_QUESTIONS = "summary_and_questions" | ||||
| 
 | ||||
|     @classmethod | ||||
|     def _validate_analysis_type( | ||||
|         cls, | ||||
|         analysis_type: Union["AnalysisType", str], | ||||
|         list_of_questions: Optional[List[str]], | ||||
|     ) -> Tuple[str, bool, bool]: | ||||
|         max_questions_per_image = 15  # safety cap to avoid too many questions | ||||
|         if isinstance(analysis_type, AnalysisType): | ||||
|             analysis_type = analysis_type.value | ||||
| 
 | ||||
|         allowed = {item.value for item in AnalysisType} | ||||
|         if analysis_type not in allowed: | ||||
|             raise ValueError(f"analysis_type must be one of {allowed}") | ||||
| 
 | ||||
|         if analysis_type in ("questions", "summary_and_questions"): | ||||
|             if not list_of_questions: | ||||
|                 raise ValueError( | ||||
|                     "list_of_questions must be provided for QUESTIONS analysis type." | ||||
|                 ) | ||||
| 
 | ||||
|             if len(list_of_questions) > max_questions_per_image: | ||||
|                 raise ValueError( | ||||
|                     f"Number of questions per image ({len(list_of_questions)}) exceeds safety cap ({max_questions_per_image}). Reduce questions or increase max_questions_per_image." | ||||
|                 ) | ||||
| 
 | ||||
|         is_summary = analysis_type in ("summary", "summary_and_questions") | ||||
|         is_questions = analysis_type in ("questions", "summary_and_questions") | ||||
|         return analysis_type, is_summary, is_questions | ||||
| 
 | ||||
| 
 | ||||
| class AnalysisMethod: | ||||
|     """Base class to be inherited by all analysis methods.""" | ||||
| @ -101,6 +132,89 @@ def _limit_results(results, limit): | ||||
|     return results | ||||
| 
 | ||||
| 
 | ||||
| def _categorize_outputs( | ||||
|     collected: List[Tuple[float, str]], | ||||
|     include_questions: bool = False, | ||||
| ) -> Tuple[List[str], List[str]]: | ||||
|     """ | ||||
|     Categorize collected outputs into summary bullets and VQA bullets. | ||||
|     Args: | ||||
|         collected (List[Tuple[float, str]]): List of tuples containing timestamps and generated texts. | ||||
|     Returns: | ||||
|         Tuple[List[str], List[str]]: A tuple containing two lists - summary bullets and VQA bullets. | ||||
|     """ | ||||
|     MAX_CAPTIONS_FOR_SUMMARY = 600  # TODO For now, this is a constant value, but later we need to make it adjustable, with the idea of cutting out the most similar frames to reduce the load on the system. | ||||
|     caps_for_summary_vqa = ( | ||||
|         collected[-MAX_CAPTIONS_FOR_SUMMARY:] | ||||
|         if len(collected) > MAX_CAPTIONS_FOR_SUMMARY | ||||
|         else collected | ||||
|     ) | ||||
|     bullets_summary = [] | ||||
|     bullets_vqa = [] | ||||
| 
 | ||||
|     for t, c in caps_for_summary_vqa: | ||||
|         if include_questions: | ||||
|             result_sections = c.strip() | ||||
|             m = re.search( | ||||
|                 r"Summary\s*:\s*(.*?)\s*(?:VQA\s+Answers\s*:\s*(.*))?$", | ||||
|                 result_sections, | ||||
|                 flags=re.IGNORECASE | re.DOTALL, | ||||
|             ) | ||||
|             if m: | ||||
|                 summary_text = ( | ||||
|                     m.group(1).replace("\n", " ").strip() if m.group(1) else None | ||||
|                 ) | ||||
|                 vqa_text = m.group(2).strip() if m.group(2) else None | ||||
|                 if not summary_text or not vqa_text: | ||||
|                     raise ValueError( | ||||
|                         f"Model output is missing either summary or VQA answers: {c}" | ||||
|                     ) | ||||
|                 bullets_summary.append(f"- [{t:.3f}s] {summary_text}") | ||||
|                 bullets_vqa.append(f"- [{t:.3f}s] {vqa_text}") | ||||
|             else: | ||||
|                 raise ValueError( | ||||
|                     f"Failed to parse summary and VQA answers from model output: {c}" | ||||
|                 ) | ||||
|         else: | ||||
|             snippet = c.replace("\n", " ").strip() | ||||
|             bullets_summary.append(f"- [{t:.3f}s] {snippet}") | ||||
|     return bullets_summary, bullets_vqa | ||||
| 
 | ||||
| 
 | ||||
| def _normalize_whitespace(s: str) -> str: | ||||
|     return re.sub(r"\s+", " ", s).strip() | ||||
| 
 | ||||
| 
 | ||||
| def _strip_prompt_prefix_literal(decoded: str, prompt: str) -> str: | ||||
|     """ | ||||
|     Remove any literal prompt prefix from decoded text using a normalized-substring match. | ||||
|     Guarantees no prompt text remains at the start of returned string (best-effort). | ||||
|     """ | ||||
|     if not decoded: | ||||
|         return "" | ||||
|     if not prompt: | ||||
|         return decoded.strip() | ||||
| 
 | ||||
|     d_norm = _normalize_whitespace(decoded) | ||||
|     p_norm = _normalize_whitespace(prompt) | ||||
| 
 | ||||
|     idx = d_norm.find(p_norm) | ||||
|     if idx != -1: | ||||
|         running = [] | ||||
|         for i, ch in enumerate(decoded): | ||||
|             running.append(ch if not ch.isspace() else " ") | ||||
|             cur_norm = _normalize_whitespace("".join(running)) | ||||
|             if cur_norm.endswith(p_norm): | ||||
|                 return decoded[i + 1 :].lstrip() if i + 1 < len(decoded) else "" | ||||
|     m = re.match( | ||||
|         r"^(?:\s*(system|user|assistant)[:\s-]*\n?)+", decoded, flags=re.IGNORECASE | ||||
|     ) | ||||
|     if m: | ||||
|         return decoded[m.end() :].lstrip() | ||||
| 
 | ||||
|     return decoded.lstrip("\n\r ").lstrip(":;- ").strip() | ||||
| 
 | ||||
| 
 | ||||
| def find_videos( | ||||
|     path: str = None, | ||||
|     pattern=["mp4"],  # TODO: test with more video formats | ||||
|  | ||||
| @ -6,17 +6,24 @@ from PIL import Image | ||||
| from torchcodec.decoders import VideoDecoder | ||||
| 
 | ||||
| from ammico.model import MultimodalSummaryModel | ||||
| from ammico.utils import AnalysisMethod | ||||
| from ammico.utils import ( | ||||
|     AnalysisMethod, | ||||
|     AnalysisType, | ||||
|     _categorize_outputs, | ||||
|     _strip_prompt_prefix_literal, | ||||
| ) | ||||
| 
 | ||||
| from typing import List, Dict, Any, Generator, Tuple | ||||
| from typing import List, Dict, Any, Generator, Tuple, Union, Optional | ||||
| from transformers import GenerationConfig | ||||
| 
 | ||||
| 
 | ||||
| class VideoSummaryDetector(AnalysisMethod): | ||||
|     MAX_SAMPLES_CAP = 1000  # safety cap for total extracted frames | ||||
| 
 | ||||
|     def __init__( | ||||
|         self, | ||||
|         summary_model: MultimodalSummaryModel, | ||||
|         subdict: dict = {}, | ||||
|         subdict: Optional[Dict[str, Any]] = None, | ||||
|     ) -> None: | ||||
|         """ | ||||
|         Class for analysing videos using QWEN-2.5-VL model. | ||||
| @ -29,6 +36,8 @@ class VideoSummaryDetector(AnalysisMethod): | ||||
|         Returns: | ||||
|             None. | ||||
|         """ | ||||
|         if subdict is None: | ||||
|             subdict = {} | ||||
| 
 | ||||
|         super().__init__(subdict) | ||||
|         self.summary_model = summary_model | ||||
| @ -107,6 +116,7 @@ class VideoSummaryDetector(AnalysisMethod): | ||||
|             raise ValueError("frame_rate_per_second must be > 0") | ||||
| 
 | ||||
|         n_samples = max(1, int(math.floor(duration * frame_rate_per_second))) | ||||
|         n_samples = min(n_samples, self.MAX_SAMPLES_CAP) | ||||
| 
 | ||||
|         if begin_stream_seconds is not None and end_stream_seconds is not None: | ||||
|             sample_times = torch.linspace( | ||||
| @ -126,38 +136,6 @@ class VideoSummaryDetector(AnalysisMethod): | ||||
| 
 | ||||
|         return generator | ||||
| 
 | ||||
|     def _normalize_whitespace(self, s: str) -> str: | ||||
|         return re.sub(r"\s+", " ", s).strip() | ||||
| 
 | ||||
|     def _strip_prompt_prefix_literal(self, decoded: str, prompt: str) -> str: | ||||
|         """ | ||||
|         Remove any literal prompt prefix from decoded text using a normalized-substring match. | ||||
|         Guarantees no prompt text remains at the start of returned string (best-effort). | ||||
|         """ | ||||
|         if not decoded: | ||||
|             return "" | ||||
|         if not prompt: | ||||
|             return decoded.strip() | ||||
| 
 | ||||
|         d_norm = self._normalize_whitespace(decoded) | ||||
|         p_norm = self._normalize_whitespace(prompt) | ||||
| 
 | ||||
|         idx = d_norm.find(p_norm) | ||||
|         if idx != -1: | ||||
|             running = [] | ||||
|             for i, ch in enumerate(decoded): | ||||
|                 running.append(ch if not ch.isspace() else " ") | ||||
|                 cur_norm = self._normalize_whitespace("".join(running)) | ||||
|                 if cur_norm.endswith(p_norm): | ||||
|                     return decoded[i + 1 :].lstrip() if i + 1 < len(decoded) else "" | ||||
|         m = re.match( | ||||
|             r"^(?:\s*(system|user|assistant)[:\s-]*\n?)+", decoded, flags=re.IGNORECASE | ||||
|         ) | ||||
|         if m: | ||||
|             return decoded[m.end() :].lstrip() | ||||
| 
 | ||||
|         return decoded.lstrip("\n\r ").lstrip(":;- ").strip() | ||||
| 
 | ||||
|     def _decode_trimmed_outputs( | ||||
|         self, | ||||
|         generated_ids: torch.Tensor, | ||||
| @ -170,20 +148,18 @@ class VideoSummaryDetector(AnalysisMethod): | ||||
|         Then remove any literal prompt prefix using prompt_texts (one per batch element). | ||||
|         """ | ||||
| 
 | ||||
|         decoded_results = [] | ||||
|         batch_size = generated_ids.shape[0] | ||||
| 
 | ||||
|         if "input_ids" in inputs: | ||||
|             lengths = ( | ||||
|                 inputs["input_ids"] | ||||
|                 .ne( | ||||
|                     tokenizer.pad_token_id | ||||
|                     if tokenizer.pad_token_id is not None | ||||
|                     else tokenizer.eos_token_id | ||||
|                 ) | ||||
|                 .sum(dim=1) | ||||
|                 .tolist() | ||||
|             token_for_padding = ( | ||||
|                 tokenizer.pad_token_id | ||||
|                 if getattr(tokenizer, "pad_token_id", None) is not None | ||||
|                 else getattr(tokenizer, "eos_token_id", None) | ||||
|             ) | ||||
|             if token_for_padding is None: | ||||
|                 lengths = [int(inputs["input_ids"].shape[1])] * batch_size | ||||
|             else: | ||||
|                 lengths = inputs["input_ids"].ne(token_for_padding).sum(dim=1).tolist() | ||||
|         else: | ||||
|             lengths = [0] * batch_size | ||||
| 
 | ||||
| @ -195,13 +171,15 @@ class VideoSummaryDetector(AnalysisMethod): | ||||
|                 t = out_ids[in_len:] | ||||
|             else: | ||||
|                 t = out_ids.new_empty((0,), dtype=out_ids.dtype) | ||||
|             trimmed_ids.append(t) | ||||
|             t_cpu = t.to("cpu") | ||||
|             trimmed_ids.append(t_cpu.tolist()) | ||||
| 
 | ||||
|         decoded = tokenizer.batch_decode( | ||||
|             trimmed_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True | ||||
|         ) | ||||
|         decoded_results = [] | ||||
|         for ptext, raw in zip(prompt_texts, decoded): | ||||
|             cleaned = self._strip_prompt_prefix_literal(raw, ptext) | ||||
|             cleaned = _strip_prompt_prefix_literal(raw, ptext) | ||||
|             decoded_results.append(cleaned) | ||||
|         return decoded_results | ||||
| 
 | ||||
| @ -209,7 +187,6 @@ class VideoSummaryDetector(AnalysisMethod): | ||||
|         self, | ||||
|         processor_inputs: Dict[str, torch.Tensor], | ||||
|         prompt_texts: List[str], | ||||
|         model, | ||||
|         tokenizer, | ||||
|     ): | ||||
|         """ | ||||
| @ -224,7 +201,7 @@ class VideoSummaryDetector(AnalysisMethod): | ||||
| 
 | ||||
|         for k, v in list(processor_inputs.items()): | ||||
|             if isinstance(v, torch.Tensor): | ||||
|                 processor_inputs[k] = v.to(model.device) | ||||
|                 processor_inputs[k] = v.to(self.summary_model.device) | ||||
| 
 | ||||
|         with torch.inference_mode(): | ||||
|             try: | ||||
| @ -239,8 +216,8 @@ class VideoSummaryDetector(AnalysisMethod): | ||||
|                     ) | ||||
|             except RuntimeError as e: | ||||
|                 warnings.warn( | ||||
|                     "Retry without autocast failed: %s. Attempting cudnn-disabled retry.", | ||||
|                     e, | ||||
|                     f"Generation failed with error: {e}. Retrying with cuDNN disabled.", | ||||
|                     RuntimeWarning, | ||||
|                 ) | ||||
|                 cudnn_was_enabled = ( | ||||
|                     torch.backends.cudnn.is_available() and torch.backends.cudnn.enabled | ||||
| @ -273,7 +250,9 @@ class VideoSummaryDetector(AnalysisMethod): | ||||
|             batch = batch.to("cpu") | ||||
| 
 | ||||
|         batch = batch.contiguous() | ||||
|         if batch.dtype != torch.uint8: | ||||
|         if batch.dtype.is_floating_point: | ||||
|             batch = (batch.clamp(0.0, 1.0) * 255.0).to(torch.uint8) | ||||
|         elif batch.dtype != torch.uint8: | ||||
|             batch = batch.to(torch.uint8) | ||||
|         pil_list: List[Image.Image] = [] | ||||
|         for frame in batch: | ||||
| @ -281,10 +260,10 @@ class VideoSummaryDetector(AnalysisMethod): | ||||
|             pil_list.append(Image.fromarray(arr)) | ||||
|         return pil_list | ||||
| 
 | ||||
|     def brute_force_summary( | ||||
|     def make_captions_from_extracted_frames( | ||||
|         self, | ||||
|         extracted_video_gen: Generator[Tuple[torch.Tensor, torch.Tensor], None, None], | ||||
|         summary_instruction: str = "Analyze the following captions from multiple frames of the same video and summarize the overall content of the video in one concise paragraph (1-3 sentences). Focus on the key themes, actions, or events across the video, not just the individual frames.", | ||||
|         list_of_questions: Optional[List[str]] = None, | ||||
|     ) -> Dict[str, Any]: | ||||
|         """ | ||||
|         Generate captions for all extracted frames and then produce a concise summary of the video. | ||||
| @ -296,68 +275,93 @@ class VideoSummaryDetector(AnalysisMethod): | ||||
|         """ | ||||
| 
 | ||||
|         caption_instruction = "Describe this image in one concise caption." | ||||
|         include_questions = bool(list_of_questions) | ||||
|         if include_questions: | ||||
|             q_block = "\n".join( | ||||
|                 [f"{i + 1}. {q.strip()}" for i, q in enumerate(list_of_questions)] | ||||
|             ) | ||||
|             caption_instruction += ( | ||||
|                 " In addition to the concise caption, also answer the following questions based ONLY on the image. Answers must be very brief and concise." | ||||
|                 " Produce exactly two labeled sections: \n\n" | ||||
|                 "Summary: <concise summary>\n\n" | ||||
|                 "VQA Answers: \n1. <answer to question 1>\n2. <answer to question 2>\n etc." | ||||
|                 "\nReturn only those two sections for each image (do not add extra commentary)." | ||||
|                 "\nIf the answer cannot be determined based on the provided answer blocks," | ||||
|                 ' reply with the line "The answer cannot be determined based on the information provided."' | ||||
|                 f"\n\nQuestions:\n{q_block}" | ||||
|             ) | ||||
|         collected: List[Tuple[float, str]] = [] | ||||
|         proc = self.summary_model.processor | ||||
|         try: | ||||
|             for batch_frames, batch_times in extracted_video_gen: | ||||
|                 pil_list = self._tensor_batch_to_pil_list(batch_frames.cpu()) | ||||
| 
 | ||||
|         for batch_frames, batch_times in extracted_video_gen: | ||||
|             pil_list = self._tensor_batch_to_pil_list(batch_frames.cpu()) | ||||
|                 prompt_texts = [] | ||||
|                 for p in pil_list: | ||||
|                     messages = [ | ||||
|                         { | ||||
|                             "role": "user", | ||||
|                             "content": [ | ||||
|                                 {"type": "image", "image": p}, | ||||
|                                 {"type": "text", "text": caption_instruction}, | ||||
|                             ], | ||||
|                         } | ||||
|                     ] | ||||
| 
 | ||||
|             prompt_texts = [] | ||||
|             for p in pil_list: | ||||
|                 messages = [ | ||||
|                     { | ||||
|                         "role": "user", | ||||
|                         "content": [ | ||||
|                             {"type": "image", "image": p}, | ||||
|                             {"type": "text", "text": caption_instruction}, | ||||
|                         ], | ||||
|                     } | ||||
|                 ] | ||||
|                     prompt_text = proc.apply_chat_template( | ||||
|                         messages, tokenize=False, add_generation_prompt=True | ||||
|                     ) | ||||
|                     prompt_texts.append(prompt_text) | ||||
| 
 | ||||
|                 prompt_text = proc.apply_chat_template( | ||||
|                     messages, tokenize=False, add_generation_prompt=True | ||||
|                 processor_inputs = proc( | ||||
|                     text=prompt_texts, | ||||
|                     images=pil_list, | ||||
|                     return_tensors="pt", | ||||
|                     padding=True, | ||||
|                 ) | ||||
|                 captions = self._generate_from_processor_inputs( | ||||
|                     processor_inputs, | ||||
|                     prompt_texts, | ||||
|                     self.summary_model.tokenizer, | ||||
|                 ) | ||||
|                 prompt_texts.append(prompt_text) | ||||
| 
 | ||||
|             processor_inputs = proc( | ||||
|                 text=prompt_texts, images=pil_list, return_tensors="pt", padding=True | ||||
|             ) | ||||
|             captions = self._generate_from_processor_inputs( | ||||
|                 processor_inputs, | ||||
|                 prompt_texts, | ||||
|                 self.summary_model, | ||||
|                 self.summary_model.tokenizer, | ||||
|             ) | ||||
| 
 | ||||
|             if isinstance(batch_times, torch.Tensor): | ||||
|                 batch_times_list = batch_times.cpu().tolist() | ||||
|             else: | ||||
|                 batch_times_list = list(batch_times) | ||||
|             for t, c in zip(batch_times_list, captions): | ||||
|                 collected.append((float(t), c)) | ||||
|                 if isinstance(batch_times, torch.Tensor): | ||||
|                     batch_times_list = batch_times.cpu().tolist() | ||||
|                 else: | ||||
|                     batch_times_list = list(batch_times) | ||||
|                 for t, c in zip(batch_times_list, captions): | ||||
|                     collected.append((float(t), c)) | ||||
|         finally: | ||||
|             try: | ||||
|                 extracted_video_gen.close() | ||||
|             except Exception: | ||||
|                 warnings.warn("Failed to close video frame generator.", RuntimeWarning) | ||||
| 
 | ||||
|         collected.sort(key=lambda x: x[0]) | ||||
|         extracted_video_gen.close() | ||||
|         bullets_summary, bullets_vqa = _categorize_outputs(collected, include_questions) | ||||
| 
 | ||||
|         MAX_CAPTIONS_FOR_SUMMARY = 200 | ||||
|         caps_for_summary = ( | ||||
|             collected[-MAX_CAPTIONS_FOR_SUMMARY:] | ||||
|             if len(collected) > MAX_CAPTIONS_FOR_SUMMARY | ||||
|             else collected | ||||
|         ) | ||||
|         return { | ||||
|             "summary_bullets": bullets_summary, | ||||
|             "vqa_bullets": bullets_vqa, | ||||
|         }  # TODO consider taking out time stamps from the returned structure | ||||
| 
 | ||||
|         bullets = [] | ||||
|         for t, c in caps_for_summary: | ||||
|             snippet = c.replace("\n", " ").strip() | ||||
|             bullets.append(f"- [{t:.3f}s] {snippet}") | ||||
|     def final_summary(self, summary_dict: Dict[str, Any]) -> Dict[str, Any]: | ||||
|         """ | ||||
|         Produce a concise summary of the video, based on generated captions for all extracted frames. | ||||
|         Args: | ||||
|             summary_dict (Dict[str, Any]): Dictionary containing captions for the frames. | ||||
|         Returns: | ||||
|             Dict[str, Any]: A dictionary containing the list of captions with timestamps and the final summary. | ||||
|         """ | ||||
|         summary_instruction = "Analyze the following captions from multiple frames of the same video and summarize the overall content of the video in one concise paragraph (1-3 sentences). Focus on the key themes, actions, or events across the video, not just the individual frames." | ||||
|         proc = self.summary_model.processor | ||||
| 
 | ||||
|         bullets = summary_dict.get("summary_bullets", []) | ||||
|         if not bullets: | ||||
|             raise ValueError("No captions available for summary generation.") | ||||
| 
 | ||||
|         combined_captions_text = "\n".join(bullets) | ||||
|         summary_user_text = ( | ||||
|             summary_instruction | ||||
|             + "\n\n" | ||||
|             + combined_captions_text | ||||
|             + "\n\nPlease produce a single concise paragraph." | ||||
|         ) | ||||
|         summary_user_text = summary_instruction + "\n\n" + combined_captions_text + "\n" | ||||
| 
 | ||||
|         messages = [ | ||||
|             { | ||||
| @ -381,40 +385,130 @@ class VideoSummaryDetector(AnalysisMethod): | ||||
|         final_summary_list = self._generate_from_processor_inputs( | ||||
|             summary_inputs, | ||||
|             [summary_prompt_text], | ||||
|             self.summary_model.model, | ||||
|             self.summary_model.tokenizer, | ||||
|         ) | ||||
|         final_summary = final_summary_list[0].strip() if final_summary_list else "" | ||||
| 
 | ||||
|         return { | ||||
|             "captions": collected, | ||||
|             "summary": final_summary, | ||||
|         } | ||||
| 
 | ||||
|     def final_answers( | ||||
|         self, | ||||
|         answers_dict: Dict[str, Any], | ||||
|         list_of_questions: List[str], | ||||
|     ) -> Dict[str, Any]: | ||||
|         """ | ||||
|         Answer the list of questions for the video based on the VQA bullets from the frames. | ||||
|         Args: | ||||
|             answers_dict (Dict[str, Any]): Dictionary containing the VQA bullets. | ||||
|         Returns: | ||||
|             Dict[str, Any]: A dictionary containing the list of answers to the questions. | ||||
|         """ | ||||
|         vqa_bullets = answers_dict.get("vqa_bullets", []) | ||||
|         if not vqa_bullets: | ||||
|             raise ValueError( | ||||
|                 "No VQA bullets generated for single frames available for answering questions." | ||||
|             ) | ||||
| 
 | ||||
|         include_questions = bool(list_of_questions) | ||||
|         if include_questions: | ||||
|             q_block = "\n".join( | ||||
|                 [f"{i + 1}. {q.strip()}" for i, q in enumerate(list_of_questions)] | ||||
|             ) | ||||
|             prompt = ( | ||||
|                 "You are provided with a set of short VQA-captions, each of which is a block of short answers" | ||||
|                 " extracted from individual frames of the same video.\n\n" | ||||
|                 "VQA-captions (use ONLY these to answer):\n" | ||||
|                 f"{vqa_bullets}\n\n" | ||||
|                 "Answer the following questions briefly, based ONLY on the lists of answers provided above. The VQA-captions above contain answers" | ||||
|                 " to the same questions you are about to answer. If the answer cannot be determined based on the provided answer blocks," | ||||
|                 ' reply with the line "The answer cannot be determined based on the information provided."' | ||||
|                 "Questions:\n" | ||||
|                 f"{q_block}\n\n" | ||||
|                 "Produce an ordered list with answers in the same order as the questions. You must have this structure of your output: " | ||||
|                 "Answers: \n1. <answer to question 1>\n2. <answer to question 2>\n etc." | ||||
|                 "Return ONLY the ordered list with answers and NOTHING else — no commentary, no explanation, no surrounding markdown." | ||||
|             ) | ||||
|         else: | ||||
|             raise ValueError( | ||||
|                 "list_of_questions must be provided for making final answers." | ||||
|             ) | ||||
| 
 | ||||
|         proc = self.summary_model.processor | ||||
|         messages = [ | ||||
|             { | ||||
|                 "role": "user", | ||||
|                 "content": [{"type": "text", "text": prompt}], | ||||
|             } | ||||
|         ] | ||||
|         final_vqa_prompt_text = proc.apply_chat_template( | ||||
|             messages, tokenize=False, add_generation_prompt=True | ||||
|         ) | ||||
|         final_vqa_inputs = proc( | ||||
|             text=[final_vqa_prompt_text], return_tensors="pt", padding=True | ||||
|         ) | ||||
|         final_vqa_inputs = { | ||||
|             k: v.to(self.summary_model.device) if isinstance(v, torch.Tensor) else v | ||||
|             for k, v in final_vqa_inputs.items() | ||||
|         } | ||||
| 
 | ||||
|         final_vqa_list = self._generate_from_processor_inputs( | ||||
|             final_vqa_inputs, | ||||
|             [final_vqa_prompt_text], | ||||
|             self.summary_model.tokenizer, | ||||
|         ) | ||||
| 
 | ||||
|         final_vqa_output = final_vqa_list[0].strip() if final_vqa_list else "" | ||||
|         vqa_answers = [] | ||||
|         answer_matches = re.findall( | ||||
|             r"\d+\.\s*(.*?)(?=\n\d+\.|$)", final_vqa_output, flags=re.DOTALL | ||||
|         ) | ||||
|         for answer in answer_matches: | ||||
|             vqa_answers.append(answer.strip()) | ||||
|         return { | ||||
|             "vqa_answers": vqa_answers, | ||||
|         } | ||||
| 
 | ||||
|     def analyse_videos_from_dict( | ||||
|         self, frame_rate_per_second: float = 2.0 | ||||
|         self, | ||||
|         analysis_type: Union[AnalysisType, str] = AnalysisType.SUMMARY, | ||||
|         frame_rate_per_second: float = 2.0, | ||||
|         list_of_questions: Optional[List[str]] = None, | ||||
|     ) -> Dict[str, Any]: | ||||
|         """ | ||||
|         Analyse the video specified in self.subdict using frame extraction and captioning. | ||||
|         For short videos (<=100 frames at the specified frame rate), it uses brute-force captioning. | ||||
|         For longer videos, it currently defaults to brute-force captioning, but can be extended for more complex methods. | ||||
| 
 | ||||
|         Args: | ||||
|             analysis_type (Union[AnalysisType, str], optional): Type of analysis to perform. Defaults to AnalysisType.SUMMARY. | ||||
|             frame_rate_per_second (float): Frame extraction rate in frames per second. Default is 2.0. | ||||
|             list_of_questions (List[str], optional): List of questions to answer about the video. Required if analysis_type includes questions. | ||||
|         Returns: | ||||
|             Dict[str, Any]: A dictionary containing the analysis results, including captions and summary. | ||||
|             Dict[str, Any]: A dictionary containing the analysis results, including summary and answers for provided questions(if any). | ||||
|         """ | ||||
| 
 | ||||
|         all_answers = {} | ||||
|         # TODO: add support for answering questions about videos | ||||
|         for video_key in list(self.subdict.keys()): | ||||
|             entry = self.subdict[video_key] | ||||
|             extracted_video_gen = self._extract_video_frames( | ||||
|         analysis_type, is_summary, is_questions = AnalysisType._validate_analysis_type( | ||||
|             analysis_type, list_of_questions | ||||
|         ) | ||||
| 
 | ||||
|         for video_key, entry in self.subdict.items(): | ||||
|             summary_and_vqa = {} | ||||
|             frames_generator = self._extract_video_frames( | ||||
|                 entry, frame_rate_per_second=frame_rate_per_second | ||||
|             ) | ||||
| 
 | ||||
|             answer = self.brute_force_summary(extracted_video_gen) | ||||
|             all_answers[video_key] = {"summary": answer["summary"]} | ||||
|             answers_dict = self.make_captions_from_extracted_frames( | ||||
|                 frames_generator, list_of_questions=list_of_questions | ||||
|             ) | ||||
|             # TODO: captions has to be post-processed with foreseeing audio analysis | ||||
|             # TODO: captions and answers may lead to prompt, that superior model limits. Consider hierarchical approach. | ||||
|             if is_summary: | ||||
|                 answer = self.final_summary(answers_dict) | ||||
|                 summary_and_vqa["summary"] = answer["summary"] | ||||
|             if is_questions: | ||||
|                 answer = self.final_answers(answers_dict, list_of_questions) | ||||
|                 summary_and_vqa["vqa_answers"] = answer["vqa_answers"] | ||||
| 
 | ||||
|         return answer | ||||
|             all_answers[video_key] = summary_and_vqa | ||||
| 
 | ||||
|         return all_answers | ||||
|  | ||||
		Загрузка…
	
	
			
			x
			
			
		
	
		Ссылка в новой задаче
	
	Block a user
	 Dmitrii Kapitan
						Dmitrii Kapitan