зеркало из
https://github.com/ssciwr/AMMICO.git
synced 2025-10-29 21:16:06 +02:00
515 строки
21 KiB
Python
515 строки
21 KiB
Python
import re
|
|
import math
|
|
import torch
|
|
import warnings
|
|
from PIL import Image
|
|
from torchcodec.decoders import VideoDecoder
|
|
|
|
from ammico.model import MultimodalSummaryModel
|
|
from ammico.utils import (
|
|
AnalysisMethod,
|
|
AnalysisType,
|
|
_categorize_outputs,
|
|
_strip_prompt_prefix_literal,
|
|
)
|
|
|
|
from typing import List, Dict, Any, Generator, Tuple, Union, Optional
|
|
from transformers import GenerationConfig
|
|
|
|
|
|
class VideoSummaryDetector(AnalysisMethod):
|
|
MAX_SAMPLES_CAP = 1000 # safety cap for total extracted frames
|
|
|
|
def __init__(
|
|
self,
|
|
summary_model: MultimodalSummaryModel,
|
|
subdict: Optional[Dict[str, Any]] = None,
|
|
) -> None:
|
|
"""
|
|
Class for analysing videos using QWEN-2.5-VL model.
|
|
It provides methods for generating captions and answering questions about videos.
|
|
|
|
Args:
|
|
summary_model ([type], optional): An instance of MultimodalSummaryModel to be used for analysis.
|
|
subdict (dict, optional): Dictionary containing the video to be analysed. Defaults to {}.
|
|
|
|
Returns:
|
|
None.
|
|
"""
|
|
if subdict is None:
|
|
subdict = {}
|
|
|
|
super().__init__(subdict)
|
|
self.summary_model = summary_model
|
|
|
|
def _frame_batch_generator(
|
|
self,
|
|
timestamps: torch.Tensor,
|
|
batch_size: int,
|
|
video_decoder: VideoDecoder,
|
|
) -> Generator[Tuple[torch.Tensor, torch.Tensor], None, None]:
|
|
"""
|
|
Yield batches of (frames, timestamps) for given frame indices.
|
|
- frames are returned as a torch.Tensor with shape (B, C, H, W).
|
|
- timestamps is a 1D torch.Tensor with B elements.
|
|
"""
|
|
total = int(timestamps.numel())
|
|
|
|
for start in range(0, total, batch_size):
|
|
batch_secs = timestamps[start : start + batch_size].tolist()
|
|
fb = video_decoder.get_frames_played_at(batch_secs)
|
|
frames = fb.data
|
|
|
|
if not frames.is_contiguous():
|
|
frames = frames.contiguous()
|
|
pts = fb.pts_seconds
|
|
pts_out = (
|
|
pts.cpu().to(dtype=torch.float32)
|
|
if isinstance(pts, torch.Tensor)
|
|
else torch.tensor(pts, dtype=torch.float32)
|
|
)
|
|
yield frames, pts_out
|
|
|
|
def _extract_video_frames(
|
|
self,
|
|
entry: Dict[str, Any],
|
|
frame_rate_per_second: float = 2.0,
|
|
batch_size: int = 32,
|
|
) -> Generator[Tuple[torch.Tensor, torch.Tensor], None, None]:
|
|
"""
|
|
Extract frames from a video at a specified frame rate and return them as a generator of batches.
|
|
Args:
|
|
filename (Union[str, os.PathLike]): Path to the video file.
|
|
frame_rate_per_second (float, optional): Frame extraction rate in frames per second. Default is 2.
|
|
batch_size (int, optional): Number of frames to include in each batch. Default is 32.
|
|
Returns:
|
|
Generator[Tuple[torch.Tensor, torch.Tensor], None, None]: A generator yielding tuples of
|
|
(frames, timestamps), where frames is a tensor of shape (B, C, H, W) and timestamps is a 1D tensor of length B.
|
|
"""
|
|
|
|
filename = entry.get("filename")
|
|
if not filename:
|
|
raise ValueError("entry must contain key 'filename'")
|
|
|
|
video_decoder = VideoDecoder(filename)
|
|
meta = video_decoder.metadata
|
|
|
|
video_fps = getattr(meta, "average_fps", None)
|
|
if video_fps is None or not (
|
|
isinstance(video_fps, (int, float)) and video_fps > 0
|
|
):
|
|
video_fps = 30.0
|
|
|
|
begin_stream_seconds = getattr(meta, "begin_stream_seconds", None)
|
|
end_stream_seconds = getattr(meta, "end_stream_seconds", None)
|
|
nframes = len(video_decoder)
|
|
if getattr(meta, "duration_seconds", None) is not None:
|
|
duration = float(meta.duration_seconds)
|
|
elif begin_stream_seconds is not None and end_stream_seconds is not None:
|
|
duration = float(end_stream_seconds) - float(begin_stream_seconds)
|
|
elif nframes:
|
|
duration = float(nframes) / float(video_fps)
|
|
else:
|
|
duration = 0.0
|
|
|
|
if frame_rate_per_second <= 0:
|
|
raise ValueError("frame_rate_per_second must be > 0")
|
|
|
|
n_samples = max(1, int(math.floor(duration * frame_rate_per_second)))
|
|
n_samples = min(n_samples, self.MAX_SAMPLES_CAP)
|
|
|
|
if begin_stream_seconds is not None and end_stream_seconds is not None:
|
|
sample_times = torch.linspace(
|
|
float(begin_stream_seconds), float(end_stream_seconds), steps=n_samples
|
|
)
|
|
if sample_times.numel() > 1:
|
|
sample_times = torch.clamp(
|
|
sample_times,
|
|
min=float(begin_stream_seconds),
|
|
max=float(end_stream_seconds) - 1e-6,
|
|
)
|
|
else:
|
|
sample_times = torch.linspace(0.0, max(0.0, duration), steps=n_samples)
|
|
|
|
sample_times = sample_times.to(dtype=torch.float32, device="cpu")
|
|
generator = self._frame_batch_generator(sample_times, batch_size, video_decoder)
|
|
|
|
return generator
|
|
|
|
def _decode_trimmed_outputs(
|
|
self,
|
|
generated_ids: torch.Tensor,
|
|
inputs: Dict[str, torch.Tensor],
|
|
tokenizer,
|
|
prompt_texts: List[str],
|
|
) -> List[str]:
|
|
"""
|
|
Trim prompt tokens using attention_mask/input_ids when available and decode to strings.
|
|
Then remove any literal prompt prefix using prompt_texts (one per batch element).
|
|
"""
|
|
|
|
batch_size = generated_ids.shape[0]
|
|
|
|
if "input_ids" in inputs:
|
|
token_for_padding = (
|
|
tokenizer.pad_token_id
|
|
if getattr(tokenizer, "pad_token_id", None) is not None
|
|
else getattr(tokenizer, "eos_token_id", None)
|
|
)
|
|
if token_for_padding is None:
|
|
lengths = [int(inputs["input_ids"].shape[1])] * batch_size
|
|
else:
|
|
lengths = inputs["input_ids"].ne(token_for_padding).sum(dim=1).tolist()
|
|
else:
|
|
lengths = [0] * batch_size
|
|
|
|
trimmed_ids = []
|
|
for i in range(batch_size):
|
|
out_ids = generated_ids[i]
|
|
in_len = int(lengths[i]) if i < len(lengths) else 0
|
|
if out_ids.size(0) > in_len:
|
|
t = out_ids[in_len:]
|
|
else:
|
|
t = out_ids.new_empty((0,), dtype=out_ids.dtype)
|
|
t_cpu = t.to("cpu")
|
|
trimmed_ids.append(t_cpu.tolist())
|
|
|
|
decoded = tokenizer.batch_decode(
|
|
trimmed_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True
|
|
)
|
|
decoded_results = []
|
|
for ptext, raw in zip(prompt_texts, decoded):
|
|
cleaned = _strip_prompt_prefix_literal(raw, ptext)
|
|
decoded_results.append(cleaned)
|
|
return decoded_results
|
|
|
|
def _generate_from_processor_inputs(
|
|
self,
|
|
processor_inputs: Dict[str, torch.Tensor],
|
|
prompt_texts: List[str],
|
|
tokenizer,
|
|
):
|
|
"""
|
|
Run model.generate on already-processed processor_inputs (tensors moved to device),
|
|
then decode and trim prompt tokens & remove literal prompt prefixes using prompt_texts.
|
|
"""
|
|
gen_conf = GenerationConfig(
|
|
max_new_tokens=64,
|
|
do_sample=False,
|
|
num_return_sequences=1,
|
|
)
|
|
|
|
for k, v in list(processor_inputs.items()):
|
|
if isinstance(v, torch.Tensor):
|
|
processor_inputs[k] = v.to(self.summary_model.device)
|
|
|
|
with torch.inference_mode():
|
|
try:
|
|
if self.summary_model.device == "cuda":
|
|
with torch.amp.autocast("cuda", enabled=True):
|
|
generated_ids = self.summary_model.model.generate(
|
|
**processor_inputs, generation_config=gen_conf
|
|
)
|
|
else:
|
|
generated_ids = self.summary_model.model.generate(
|
|
**processor_inputs, generation_config=gen_conf
|
|
)
|
|
except RuntimeError as e:
|
|
warnings.warn(
|
|
f"Generation failed with error: {e}. Retrying with cuDNN disabled.",
|
|
RuntimeWarning,
|
|
)
|
|
cudnn_was_enabled = (
|
|
torch.backends.cudnn.is_available() and torch.backends.cudnn.enabled
|
|
)
|
|
if cudnn_was_enabled:
|
|
torch.backends.cudnn.enabled = False
|
|
try:
|
|
generated_ids = self.summary_model.model.generate(
|
|
**processor_inputs, generation_config=gen_conf
|
|
)
|
|
except Exception as retry_error:
|
|
raise RuntimeError(
|
|
f"Failed to generate ids after retry: {retry_error}"
|
|
) from retry_error
|
|
finally:
|
|
if cudnn_was_enabled:
|
|
torch.backends.cudnn.enabled = True
|
|
|
|
decoded = self._decode_trimmed_outputs(
|
|
generated_ids, processor_inputs, tokenizer, prompt_texts
|
|
)
|
|
return decoded
|
|
|
|
def _tensor_batch_to_pil_list(self, batch: torch.Tensor) -> List[Image.Image]:
|
|
"""
|
|
Convert a uint8 torch tensor batch (B, C, H, W) on CPU to list of PIL images (RGB).
|
|
The conversion is done on CPU and returns PIL.Image objects.
|
|
"""
|
|
if batch.device.type != "cpu":
|
|
batch = batch.to("cpu")
|
|
|
|
batch = batch.contiguous()
|
|
if batch.dtype.is_floating_point:
|
|
batch = (batch.clamp(0.0, 1.0) * 255.0).to(torch.uint8)
|
|
elif batch.dtype != torch.uint8:
|
|
batch = batch.to(torch.uint8)
|
|
pil_list: List[Image.Image] = []
|
|
for frame in batch:
|
|
arr = frame.permute(1, 2, 0).numpy()
|
|
pil_list.append(Image.fromarray(arr))
|
|
return pil_list
|
|
|
|
def make_captions_from_extracted_frames(
|
|
self,
|
|
extracted_video_gen: Generator[Tuple[torch.Tensor, torch.Tensor], None, None],
|
|
list_of_questions: Optional[List[str]] = None,
|
|
) -> Dict[str, Any]:
|
|
"""
|
|
Generate captions for all extracted frames and then produce a concise summary of the video.
|
|
Args:
|
|
extracted_video_dict (Dict[str, Any]): Dictionary containing the frame generator and number of frames.
|
|
summary_instruction (str, optional): Instruction for summarizing the captions. Defaults to a concise paragraph.
|
|
Returns:
|
|
Dict[str, Any]: A dictionary containing the list of captions with timestamps and the final summary.
|
|
"""
|
|
|
|
caption_instruction = "Describe this image in one concise caption."
|
|
include_questions = bool(list_of_questions)
|
|
if include_questions:
|
|
q_block = "\n".join(
|
|
[f"{i + 1}. {q.strip()}" for i, q in enumerate(list_of_questions)]
|
|
)
|
|
caption_instruction += (
|
|
" In addition to the concise caption, also answer the following questions based ONLY on the image. Answers must be very brief and concise."
|
|
" Produce exactly two labeled sections: \n\n"
|
|
"Summary: <concise summary>\n\n"
|
|
"VQA Answers: \n1. <answer to question 1>\n2. <answer to question 2>\n etc."
|
|
"\nReturn only those two sections for each image (do not add extra commentary)."
|
|
"\nIf the answer cannot be determined based on the provided answer blocks,"
|
|
' reply with the line "The answer cannot be determined based on the information provided."'
|
|
f"\n\nQuestions:\n{q_block}"
|
|
)
|
|
collected: List[Tuple[float, str]] = []
|
|
proc = self.summary_model.processor
|
|
try:
|
|
for batch_frames, batch_times in extracted_video_gen:
|
|
pil_list = self._tensor_batch_to_pil_list(batch_frames.cpu())
|
|
|
|
prompt_texts = []
|
|
for p in pil_list:
|
|
messages = [
|
|
{
|
|
"role": "user",
|
|
"content": [
|
|
{"type": "image", "image": p},
|
|
{"type": "text", "text": caption_instruction},
|
|
],
|
|
}
|
|
]
|
|
|
|
prompt_text = proc.apply_chat_template(
|
|
messages, tokenize=False, add_generation_prompt=True
|
|
)
|
|
prompt_texts.append(prompt_text)
|
|
|
|
processor_inputs = proc(
|
|
text=prompt_texts,
|
|
images=pil_list,
|
|
return_tensors="pt",
|
|
padding=True,
|
|
)
|
|
captions = self._generate_from_processor_inputs(
|
|
processor_inputs,
|
|
prompt_texts,
|
|
self.summary_model.tokenizer,
|
|
)
|
|
|
|
if isinstance(batch_times, torch.Tensor):
|
|
batch_times_list = batch_times.cpu().tolist()
|
|
else:
|
|
batch_times_list = list(batch_times)
|
|
for t, c in zip(batch_times_list, captions):
|
|
collected.append((float(t), c))
|
|
finally:
|
|
try:
|
|
extracted_video_gen.close()
|
|
except Exception:
|
|
warnings.warn("Failed to close video frame generator.", RuntimeWarning)
|
|
|
|
collected.sort(key=lambda x: x[0])
|
|
bullets_summary, bullets_vqa = _categorize_outputs(collected, include_questions)
|
|
|
|
return {
|
|
"summary_bullets": bullets_summary,
|
|
"vqa_bullets": bullets_vqa,
|
|
} # TODO consider taking out time stamps from the returned structure
|
|
|
|
def final_summary(self, summary_dict: Dict[str, Any]) -> Dict[str, Any]:
|
|
"""
|
|
Produce a concise summary of the video, based on generated captions for all extracted frames.
|
|
Args:
|
|
summary_dict (Dict[str, Any]): Dictionary containing captions for the frames.
|
|
Returns:
|
|
Dict[str, Any]: A dictionary containing the list of captions with timestamps and the final summary.
|
|
"""
|
|
summary_instruction = "Analyze the following captions from multiple frames of the same video and summarize the overall content of the video in one concise paragraph (1-3 sentences). Focus on the key themes, actions, or events across the video, not just the individual frames."
|
|
proc = self.summary_model.processor
|
|
|
|
bullets = summary_dict.get("summary_bullets", [])
|
|
if not bullets:
|
|
raise ValueError("No captions available for summary generation.")
|
|
|
|
combined_captions_text = "\n".join(bullets)
|
|
summary_user_text = summary_instruction + "\n\n" + combined_captions_text + "\n"
|
|
|
|
messages = [
|
|
{
|
|
"role": "user",
|
|
"content": [{"type": "text", "text": summary_user_text}],
|
|
}
|
|
]
|
|
|
|
summary_prompt_text = proc.apply_chat_template(
|
|
messages, tokenize=False, add_generation_prompt=True
|
|
)
|
|
|
|
summary_inputs = proc(
|
|
text=[summary_prompt_text], return_tensors="pt", padding=True
|
|
)
|
|
|
|
summary_inputs = {
|
|
k: v.to(self.summary_model.device) if isinstance(v, torch.Tensor) else v
|
|
for k, v in summary_inputs.items()
|
|
}
|
|
final_summary_list = self._generate_from_processor_inputs(
|
|
summary_inputs,
|
|
[summary_prompt_text],
|
|
self.summary_model.tokenizer,
|
|
)
|
|
final_summary = final_summary_list[0].strip() if final_summary_list else ""
|
|
|
|
return {
|
|
"summary": final_summary,
|
|
}
|
|
|
|
def final_answers(
|
|
self,
|
|
answers_dict: Dict[str, Any],
|
|
list_of_questions: List[str],
|
|
) -> Dict[str, Any]:
|
|
"""
|
|
Answer the list of questions for the video based on the VQA bullets from the frames.
|
|
Args:
|
|
answers_dict (Dict[str, Any]): Dictionary containing the VQA bullets.
|
|
Returns:
|
|
Dict[str, Any]: A dictionary containing the list of answers to the questions.
|
|
"""
|
|
vqa_bullets = answers_dict.get("vqa_bullets", [])
|
|
if not vqa_bullets:
|
|
raise ValueError(
|
|
"No VQA bullets generated for single frames available for answering questions."
|
|
)
|
|
|
|
include_questions = bool(list_of_questions)
|
|
if include_questions:
|
|
q_block = "\n".join(
|
|
[f"{i + 1}. {q.strip()}" for i, q in enumerate(list_of_questions)]
|
|
)
|
|
prompt = (
|
|
"You are provided with a set of short VQA-captions, each of which is a block of short answers"
|
|
" extracted from individual frames of the same video.\n\n"
|
|
"VQA-captions (use ONLY these to answer):\n"
|
|
f"{vqa_bullets}\n\n"
|
|
"Answer the following questions briefly, based ONLY on the lists of answers provided above. The VQA-captions above contain answers"
|
|
" to the same questions you are about to answer. If the answer cannot be determined based on the provided answer blocks,"
|
|
' reply with the line "The answer cannot be determined based on the information provided."'
|
|
"Questions:\n"
|
|
f"{q_block}\n\n"
|
|
"Produce an ordered list with answers in the same order as the questions. You must have this structure of your output: "
|
|
"Answers: \n1. <answer to question 1>\n2. <answer to question 2>\n etc."
|
|
"Return ONLY the ordered list with answers and NOTHING else — no commentary, no explanation, no surrounding markdown."
|
|
)
|
|
else:
|
|
raise ValueError(
|
|
"list_of_questions must be provided for making final answers."
|
|
)
|
|
|
|
proc = self.summary_model.processor
|
|
messages = [
|
|
{
|
|
"role": "user",
|
|
"content": [{"type": "text", "text": prompt}],
|
|
}
|
|
]
|
|
final_vqa_prompt_text = proc.apply_chat_template(
|
|
messages, tokenize=False, add_generation_prompt=True
|
|
)
|
|
final_vqa_inputs = proc(
|
|
text=[final_vqa_prompt_text], return_tensors="pt", padding=True
|
|
)
|
|
final_vqa_inputs = {
|
|
k: v.to(self.summary_model.device) if isinstance(v, torch.Tensor) else v
|
|
for k, v in final_vqa_inputs.items()
|
|
}
|
|
|
|
final_vqa_list = self._generate_from_processor_inputs(
|
|
final_vqa_inputs,
|
|
[final_vqa_prompt_text],
|
|
self.summary_model.tokenizer,
|
|
)
|
|
|
|
final_vqa_output = final_vqa_list[0].strip() if final_vqa_list else ""
|
|
vqa_answers = []
|
|
answer_matches = re.findall(
|
|
r"\d+\.\s*(.*?)(?=\n\d+\.|$)", final_vqa_output, flags=re.DOTALL
|
|
)
|
|
for answer in answer_matches:
|
|
vqa_answers.append(answer.strip())
|
|
return {
|
|
"vqa_answers": vqa_answers,
|
|
}
|
|
|
|
def analyse_videos_from_dict(
|
|
self,
|
|
analysis_type: Union[AnalysisType, str] = AnalysisType.SUMMARY,
|
|
frame_rate_per_second: float = 2.0,
|
|
list_of_questions: Optional[List[str]] = None,
|
|
) -> Dict[str, Any]:
|
|
"""
|
|
Analyse the video specified in self.subdict using frame extraction and captioning.
|
|
Args:
|
|
analysis_type (Union[AnalysisType, str], optional): Type of analysis to perform. Defaults to AnalysisType.SUMMARY.
|
|
frame_rate_per_second (float): Frame extraction rate in frames per second. Default is 2.0.
|
|
list_of_questions (List[str], optional): List of questions to answer about the video. Required if analysis_type includes questions.
|
|
Returns:
|
|
Dict[str, Any]: A dictionary containing the analysis results, including summary and answers for provided questions(if any).
|
|
"""
|
|
|
|
all_answers = {}
|
|
analysis_type, is_summary, is_questions = AnalysisType._validate_analysis_type(
|
|
analysis_type, list_of_questions
|
|
)
|
|
|
|
for video_key, entry in self.subdict.items():
|
|
summary_and_vqa = {}
|
|
frames_generator = self._extract_video_frames(
|
|
entry, frame_rate_per_second=frame_rate_per_second
|
|
)
|
|
|
|
answers_dict = self.make_captions_from_extracted_frames(
|
|
frames_generator, list_of_questions=list_of_questions
|
|
)
|
|
# TODO: captions has to be post-processed with foreseeing audio analysis
|
|
# TODO: captions and answers may lead to prompt, that superior model limits. Consider hierarchical approach.
|
|
if is_summary:
|
|
answer = self.final_summary(answers_dict)
|
|
summary_and_vqa["summary"] = answer["summary"]
|
|
if is_questions:
|
|
answer = self.final_answers(answers_dict, list_of_questions)
|
|
summary_and_vqa["vqa_answers"] = answer["vqa_answers"]
|
|
|
|
all_answers[video_key] = summary_and_vqa
|
|
|
|
return all_answers
|