зеркало из
https://github.com/ssciwr/AMMICO.git
synced 2025-10-29 13:06:04 +02:00
Merge pull request #8 from DimasfromLavoisier/add_video_vqa
add functionality for video vqa
Этот коммит содержится в:
Коммит
3e036514cd
@ -6,7 +6,7 @@ import torch
|
||||
from PIL import Image
|
||||
import warnings
|
||||
|
||||
from typing import List, Optional, Union, Dict, Any, Tuple
|
||||
from typing import List, Optional, Union, Dict, Any
|
||||
from collections.abc import Sequence as _Sequence
|
||||
from transformers import GenerationConfig
|
||||
from qwen_vl_utils import process_vision_info
|
||||
@ -120,42 +120,11 @@ class ImageSummaryDetector(AnalysisMethod):
|
||||
|
||||
return inputs
|
||||
|
||||
def _validate_analysis_type(
|
||||
self,
|
||||
analysis_type: Union["AnalysisType", str],
|
||||
list_of_questions: Optional[List[str]],
|
||||
max_questions_per_image: int,
|
||||
) -> Tuple[str, List[str], bool, bool]:
|
||||
if isinstance(analysis_type, AnalysisType):
|
||||
analysis_type = analysis_type.value
|
||||
|
||||
allowed = {"summary", "questions", "summary_and_questions"}
|
||||
if analysis_type not in allowed:
|
||||
raise ValueError(f"analysis_type must be one of {allowed}")
|
||||
|
||||
if list_of_questions is None:
|
||||
list_of_questions = [
|
||||
"Are there people in the image?",
|
||||
"What is this picture about?",
|
||||
]
|
||||
|
||||
if analysis_type in ("questions", "summary_and_questions"):
|
||||
if len(list_of_questions) > max_questions_per_image:
|
||||
raise ValueError(
|
||||
f"Number of questions per image ({len(list_of_questions)}) exceeds safety cap ({max_questions_per_image}). Reduce questions or increase max_questions_per_image."
|
||||
)
|
||||
|
||||
is_summary = analysis_type in ("summary", "summary_and_questions")
|
||||
is_questions = analysis_type in ("questions", "summary_and_questions")
|
||||
|
||||
return analysis_type, list_of_questions, is_summary, is_questions
|
||||
|
||||
def analyse_image(
|
||||
self,
|
||||
entry: dict,
|
||||
analysis_type: Union[str, AnalysisType] = AnalysisType.SUMMARY_AND_QUESTIONS,
|
||||
analysis_type: Union[str, AnalysisType] = AnalysisType.SUMMARY,
|
||||
list_of_questions: Optional[List[str]] = None,
|
||||
max_questions_per_image: int = 32,
|
||||
is_concise_summary: bool = True,
|
||||
is_concise_answer: bool = True,
|
||||
) -> Dict[str, Any]:
|
||||
@ -165,10 +134,8 @@ class ImageSummaryDetector(AnalysisMethod):
|
||||
- 'vqa' (dict) if questions requested
|
||||
"""
|
||||
self.subdict = entry
|
||||
analysis_type, list_of_questions, is_summary, is_questions = (
|
||||
self._validate_analysis_type(
|
||||
analysis_type, list_of_questions, max_questions_per_image
|
||||
)
|
||||
analysis_type, is_summary, is_questions = AnalysisType._validate_analysis_type(
|
||||
analysis_type, list_of_questions
|
||||
)
|
||||
|
||||
if is_summary:
|
||||
@ -195,9 +162,8 @@ class ImageSummaryDetector(AnalysisMethod):
|
||||
|
||||
def analyse_images_from_dict(
|
||||
self,
|
||||
analysis_type: Union[AnalysisType, str] = AnalysisType.SUMMARY_AND_QUESTIONS,
|
||||
analysis_type: Union[AnalysisType, str] = AnalysisType.SUMMARY,
|
||||
list_of_questions: Optional[List[str]] = None,
|
||||
max_questions_per_image: int = 32,
|
||||
keys_batch_size: int = 16,
|
||||
is_concise_summary: bool = True,
|
||||
is_concise_answer: bool = True,
|
||||
@ -208,7 +174,6 @@ class ImageSummaryDetector(AnalysisMethod):
|
||||
Args:
|
||||
analysis_type (str): type of the analysis.
|
||||
list_of_questions (list[str]): list of questions.
|
||||
max_questions_per_image (int): maximum number of questions per image. We recommend to keep it low to avoid long processing times and high memory usage.
|
||||
keys_batch_size (int): number of images to process in a batch.
|
||||
is_concise_summary (bool): whether to generate concise summary.
|
||||
is_concise_answer (bool): whether to generate concise answers.
|
||||
@ -216,10 +181,8 @@ class ImageSummaryDetector(AnalysisMethod):
|
||||
self.subdict (dict): dictionary with analysis results.
|
||||
"""
|
||||
# TODO: add option to ask multiple questions per image as one batch.
|
||||
analysis_type, list_of_questions, is_summary, is_questions = (
|
||||
self._validate_analysis_type(
|
||||
analysis_type, list_of_questions, max_questions_per_image
|
||||
)
|
||||
analysis_type, is_summary, is_questions = AnalysisType._validate_analysis_type(
|
||||
analysis_type, list_of_questions
|
||||
)
|
||||
|
||||
keys = list(self.subdict.keys())
|
||||
@ -284,7 +247,7 @@ class ImageSummaryDetector(AnalysisMethod):
|
||||
with torch.inference_mode():
|
||||
try:
|
||||
if self.summary_model.device == "cuda":
|
||||
with torch.cuda.amp.autocast(enabled=True):
|
||||
with torch.amp.autocast("cuda", enabled=True):
|
||||
generated_ids = self.summary_model.model.generate(
|
||||
**inputs, generation_config=gen_conf
|
||||
)
|
||||
@ -366,7 +329,7 @@ class ImageSummaryDetector(AnalysisMethod):
|
||||
inputs = self._prepare_inputs(chunk, entry)
|
||||
with torch.inference_mode():
|
||||
if self.summary_model.device == "cuda":
|
||||
with torch.cuda.amp.autocast(enabled=True):
|
||||
with torch.amp.autocast("cuda", enabled=True):
|
||||
out_ids = self.summary_model.model.generate(
|
||||
**inputs, generation_config=gen_conf
|
||||
)
|
||||
|
||||
114
ammico/utils.py
114
ammico/utils.py
@ -6,6 +6,8 @@ import importlib_resources
|
||||
import collections
|
||||
import random
|
||||
from enum import Enum
|
||||
from typing import List, Tuple, Optional, Union
|
||||
import re
|
||||
|
||||
|
||||
pkg = importlib_resources.files("ammico")
|
||||
@ -46,6 +48,35 @@ class AnalysisType(str, Enum):
|
||||
QUESTIONS = "questions"
|
||||
SUMMARY_AND_QUESTIONS = "summary_and_questions"
|
||||
|
||||
@classmethod
|
||||
def _validate_analysis_type(
|
||||
cls,
|
||||
analysis_type: Union["AnalysisType", str],
|
||||
list_of_questions: Optional[List[str]],
|
||||
) -> Tuple[str, bool, bool]:
|
||||
max_questions_per_image = 15 # safety cap to avoid too many questions
|
||||
if isinstance(analysis_type, AnalysisType):
|
||||
analysis_type = analysis_type.value
|
||||
|
||||
allowed = {item.value for item in AnalysisType}
|
||||
if analysis_type not in allowed:
|
||||
raise ValueError(f"analysis_type must be one of {allowed}")
|
||||
|
||||
if analysis_type in ("questions", "summary_and_questions"):
|
||||
if not list_of_questions:
|
||||
raise ValueError(
|
||||
"list_of_questions must be provided for QUESTIONS analysis type."
|
||||
)
|
||||
|
||||
if len(list_of_questions) > max_questions_per_image:
|
||||
raise ValueError(
|
||||
f"Number of questions per image ({len(list_of_questions)}) exceeds safety cap ({max_questions_per_image}). Reduce questions or increase max_questions_per_image."
|
||||
)
|
||||
|
||||
is_summary = analysis_type in ("summary", "summary_and_questions")
|
||||
is_questions = analysis_type in ("questions", "summary_and_questions")
|
||||
return analysis_type, is_summary, is_questions
|
||||
|
||||
|
||||
class AnalysisMethod:
|
||||
"""Base class to be inherited by all analysis methods."""
|
||||
@ -101,6 +132,89 @@ def _limit_results(results, limit):
|
||||
return results
|
||||
|
||||
|
||||
def _categorize_outputs(
|
||||
collected: List[Tuple[float, str]],
|
||||
include_questions: bool = False,
|
||||
) -> Tuple[List[str], List[str]]:
|
||||
"""
|
||||
Categorize collected outputs into summary bullets and VQA bullets.
|
||||
Args:
|
||||
collected (List[Tuple[float, str]]): List of tuples containing timestamps and generated texts.
|
||||
Returns:
|
||||
Tuple[List[str], List[str]]: A tuple containing two lists - summary bullets and VQA bullets.
|
||||
"""
|
||||
MAX_CAPTIONS_FOR_SUMMARY = 600 # TODO For now, this is a constant value, but later we need to make it adjustable, with the idea of cutting out the most similar frames to reduce the load on the system.
|
||||
caps_for_summary_vqa = (
|
||||
collected[-MAX_CAPTIONS_FOR_SUMMARY:]
|
||||
if len(collected) > MAX_CAPTIONS_FOR_SUMMARY
|
||||
else collected
|
||||
)
|
||||
bullets_summary = []
|
||||
bullets_vqa = []
|
||||
|
||||
for t, c in caps_for_summary_vqa:
|
||||
if include_questions:
|
||||
result_sections = c.strip()
|
||||
m = re.search(
|
||||
r"Summary\s*:\s*(.*?)\s*(?:VQA\s+Answers\s*:\s*(.*))?$",
|
||||
result_sections,
|
||||
flags=re.IGNORECASE | re.DOTALL,
|
||||
)
|
||||
if m:
|
||||
summary_text = (
|
||||
m.group(1).replace("\n", " ").strip() if m.group(1) else None
|
||||
)
|
||||
vqa_text = m.group(2).strip() if m.group(2) else None
|
||||
if not summary_text or not vqa_text:
|
||||
raise ValueError(
|
||||
f"Model output is missing either summary or VQA answers: {c}"
|
||||
)
|
||||
bullets_summary.append(f"- [{t:.3f}s] {summary_text}")
|
||||
bullets_vqa.append(f"- [{t:.3f}s] {vqa_text}")
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Failed to parse summary and VQA answers from model output: {c}"
|
||||
)
|
||||
else:
|
||||
snippet = c.replace("\n", " ").strip()
|
||||
bullets_summary.append(f"- [{t:.3f}s] {snippet}")
|
||||
return bullets_summary, bullets_vqa
|
||||
|
||||
|
||||
def _normalize_whitespace(s: str) -> str:
|
||||
return re.sub(r"\s+", " ", s).strip()
|
||||
|
||||
|
||||
def _strip_prompt_prefix_literal(decoded: str, prompt: str) -> str:
|
||||
"""
|
||||
Remove any literal prompt prefix from decoded text using a normalized-substring match.
|
||||
Guarantees no prompt text remains at the start of returned string (best-effort).
|
||||
"""
|
||||
if not decoded:
|
||||
return ""
|
||||
if not prompt:
|
||||
return decoded.strip()
|
||||
|
||||
d_norm = _normalize_whitespace(decoded)
|
||||
p_norm = _normalize_whitespace(prompt)
|
||||
|
||||
idx = d_norm.find(p_norm)
|
||||
if idx != -1:
|
||||
running = []
|
||||
for i, ch in enumerate(decoded):
|
||||
running.append(ch if not ch.isspace() else " ")
|
||||
cur_norm = _normalize_whitespace("".join(running))
|
||||
if cur_norm.endswith(p_norm):
|
||||
return decoded[i + 1 :].lstrip() if i + 1 < len(decoded) else ""
|
||||
m = re.match(
|
||||
r"^(?:\s*(system|user|assistant)[:\s-]*\n?)+", decoded, flags=re.IGNORECASE
|
||||
)
|
||||
if m:
|
||||
return decoded[m.end() :].lstrip()
|
||||
|
||||
return decoded.lstrip("\n\r ").lstrip(":;- ").strip()
|
||||
|
||||
|
||||
def find_videos(
|
||||
path: str = None,
|
||||
pattern=["mp4"], # TODO: test with more video formats
|
||||
|
||||
@ -6,17 +6,24 @@ from PIL import Image
|
||||
from torchcodec.decoders import VideoDecoder
|
||||
|
||||
from ammico.model import MultimodalSummaryModel
|
||||
from ammico.utils import AnalysisMethod
|
||||
from ammico.utils import (
|
||||
AnalysisMethod,
|
||||
AnalysisType,
|
||||
_categorize_outputs,
|
||||
_strip_prompt_prefix_literal,
|
||||
)
|
||||
|
||||
from typing import List, Dict, Any, Generator, Tuple
|
||||
from typing import List, Dict, Any, Generator, Tuple, Union, Optional
|
||||
from transformers import GenerationConfig
|
||||
|
||||
|
||||
class VideoSummaryDetector(AnalysisMethod):
|
||||
MAX_SAMPLES_CAP = 1000 # safety cap for total extracted frames
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
summary_model: MultimodalSummaryModel,
|
||||
subdict: dict = {},
|
||||
subdict: Optional[Dict[str, Any]] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Class for analysing videos using QWEN-2.5-VL model.
|
||||
@ -29,6 +36,8 @@ class VideoSummaryDetector(AnalysisMethod):
|
||||
Returns:
|
||||
None.
|
||||
"""
|
||||
if subdict is None:
|
||||
subdict = {}
|
||||
|
||||
super().__init__(subdict)
|
||||
self.summary_model = summary_model
|
||||
@ -107,6 +116,7 @@ class VideoSummaryDetector(AnalysisMethod):
|
||||
raise ValueError("frame_rate_per_second must be > 0")
|
||||
|
||||
n_samples = max(1, int(math.floor(duration * frame_rate_per_second)))
|
||||
n_samples = min(n_samples, self.MAX_SAMPLES_CAP)
|
||||
|
||||
if begin_stream_seconds is not None and end_stream_seconds is not None:
|
||||
sample_times = torch.linspace(
|
||||
@ -126,38 +136,6 @@ class VideoSummaryDetector(AnalysisMethod):
|
||||
|
||||
return generator
|
||||
|
||||
def _normalize_whitespace(self, s: str) -> str:
|
||||
return re.sub(r"\s+", " ", s).strip()
|
||||
|
||||
def _strip_prompt_prefix_literal(self, decoded: str, prompt: str) -> str:
|
||||
"""
|
||||
Remove any literal prompt prefix from decoded text using a normalized-substring match.
|
||||
Guarantees no prompt text remains at the start of returned string (best-effort).
|
||||
"""
|
||||
if not decoded:
|
||||
return ""
|
||||
if not prompt:
|
||||
return decoded.strip()
|
||||
|
||||
d_norm = self._normalize_whitespace(decoded)
|
||||
p_norm = self._normalize_whitespace(prompt)
|
||||
|
||||
idx = d_norm.find(p_norm)
|
||||
if idx != -1:
|
||||
running = []
|
||||
for i, ch in enumerate(decoded):
|
||||
running.append(ch if not ch.isspace() else " ")
|
||||
cur_norm = self._normalize_whitespace("".join(running))
|
||||
if cur_norm.endswith(p_norm):
|
||||
return decoded[i + 1 :].lstrip() if i + 1 < len(decoded) else ""
|
||||
m = re.match(
|
||||
r"^(?:\s*(system|user|assistant)[:\s-]*\n?)+", decoded, flags=re.IGNORECASE
|
||||
)
|
||||
if m:
|
||||
return decoded[m.end() :].lstrip()
|
||||
|
||||
return decoded.lstrip("\n\r ").lstrip(":;- ").strip()
|
||||
|
||||
def _decode_trimmed_outputs(
|
||||
self,
|
||||
generated_ids: torch.Tensor,
|
||||
@ -170,20 +148,18 @@ class VideoSummaryDetector(AnalysisMethod):
|
||||
Then remove any literal prompt prefix using prompt_texts (one per batch element).
|
||||
"""
|
||||
|
||||
decoded_results = []
|
||||
batch_size = generated_ids.shape[0]
|
||||
|
||||
if "input_ids" in inputs:
|
||||
lengths = (
|
||||
inputs["input_ids"]
|
||||
.ne(
|
||||
tokenizer.pad_token_id
|
||||
if tokenizer.pad_token_id is not None
|
||||
else tokenizer.eos_token_id
|
||||
)
|
||||
.sum(dim=1)
|
||||
.tolist()
|
||||
token_for_padding = (
|
||||
tokenizer.pad_token_id
|
||||
if getattr(tokenizer, "pad_token_id", None) is not None
|
||||
else getattr(tokenizer, "eos_token_id", None)
|
||||
)
|
||||
if token_for_padding is None:
|
||||
lengths = [int(inputs["input_ids"].shape[1])] * batch_size
|
||||
else:
|
||||
lengths = inputs["input_ids"].ne(token_for_padding).sum(dim=1).tolist()
|
||||
else:
|
||||
lengths = [0] * batch_size
|
||||
|
||||
@ -195,13 +171,15 @@ class VideoSummaryDetector(AnalysisMethod):
|
||||
t = out_ids[in_len:]
|
||||
else:
|
||||
t = out_ids.new_empty((0,), dtype=out_ids.dtype)
|
||||
trimmed_ids.append(t)
|
||||
t_cpu = t.to("cpu")
|
||||
trimmed_ids.append(t_cpu.tolist())
|
||||
|
||||
decoded = tokenizer.batch_decode(
|
||||
trimmed_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True
|
||||
)
|
||||
decoded_results = []
|
||||
for ptext, raw in zip(prompt_texts, decoded):
|
||||
cleaned = self._strip_prompt_prefix_literal(raw, ptext)
|
||||
cleaned = _strip_prompt_prefix_literal(raw, ptext)
|
||||
decoded_results.append(cleaned)
|
||||
return decoded_results
|
||||
|
||||
@ -209,7 +187,6 @@ class VideoSummaryDetector(AnalysisMethod):
|
||||
self,
|
||||
processor_inputs: Dict[str, torch.Tensor],
|
||||
prompt_texts: List[str],
|
||||
model,
|
||||
tokenizer,
|
||||
):
|
||||
"""
|
||||
@ -224,7 +201,7 @@ class VideoSummaryDetector(AnalysisMethod):
|
||||
|
||||
for k, v in list(processor_inputs.items()):
|
||||
if isinstance(v, torch.Tensor):
|
||||
processor_inputs[k] = v.to(model.device)
|
||||
processor_inputs[k] = v.to(self.summary_model.device)
|
||||
|
||||
with torch.inference_mode():
|
||||
try:
|
||||
@ -239,8 +216,8 @@ class VideoSummaryDetector(AnalysisMethod):
|
||||
)
|
||||
except RuntimeError as e:
|
||||
warnings.warn(
|
||||
"Retry without autocast failed: %s. Attempting cudnn-disabled retry.",
|
||||
e,
|
||||
f"Generation failed with error: {e}. Retrying with cuDNN disabled.",
|
||||
RuntimeWarning,
|
||||
)
|
||||
cudnn_was_enabled = (
|
||||
torch.backends.cudnn.is_available() and torch.backends.cudnn.enabled
|
||||
@ -273,7 +250,9 @@ class VideoSummaryDetector(AnalysisMethod):
|
||||
batch = batch.to("cpu")
|
||||
|
||||
batch = batch.contiguous()
|
||||
if batch.dtype != torch.uint8:
|
||||
if batch.dtype.is_floating_point:
|
||||
batch = (batch.clamp(0.0, 1.0) * 255.0).to(torch.uint8)
|
||||
elif batch.dtype != torch.uint8:
|
||||
batch = batch.to(torch.uint8)
|
||||
pil_list: List[Image.Image] = []
|
||||
for frame in batch:
|
||||
@ -281,10 +260,10 @@ class VideoSummaryDetector(AnalysisMethod):
|
||||
pil_list.append(Image.fromarray(arr))
|
||||
return pil_list
|
||||
|
||||
def brute_force_summary(
|
||||
def make_captions_from_extracted_frames(
|
||||
self,
|
||||
extracted_video_gen: Generator[Tuple[torch.Tensor, torch.Tensor], None, None],
|
||||
summary_instruction: str = "Analyze the following captions from multiple frames of the same video and summarize the overall content of the video in one concise paragraph (1-3 sentences). Focus on the key themes, actions, or events across the video, not just the individual frames.",
|
||||
list_of_questions: Optional[List[str]] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Generate captions for all extracted frames and then produce a concise summary of the video.
|
||||
@ -296,68 +275,93 @@ class VideoSummaryDetector(AnalysisMethod):
|
||||
"""
|
||||
|
||||
caption_instruction = "Describe this image in one concise caption."
|
||||
include_questions = bool(list_of_questions)
|
||||
if include_questions:
|
||||
q_block = "\n".join(
|
||||
[f"{i + 1}. {q.strip()}" for i, q in enumerate(list_of_questions)]
|
||||
)
|
||||
caption_instruction += (
|
||||
" In addition to the concise caption, also answer the following questions based ONLY on the image. Answers must be very brief and concise."
|
||||
" Produce exactly two labeled sections: \n\n"
|
||||
"Summary: <concise summary>\n\n"
|
||||
"VQA Answers: \n1. <answer to question 1>\n2. <answer to question 2>\n etc."
|
||||
"\nReturn only those two sections for each image (do not add extra commentary)."
|
||||
"\nIf the answer cannot be determined based on the provided answer blocks,"
|
||||
' reply with the line "The answer cannot be determined based on the information provided."'
|
||||
f"\n\nQuestions:\n{q_block}"
|
||||
)
|
||||
collected: List[Tuple[float, str]] = []
|
||||
proc = self.summary_model.processor
|
||||
try:
|
||||
for batch_frames, batch_times in extracted_video_gen:
|
||||
pil_list = self._tensor_batch_to_pil_list(batch_frames.cpu())
|
||||
|
||||
for batch_frames, batch_times in extracted_video_gen:
|
||||
pil_list = self._tensor_batch_to_pil_list(batch_frames.cpu())
|
||||
prompt_texts = []
|
||||
for p in pil_list:
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image", "image": p},
|
||||
{"type": "text", "text": caption_instruction},
|
||||
],
|
||||
}
|
||||
]
|
||||
|
||||
prompt_texts = []
|
||||
for p in pil_list:
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image", "image": p},
|
||||
{"type": "text", "text": caption_instruction},
|
||||
],
|
||||
}
|
||||
]
|
||||
prompt_text = proc.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
prompt_texts.append(prompt_text)
|
||||
|
||||
prompt_text = proc.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=True
|
||||
processor_inputs = proc(
|
||||
text=prompt_texts,
|
||||
images=pil_list,
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
)
|
||||
captions = self._generate_from_processor_inputs(
|
||||
processor_inputs,
|
||||
prompt_texts,
|
||||
self.summary_model.tokenizer,
|
||||
)
|
||||
prompt_texts.append(prompt_text)
|
||||
|
||||
processor_inputs = proc(
|
||||
text=prompt_texts, images=pil_list, return_tensors="pt", padding=True
|
||||
)
|
||||
captions = self._generate_from_processor_inputs(
|
||||
processor_inputs,
|
||||
prompt_texts,
|
||||
self.summary_model,
|
||||
self.summary_model.tokenizer,
|
||||
)
|
||||
|
||||
if isinstance(batch_times, torch.Tensor):
|
||||
batch_times_list = batch_times.cpu().tolist()
|
||||
else:
|
||||
batch_times_list = list(batch_times)
|
||||
for t, c in zip(batch_times_list, captions):
|
||||
collected.append((float(t), c))
|
||||
if isinstance(batch_times, torch.Tensor):
|
||||
batch_times_list = batch_times.cpu().tolist()
|
||||
else:
|
||||
batch_times_list = list(batch_times)
|
||||
for t, c in zip(batch_times_list, captions):
|
||||
collected.append((float(t), c))
|
||||
finally:
|
||||
try:
|
||||
extracted_video_gen.close()
|
||||
except Exception:
|
||||
warnings.warn("Failed to close video frame generator.", RuntimeWarning)
|
||||
|
||||
collected.sort(key=lambda x: x[0])
|
||||
extracted_video_gen.close()
|
||||
bullets_summary, bullets_vqa = _categorize_outputs(collected, include_questions)
|
||||
|
||||
MAX_CAPTIONS_FOR_SUMMARY = 200
|
||||
caps_for_summary = (
|
||||
collected[-MAX_CAPTIONS_FOR_SUMMARY:]
|
||||
if len(collected) > MAX_CAPTIONS_FOR_SUMMARY
|
||||
else collected
|
||||
)
|
||||
return {
|
||||
"summary_bullets": bullets_summary,
|
||||
"vqa_bullets": bullets_vqa,
|
||||
} # TODO consider taking out time stamps from the returned structure
|
||||
|
||||
bullets = []
|
||||
for t, c in caps_for_summary:
|
||||
snippet = c.replace("\n", " ").strip()
|
||||
bullets.append(f"- [{t:.3f}s] {snippet}")
|
||||
def final_summary(self, summary_dict: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
Produce a concise summary of the video, based on generated captions for all extracted frames.
|
||||
Args:
|
||||
summary_dict (Dict[str, Any]): Dictionary containing captions for the frames.
|
||||
Returns:
|
||||
Dict[str, Any]: A dictionary containing the list of captions with timestamps and the final summary.
|
||||
"""
|
||||
summary_instruction = "Analyze the following captions from multiple frames of the same video and summarize the overall content of the video in one concise paragraph (1-3 sentences). Focus on the key themes, actions, or events across the video, not just the individual frames."
|
||||
proc = self.summary_model.processor
|
||||
|
||||
bullets = summary_dict.get("summary_bullets", [])
|
||||
if not bullets:
|
||||
raise ValueError("No captions available for summary generation.")
|
||||
|
||||
combined_captions_text = "\n".join(bullets)
|
||||
summary_user_text = (
|
||||
summary_instruction
|
||||
+ "\n\n"
|
||||
+ combined_captions_text
|
||||
+ "\n\nPlease produce a single concise paragraph."
|
||||
)
|
||||
summary_user_text = summary_instruction + "\n\n" + combined_captions_text + "\n"
|
||||
|
||||
messages = [
|
||||
{
|
||||
@ -381,40 +385,130 @@ class VideoSummaryDetector(AnalysisMethod):
|
||||
final_summary_list = self._generate_from_processor_inputs(
|
||||
summary_inputs,
|
||||
[summary_prompt_text],
|
||||
self.summary_model.model,
|
||||
self.summary_model.tokenizer,
|
||||
)
|
||||
final_summary = final_summary_list[0].strip() if final_summary_list else ""
|
||||
|
||||
return {
|
||||
"captions": collected,
|
||||
"summary": final_summary,
|
||||
}
|
||||
|
||||
def final_answers(
|
||||
self,
|
||||
answers_dict: Dict[str, Any],
|
||||
list_of_questions: List[str],
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Answer the list of questions for the video based on the VQA bullets from the frames.
|
||||
Args:
|
||||
answers_dict (Dict[str, Any]): Dictionary containing the VQA bullets.
|
||||
Returns:
|
||||
Dict[str, Any]: A dictionary containing the list of answers to the questions.
|
||||
"""
|
||||
vqa_bullets = answers_dict.get("vqa_bullets", [])
|
||||
if not vqa_bullets:
|
||||
raise ValueError(
|
||||
"No VQA bullets generated for single frames available for answering questions."
|
||||
)
|
||||
|
||||
include_questions = bool(list_of_questions)
|
||||
if include_questions:
|
||||
q_block = "\n".join(
|
||||
[f"{i + 1}. {q.strip()}" for i, q in enumerate(list_of_questions)]
|
||||
)
|
||||
prompt = (
|
||||
"You are provided with a set of short VQA-captions, each of which is a block of short answers"
|
||||
" extracted from individual frames of the same video.\n\n"
|
||||
"VQA-captions (use ONLY these to answer):\n"
|
||||
f"{vqa_bullets}\n\n"
|
||||
"Answer the following questions briefly, based ONLY on the lists of answers provided above. The VQA-captions above contain answers"
|
||||
" to the same questions you are about to answer. If the answer cannot be determined based on the provided answer blocks,"
|
||||
' reply with the line "The answer cannot be determined based on the information provided."'
|
||||
"Questions:\n"
|
||||
f"{q_block}\n\n"
|
||||
"Produce an ordered list with answers in the same order as the questions. You must have this structure of your output: "
|
||||
"Answers: \n1. <answer to question 1>\n2. <answer to question 2>\n etc."
|
||||
"Return ONLY the ordered list with answers and NOTHING else — no commentary, no explanation, no surrounding markdown."
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"list_of_questions must be provided for making final answers."
|
||||
)
|
||||
|
||||
proc = self.summary_model.processor
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [{"type": "text", "text": prompt}],
|
||||
}
|
||||
]
|
||||
final_vqa_prompt_text = proc.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
final_vqa_inputs = proc(
|
||||
text=[final_vqa_prompt_text], return_tensors="pt", padding=True
|
||||
)
|
||||
final_vqa_inputs = {
|
||||
k: v.to(self.summary_model.device) if isinstance(v, torch.Tensor) else v
|
||||
for k, v in final_vqa_inputs.items()
|
||||
}
|
||||
|
||||
final_vqa_list = self._generate_from_processor_inputs(
|
||||
final_vqa_inputs,
|
||||
[final_vqa_prompt_text],
|
||||
self.summary_model.tokenizer,
|
||||
)
|
||||
|
||||
final_vqa_output = final_vqa_list[0].strip() if final_vqa_list else ""
|
||||
vqa_answers = []
|
||||
answer_matches = re.findall(
|
||||
r"\d+\.\s*(.*?)(?=\n\d+\.|$)", final_vqa_output, flags=re.DOTALL
|
||||
)
|
||||
for answer in answer_matches:
|
||||
vqa_answers.append(answer.strip())
|
||||
return {
|
||||
"vqa_answers": vqa_answers,
|
||||
}
|
||||
|
||||
def analyse_videos_from_dict(
|
||||
self, frame_rate_per_second: float = 2.0
|
||||
self,
|
||||
analysis_type: Union[AnalysisType, str] = AnalysisType.SUMMARY,
|
||||
frame_rate_per_second: float = 2.0,
|
||||
list_of_questions: Optional[List[str]] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Analyse the video specified in self.subdict using frame extraction and captioning.
|
||||
For short videos (<=100 frames at the specified frame rate), it uses brute-force captioning.
|
||||
For longer videos, it currently defaults to brute-force captioning, but can be extended for more complex methods.
|
||||
|
||||
Args:
|
||||
analysis_type (Union[AnalysisType, str], optional): Type of analysis to perform. Defaults to AnalysisType.SUMMARY.
|
||||
frame_rate_per_second (float): Frame extraction rate in frames per second. Default is 2.0.
|
||||
list_of_questions (List[str], optional): List of questions to answer about the video. Required if analysis_type includes questions.
|
||||
Returns:
|
||||
Dict[str, Any]: A dictionary containing the analysis results, including captions and summary.
|
||||
Dict[str, Any]: A dictionary containing the analysis results, including summary and answers for provided questions(if any).
|
||||
"""
|
||||
|
||||
all_answers = {}
|
||||
# TODO: add support for answering questions about videos
|
||||
for video_key in list(self.subdict.keys()):
|
||||
entry = self.subdict[video_key]
|
||||
extracted_video_gen = self._extract_video_frames(
|
||||
analysis_type, is_summary, is_questions = AnalysisType._validate_analysis_type(
|
||||
analysis_type, list_of_questions
|
||||
)
|
||||
|
||||
for video_key, entry in self.subdict.items():
|
||||
summary_and_vqa = {}
|
||||
frames_generator = self._extract_video_frames(
|
||||
entry, frame_rate_per_second=frame_rate_per_second
|
||||
)
|
||||
|
||||
answer = self.brute_force_summary(extracted_video_gen)
|
||||
all_answers[video_key] = {"summary": answer["summary"]}
|
||||
answers_dict = self.make_captions_from_extracted_frames(
|
||||
frames_generator, list_of_questions=list_of_questions
|
||||
)
|
||||
# TODO: captions has to be post-processed with foreseeing audio analysis
|
||||
# TODO: captions and answers may lead to prompt, that superior model limits. Consider hierarchical approach.
|
||||
if is_summary:
|
||||
answer = self.final_summary(answers_dict)
|
||||
summary_and_vqa["summary"] = answer["summary"]
|
||||
if is_questions:
|
||||
answer = self.final_answers(answers_dict, list_of_questions)
|
||||
summary_and_vqa["vqa_answers"] = answer["vqa_answers"]
|
||||
|
||||
return answer
|
||||
all_answers[video_key] = summary_and_vqa
|
||||
|
||||
return all_answers
|
||||
|
||||
Загрузка…
x
Ссылка в новой задаче
Block a user