зеркало из
https://github.com/ssciwr/AMMICO.git
synced 2025-10-30 05:26:05 +02:00
add functionality for video vqa
Этот коммит содержится в:
родитель
8c26a8de5e
Коммит
c2b2079d4e
@ -6,7 +6,7 @@ import torch
|
|||||||
from PIL import Image
|
from PIL import Image
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
from typing import List, Optional, Union, Dict, Any, Tuple
|
from typing import List, Optional, Union, Dict, Any
|
||||||
from collections.abc import Sequence as _Sequence
|
from collections.abc import Sequence as _Sequence
|
||||||
from transformers import GenerationConfig
|
from transformers import GenerationConfig
|
||||||
from qwen_vl_utils import process_vision_info
|
from qwen_vl_utils import process_vision_info
|
||||||
@ -120,42 +120,11 @@ class ImageSummaryDetector(AnalysisMethod):
|
|||||||
|
|
||||||
return inputs
|
return inputs
|
||||||
|
|
||||||
def _validate_analysis_type(
|
|
||||||
self,
|
|
||||||
analysis_type: Union["AnalysisType", str],
|
|
||||||
list_of_questions: Optional[List[str]],
|
|
||||||
max_questions_per_image: int,
|
|
||||||
) -> Tuple[str, List[str], bool, bool]:
|
|
||||||
if isinstance(analysis_type, AnalysisType):
|
|
||||||
analysis_type = analysis_type.value
|
|
||||||
|
|
||||||
allowed = {"summary", "questions", "summary_and_questions"}
|
|
||||||
if analysis_type not in allowed:
|
|
||||||
raise ValueError(f"analysis_type must be one of {allowed}")
|
|
||||||
|
|
||||||
if list_of_questions is None:
|
|
||||||
list_of_questions = [
|
|
||||||
"Are there people in the image?",
|
|
||||||
"What is this picture about?",
|
|
||||||
]
|
|
||||||
|
|
||||||
if analysis_type in ("questions", "summary_and_questions"):
|
|
||||||
if len(list_of_questions) > max_questions_per_image:
|
|
||||||
raise ValueError(
|
|
||||||
f"Number of questions per image ({len(list_of_questions)}) exceeds safety cap ({max_questions_per_image}). Reduce questions or increase max_questions_per_image."
|
|
||||||
)
|
|
||||||
|
|
||||||
is_summary = analysis_type in ("summary", "summary_and_questions")
|
|
||||||
is_questions = analysis_type in ("questions", "summary_and_questions")
|
|
||||||
|
|
||||||
return analysis_type, list_of_questions, is_summary, is_questions
|
|
||||||
|
|
||||||
def analyse_image(
|
def analyse_image(
|
||||||
self,
|
self,
|
||||||
entry: dict,
|
entry: dict,
|
||||||
analysis_type: Union[str, AnalysisType] = AnalysisType.SUMMARY_AND_QUESTIONS,
|
analysis_type: Union[str, AnalysisType] = AnalysisType.SUMMARY,
|
||||||
list_of_questions: Optional[List[str]] = None,
|
list_of_questions: Optional[List[str]] = None,
|
||||||
max_questions_per_image: int = 32,
|
|
||||||
is_concise_summary: bool = True,
|
is_concise_summary: bool = True,
|
||||||
is_concise_answer: bool = True,
|
is_concise_answer: bool = True,
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
@ -165,10 +134,8 @@ class ImageSummaryDetector(AnalysisMethod):
|
|||||||
- 'vqa' (dict) if questions requested
|
- 'vqa' (dict) if questions requested
|
||||||
"""
|
"""
|
||||||
self.subdict = entry
|
self.subdict = entry
|
||||||
analysis_type, list_of_questions, is_summary, is_questions = (
|
analysis_type, is_summary, is_questions = AnalysisType._validate_analysis_type(
|
||||||
self._validate_analysis_type(
|
analysis_type, list_of_questions
|
||||||
analysis_type, list_of_questions, max_questions_per_image
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if is_summary:
|
if is_summary:
|
||||||
@ -195,9 +162,8 @@ class ImageSummaryDetector(AnalysisMethod):
|
|||||||
|
|
||||||
def analyse_images_from_dict(
|
def analyse_images_from_dict(
|
||||||
self,
|
self,
|
||||||
analysis_type: Union[AnalysisType, str] = AnalysisType.SUMMARY_AND_QUESTIONS,
|
analysis_type: Union[AnalysisType, str] = AnalysisType.SUMMARY,
|
||||||
list_of_questions: Optional[List[str]] = None,
|
list_of_questions: Optional[List[str]] = None,
|
||||||
max_questions_per_image: int = 32,
|
|
||||||
keys_batch_size: int = 16,
|
keys_batch_size: int = 16,
|
||||||
is_concise_summary: bool = True,
|
is_concise_summary: bool = True,
|
||||||
is_concise_answer: bool = True,
|
is_concise_answer: bool = True,
|
||||||
@ -208,7 +174,6 @@ class ImageSummaryDetector(AnalysisMethod):
|
|||||||
Args:
|
Args:
|
||||||
analysis_type (str): type of the analysis.
|
analysis_type (str): type of the analysis.
|
||||||
list_of_questions (list[str]): list of questions.
|
list_of_questions (list[str]): list of questions.
|
||||||
max_questions_per_image (int): maximum number of questions per image. We recommend to keep it low to avoid long processing times and high memory usage.
|
|
||||||
keys_batch_size (int): number of images to process in a batch.
|
keys_batch_size (int): number of images to process in a batch.
|
||||||
is_concise_summary (bool): whether to generate concise summary.
|
is_concise_summary (bool): whether to generate concise summary.
|
||||||
is_concise_answer (bool): whether to generate concise answers.
|
is_concise_answer (bool): whether to generate concise answers.
|
||||||
@ -216,10 +181,8 @@ class ImageSummaryDetector(AnalysisMethod):
|
|||||||
self.subdict (dict): dictionary with analysis results.
|
self.subdict (dict): dictionary with analysis results.
|
||||||
"""
|
"""
|
||||||
# TODO: add option to ask multiple questions per image as one batch.
|
# TODO: add option to ask multiple questions per image as one batch.
|
||||||
analysis_type, list_of_questions, is_summary, is_questions = (
|
analysis_type, is_summary, is_questions = AnalysisType._validate_analysis_type(
|
||||||
self._validate_analysis_type(
|
analysis_type, list_of_questions
|
||||||
analysis_type, list_of_questions, max_questions_per_image
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
keys = list(self.subdict.keys())
|
keys = list(self.subdict.keys())
|
||||||
@ -284,7 +247,7 @@ class ImageSummaryDetector(AnalysisMethod):
|
|||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
try:
|
try:
|
||||||
if self.summary_model.device == "cuda":
|
if self.summary_model.device == "cuda":
|
||||||
with torch.cuda.amp.autocast(enabled=True):
|
with torch.amp.autocast("cuda", enabled=True):
|
||||||
generated_ids = self.summary_model.model.generate(
|
generated_ids = self.summary_model.model.generate(
|
||||||
**inputs, generation_config=gen_conf
|
**inputs, generation_config=gen_conf
|
||||||
)
|
)
|
||||||
@ -366,7 +329,7 @@ class ImageSummaryDetector(AnalysisMethod):
|
|||||||
inputs = self._prepare_inputs(chunk, entry)
|
inputs = self._prepare_inputs(chunk, entry)
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
if self.summary_model.device == "cuda":
|
if self.summary_model.device == "cuda":
|
||||||
with torch.cuda.amp.autocast(enabled=True):
|
with torch.amp.autocast("cuda", enabled=True):
|
||||||
out_ids = self.summary_model.model.generate(
|
out_ids = self.summary_model.model.generate(
|
||||||
**inputs, generation_config=gen_conf
|
**inputs, generation_config=gen_conf
|
||||||
)
|
)
|
||||||
|
|||||||
114
ammico/utils.py
114
ammico/utils.py
@ -6,6 +6,8 @@ import importlib_resources
|
|||||||
import collections
|
import collections
|
||||||
import random
|
import random
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
from typing import List, Tuple, Optional, Union
|
||||||
|
import re
|
||||||
|
|
||||||
|
|
||||||
pkg = importlib_resources.files("ammico")
|
pkg = importlib_resources.files("ammico")
|
||||||
@ -46,6 +48,35 @@ class AnalysisType(str, Enum):
|
|||||||
QUESTIONS = "questions"
|
QUESTIONS = "questions"
|
||||||
SUMMARY_AND_QUESTIONS = "summary_and_questions"
|
SUMMARY_AND_QUESTIONS = "summary_and_questions"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _validate_analysis_type(
|
||||||
|
cls,
|
||||||
|
analysis_type: Union["AnalysisType", str],
|
||||||
|
list_of_questions: Optional[List[str]],
|
||||||
|
) -> Tuple[str, bool, bool]:
|
||||||
|
max_questions_per_image = 15 # safety cap to avoid too many questions
|
||||||
|
if isinstance(analysis_type, AnalysisType):
|
||||||
|
analysis_type = analysis_type.value
|
||||||
|
|
||||||
|
allowed = {item.value for item in AnalysisType}
|
||||||
|
if analysis_type not in allowed:
|
||||||
|
raise ValueError(f"analysis_type must be one of {allowed}")
|
||||||
|
|
||||||
|
if analysis_type in ("questions", "summary_and_questions"):
|
||||||
|
if not list_of_questions:
|
||||||
|
raise ValueError(
|
||||||
|
"list_of_questions must be provided for QUESTIONS analysis type."
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(list_of_questions) > max_questions_per_image:
|
||||||
|
raise ValueError(
|
||||||
|
f"Number of questions per image ({len(list_of_questions)}) exceeds safety cap ({max_questions_per_image}). Reduce questions or increase max_questions_per_image."
|
||||||
|
)
|
||||||
|
|
||||||
|
is_summary = analysis_type in ("summary", "summary_and_questions")
|
||||||
|
is_questions = analysis_type in ("questions", "summary_and_questions")
|
||||||
|
return analysis_type, is_summary, is_questions
|
||||||
|
|
||||||
|
|
||||||
class AnalysisMethod:
|
class AnalysisMethod:
|
||||||
"""Base class to be inherited by all analysis methods."""
|
"""Base class to be inherited by all analysis methods."""
|
||||||
@ -101,6 +132,89 @@ def _limit_results(results, limit):
|
|||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
def _categorize_outputs(
|
||||||
|
collected: List[Tuple[float, str]],
|
||||||
|
include_questions: bool = False,
|
||||||
|
) -> Tuple[List[str], List[str]]:
|
||||||
|
"""
|
||||||
|
Categorize collected outputs into summary bullets and VQA bullets.
|
||||||
|
Args:
|
||||||
|
collected (List[Tuple[float, str]]): List of tuples containing timestamps and generated texts.
|
||||||
|
Returns:
|
||||||
|
Tuple[List[str], List[str]]: A tuple containing two lists - summary bullets and VQA bullets.
|
||||||
|
"""
|
||||||
|
MAX_CAPTIONS_FOR_SUMMARY = 600 # TODO For now, this is a constant value, but later we need to make it adjustable, with the idea of cutting out the most similar frames to reduce the load on the system.
|
||||||
|
caps_for_summary_vqa = (
|
||||||
|
collected[-MAX_CAPTIONS_FOR_SUMMARY:]
|
||||||
|
if len(collected) > MAX_CAPTIONS_FOR_SUMMARY
|
||||||
|
else collected
|
||||||
|
)
|
||||||
|
bullets_summary = []
|
||||||
|
bullets_vqa = []
|
||||||
|
|
||||||
|
for t, c in caps_for_summary_vqa:
|
||||||
|
if include_questions:
|
||||||
|
result_sections = c.strip()
|
||||||
|
m = re.search(
|
||||||
|
r"Summary\s*:\s*(.*?)\s*(?:VQA\s+Answers\s*:\s*(.*))?$",
|
||||||
|
result_sections,
|
||||||
|
flags=re.IGNORECASE | re.DOTALL,
|
||||||
|
)
|
||||||
|
if m:
|
||||||
|
summary_text = (
|
||||||
|
m.group(1).replace("\n", " ").strip() if m.group(1) else None
|
||||||
|
)
|
||||||
|
vqa_text = m.group(2).strip() if m.group(2) else None
|
||||||
|
if not summary_text or not vqa_text:
|
||||||
|
raise ValueError(
|
||||||
|
f"Model output is missing either summary or VQA answers: {c}"
|
||||||
|
)
|
||||||
|
bullets_summary.append(f"- [{t:.3f}s] {summary_text}")
|
||||||
|
bullets_vqa.append(f"- [{t:.3f}s] {vqa_text}")
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Failed to parse summary and VQA answers from model output: {c}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
snippet = c.replace("\n", " ").strip()
|
||||||
|
bullets_summary.append(f"- [{t:.3f}s] {snippet}")
|
||||||
|
return bullets_summary, bullets_vqa
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_whitespace(s: str) -> str:
|
||||||
|
return re.sub(r"\s+", " ", s).strip()
|
||||||
|
|
||||||
|
|
||||||
|
def _strip_prompt_prefix_literal(decoded: str, prompt: str) -> str:
|
||||||
|
"""
|
||||||
|
Remove any literal prompt prefix from decoded text using a normalized-substring match.
|
||||||
|
Guarantees no prompt text remains at the start of returned string (best-effort).
|
||||||
|
"""
|
||||||
|
if not decoded:
|
||||||
|
return ""
|
||||||
|
if not prompt:
|
||||||
|
return decoded.strip()
|
||||||
|
|
||||||
|
d_norm = _normalize_whitespace(decoded)
|
||||||
|
p_norm = _normalize_whitespace(prompt)
|
||||||
|
|
||||||
|
idx = d_norm.find(p_norm)
|
||||||
|
if idx != -1:
|
||||||
|
running = []
|
||||||
|
for i, ch in enumerate(decoded):
|
||||||
|
running.append(ch if not ch.isspace() else " ")
|
||||||
|
cur_norm = _normalize_whitespace("".join(running))
|
||||||
|
if cur_norm.endswith(p_norm):
|
||||||
|
return decoded[i + 1 :].lstrip() if i + 1 < len(decoded) else ""
|
||||||
|
m = re.match(
|
||||||
|
r"^(?:\s*(system|user|assistant)[:\s-]*\n?)+", decoded, flags=re.IGNORECASE
|
||||||
|
)
|
||||||
|
if m:
|
||||||
|
return decoded[m.end() :].lstrip()
|
||||||
|
|
||||||
|
return decoded.lstrip("\n\r ").lstrip(":;- ").strip()
|
||||||
|
|
||||||
|
|
||||||
def find_videos(
|
def find_videos(
|
||||||
path: str = None,
|
path: str = None,
|
||||||
pattern=["mp4"], # TODO: test with more video formats
|
pattern=["mp4"], # TODO: test with more video formats
|
||||||
|
|||||||
@ -6,17 +6,24 @@ from PIL import Image
|
|||||||
from torchcodec.decoders import VideoDecoder
|
from torchcodec.decoders import VideoDecoder
|
||||||
|
|
||||||
from ammico.model import MultimodalSummaryModel
|
from ammico.model import MultimodalSummaryModel
|
||||||
from ammico.utils import AnalysisMethod
|
from ammico.utils import (
|
||||||
|
AnalysisMethod,
|
||||||
|
AnalysisType,
|
||||||
|
_categorize_outputs,
|
||||||
|
_strip_prompt_prefix_literal,
|
||||||
|
)
|
||||||
|
|
||||||
from typing import List, Dict, Any, Generator, Tuple
|
from typing import List, Dict, Any, Generator, Tuple, Union, Optional
|
||||||
from transformers import GenerationConfig
|
from transformers import GenerationConfig
|
||||||
|
|
||||||
|
|
||||||
class VideoSummaryDetector(AnalysisMethod):
|
class VideoSummaryDetector(AnalysisMethod):
|
||||||
|
MAX_SAMPLES_CAP = 1000 # safety cap for total extracted frames
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
summary_model: MultimodalSummaryModel,
|
summary_model: MultimodalSummaryModel,
|
||||||
subdict: dict = {},
|
subdict: Optional[Dict[str, Any]] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Class for analysing videos using QWEN-2.5-VL model.
|
Class for analysing videos using QWEN-2.5-VL model.
|
||||||
@ -29,6 +36,8 @@ class VideoSummaryDetector(AnalysisMethod):
|
|||||||
Returns:
|
Returns:
|
||||||
None.
|
None.
|
||||||
"""
|
"""
|
||||||
|
if subdict is None:
|
||||||
|
subdict = {}
|
||||||
|
|
||||||
super().__init__(subdict)
|
super().__init__(subdict)
|
||||||
self.summary_model = summary_model
|
self.summary_model = summary_model
|
||||||
@ -107,6 +116,7 @@ class VideoSummaryDetector(AnalysisMethod):
|
|||||||
raise ValueError("frame_rate_per_second must be > 0")
|
raise ValueError("frame_rate_per_second must be > 0")
|
||||||
|
|
||||||
n_samples = max(1, int(math.floor(duration * frame_rate_per_second)))
|
n_samples = max(1, int(math.floor(duration * frame_rate_per_second)))
|
||||||
|
n_samples = min(n_samples, self.MAX_SAMPLES_CAP)
|
||||||
|
|
||||||
if begin_stream_seconds is not None and end_stream_seconds is not None:
|
if begin_stream_seconds is not None and end_stream_seconds is not None:
|
||||||
sample_times = torch.linspace(
|
sample_times = torch.linspace(
|
||||||
@ -126,38 +136,6 @@ class VideoSummaryDetector(AnalysisMethod):
|
|||||||
|
|
||||||
return generator
|
return generator
|
||||||
|
|
||||||
def _normalize_whitespace(self, s: str) -> str:
|
|
||||||
return re.sub(r"\s+", " ", s).strip()
|
|
||||||
|
|
||||||
def _strip_prompt_prefix_literal(self, decoded: str, prompt: str) -> str:
|
|
||||||
"""
|
|
||||||
Remove any literal prompt prefix from decoded text using a normalized-substring match.
|
|
||||||
Guarantees no prompt text remains at the start of returned string (best-effort).
|
|
||||||
"""
|
|
||||||
if not decoded:
|
|
||||||
return ""
|
|
||||||
if not prompt:
|
|
||||||
return decoded.strip()
|
|
||||||
|
|
||||||
d_norm = self._normalize_whitespace(decoded)
|
|
||||||
p_norm = self._normalize_whitespace(prompt)
|
|
||||||
|
|
||||||
idx = d_norm.find(p_norm)
|
|
||||||
if idx != -1:
|
|
||||||
running = []
|
|
||||||
for i, ch in enumerate(decoded):
|
|
||||||
running.append(ch if not ch.isspace() else " ")
|
|
||||||
cur_norm = self._normalize_whitespace("".join(running))
|
|
||||||
if cur_norm.endswith(p_norm):
|
|
||||||
return decoded[i + 1 :].lstrip() if i + 1 < len(decoded) else ""
|
|
||||||
m = re.match(
|
|
||||||
r"^(?:\s*(system|user|assistant)[:\s-]*\n?)+", decoded, flags=re.IGNORECASE
|
|
||||||
)
|
|
||||||
if m:
|
|
||||||
return decoded[m.end() :].lstrip()
|
|
||||||
|
|
||||||
return decoded.lstrip("\n\r ").lstrip(":;- ").strip()
|
|
||||||
|
|
||||||
def _decode_trimmed_outputs(
|
def _decode_trimmed_outputs(
|
||||||
self,
|
self,
|
||||||
generated_ids: torch.Tensor,
|
generated_ids: torch.Tensor,
|
||||||
@ -170,20 +148,18 @@ class VideoSummaryDetector(AnalysisMethod):
|
|||||||
Then remove any literal prompt prefix using prompt_texts (one per batch element).
|
Then remove any literal prompt prefix using prompt_texts (one per batch element).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
decoded_results = []
|
|
||||||
batch_size = generated_ids.shape[0]
|
batch_size = generated_ids.shape[0]
|
||||||
|
|
||||||
if "input_ids" in inputs:
|
if "input_ids" in inputs:
|
||||||
lengths = (
|
token_for_padding = (
|
||||||
inputs["input_ids"]
|
tokenizer.pad_token_id
|
||||||
.ne(
|
if getattr(tokenizer, "pad_token_id", None) is not None
|
||||||
tokenizer.pad_token_id
|
else getattr(tokenizer, "eos_token_id", None)
|
||||||
if tokenizer.pad_token_id is not None
|
|
||||||
else tokenizer.eos_token_id
|
|
||||||
)
|
|
||||||
.sum(dim=1)
|
|
||||||
.tolist()
|
|
||||||
)
|
)
|
||||||
|
if token_for_padding is None:
|
||||||
|
lengths = [int(inputs["input_ids"].shape[1])] * batch_size
|
||||||
|
else:
|
||||||
|
lengths = inputs["input_ids"].ne(token_for_padding).sum(dim=1).tolist()
|
||||||
else:
|
else:
|
||||||
lengths = [0] * batch_size
|
lengths = [0] * batch_size
|
||||||
|
|
||||||
@ -195,13 +171,15 @@ class VideoSummaryDetector(AnalysisMethod):
|
|||||||
t = out_ids[in_len:]
|
t = out_ids[in_len:]
|
||||||
else:
|
else:
|
||||||
t = out_ids.new_empty((0,), dtype=out_ids.dtype)
|
t = out_ids.new_empty((0,), dtype=out_ids.dtype)
|
||||||
trimmed_ids.append(t)
|
t_cpu = t.to("cpu")
|
||||||
|
trimmed_ids.append(t_cpu.tolist())
|
||||||
|
|
||||||
decoded = tokenizer.batch_decode(
|
decoded = tokenizer.batch_decode(
|
||||||
trimmed_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True
|
trimmed_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True
|
||||||
)
|
)
|
||||||
|
decoded_results = []
|
||||||
for ptext, raw in zip(prompt_texts, decoded):
|
for ptext, raw in zip(prompt_texts, decoded):
|
||||||
cleaned = self._strip_prompt_prefix_literal(raw, ptext)
|
cleaned = _strip_prompt_prefix_literal(raw, ptext)
|
||||||
decoded_results.append(cleaned)
|
decoded_results.append(cleaned)
|
||||||
return decoded_results
|
return decoded_results
|
||||||
|
|
||||||
@ -209,7 +187,6 @@ class VideoSummaryDetector(AnalysisMethod):
|
|||||||
self,
|
self,
|
||||||
processor_inputs: Dict[str, torch.Tensor],
|
processor_inputs: Dict[str, torch.Tensor],
|
||||||
prompt_texts: List[str],
|
prompt_texts: List[str],
|
||||||
model,
|
|
||||||
tokenizer,
|
tokenizer,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@ -224,7 +201,7 @@ class VideoSummaryDetector(AnalysisMethod):
|
|||||||
|
|
||||||
for k, v in list(processor_inputs.items()):
|
for k, v in list(processor_inputs.items()):
|
||||||
if isinstance(v, torch.Tensor):
|
if isinstance(v, torch.Tensor):
|
||||||
processor_inputs[k] = v.to(model.device)
|
processor_inputs[k] = v.to(self.summary_model.device)
|
||||||
|
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
try:
|
try:
|
||||||
@ -239,8 +216,8 @@ class VideoSummaryDetector(AnalysisMethod):
|
|||||||
)
|
)
|
||||||
except RuntimeError as e:
|
except RuntimeError as e:
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
"Retry without autocast failed: %s. Attempting cudnn-disabled retry.",
|
f"Generation failed with error: {e}. Retrying with cuDNN disabled.",
|
||||||
e,
|
RuntimeWarning,
|
||||||
)
|
)
|
||||||
cudnn_was_enabled = (
|
cudnn_was_enabled = (
|
||||||
torch.backends.cudnn.is_available() and torch.backends.cudnn.enabled
|
torch.backends.cudnn.is_available() and torch.backends.cudnn.enabled
|
||||||
@ -273,7 +250,9 @@ class VideoSummaryDetector(AnalysisMethod):
|
|||||||
batch = batch.to("cpu")
|
batch = batch.to("cpu")
|
||||||
|
|
||||||
batch = batch.contiguous()
|
batch = batch.contiguous()
|
||||||
if batch.dtype != torch.uint8:
|
if batch.dtype.is_floating_point:
|
||||||
|
batch = (batch.clamp(0.0, 1.0) * 255.0).to(torch.uint8)
|
||||||
|
elif batch.dtype != torch.uint8:
|
||||||
batch = batch.to(torch.uint8)
|
batch = batch.to(torch.uint8)
|
||||||
pil_list: List[Image.Image] = []
|
pil_list: List[Image.Image] = []
|
||||||
for frame in batch:
|
for frame in batch:
|
||||||
@ -281,10 +260,10 @@ class VideoSummaryDetector(AnalysisMethod):
|
|||||||
pil_list.append(Image.fromarray(arr))
|
pil_list.append(Image.fromarray(arr))
|
||||||
return pil_list
|
return pil_list
|
||||||
|
|
||||||
def brute_force_summary(
|
def make_captions_from_extracted_frames(
|
||||||
self,
|
self,
|
||||||
extracted_video_gen: Generator[Tuple[torch.Tensor, torch.Tensor], None, None],
|
extracted_video_gen: Generator[Tuple[torch.Tensor, torch.Tensor], None, None],
|
||||||
summary_instruction: str = "Analyze the following captions from multiple frames of the same video and summarize the overall content of the video in one concise paragraph (1-3 sentences). Focus on the key themes, actions, or events across the video, not just the individual frames.",
|
list_of_questions: Optional[List[str]] = None,
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Generate captions for all extracted frames and then produce a concise summary of the video.
|
Generate captions for all extracted frames and then produce a concise summary of the video.
|
||||||
@ -296,68 +275,93 @@ class VideoSummaryDetector(AnalysisMethod):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
caption_instruction = "Describe this image in one concise caption."
|
caption_instruction = "Describe this image in one concise caption."
|
||||||
|
include_questions = bool(list_of_questions)
|
||||||
|
if include_questions:
|
||||||
|
q_block = "\n".join(
|
||||||
|
[f"{i + 1}. {q.strip()}" for i, q in enumerate(list_of_questions)]
|
||||||
|
)
|
||||||
|
caption_instruction += (
|
||||||
|
" In addition to the concise caption, also answer the following questions based ONLY on the image. Answers must be very brief and concise."
|
||||||
|
" Produce exactly two labeled sections: \n\n"
|
||||||
|
"Summary: <concise summary>\n\n"
|
||||||
|
"VQA Answers: \n1. <answer to question 1>\n2. <answer to question 2>\n etc."
|
||||||
|
"\nReturn only those two sections for each image (do not add extra commentary)."
|
||||||
|
"\nIf the answer cannot be determined based on the provided answer blocks,"
|
||||||
|
' reply with the line "The answer cannot be determined based on the information provided."'
|
||||||
|
f"\n\nQuestions:\n{q_block}"
|
||||||
|
)
|
||||||
collected: List[Tuple[float, str]] = []
|
collected: List[Tuple[float, str]] = []
|
||||||
proc = self.summary_model.processor
|
proc = self.summary_model.processor
|
||||||
|
try:
|
||||||
|
for batch_frames, batch_times in extracted_video_gen:
|
||||||
|
pil_list = self._tensor_batch_to_pil_list(batch_frames.cpu())
|
||||||
|
|
||||||
for batch_frames, batch_times in extracted_video_gen:
|
prompt_texts = []
|
||||||
pil_list = self._tensor_batch_to_pil_list(batch_frames.cpu())
|
for p in pil_list:
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{"type": "image", "image": p},
|
||||||
|
{"type": "text", "text": caption_instruction},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
prompt_texts = []
|
prompt_text = proc.apply_chat_template(
|
||||||
for p in pil_list:
|
messages, tokenize=False, add_generation_prompt=True
|
||||||
messages = [
|
)
|
||||||
{
|
prompt_texts.append(prompt_text)
|
||||||
"role": "user",
|
|
||||||
"content": [
|
|
||||||
{"type": "image", "image": p},
|
|
||||||
{"type": "text", "text": caption_instruction},
|
|
||||||
],
|
|
||||||
}
|
|
||||||
]
|
|
||||||
|
|
||||||
prompt_text = proc.apply_chat_template(
|
processor_inputs = proc(
|
||||||
messages, tokenize=False, add_generation_prompt=True
|
text=prompt_texts,
|
||||||
|
images=pil_list,
|
||||||
|
return_tensors="pt",
|
||||||
|
padding=True,
|
||||||
|
)
|
||||||
|
captions = self._generate_from_processor_inputs(
|
||||||
|
processor_inputs,
|
||||||
|
prompt_texts,
|
||||||
|
self.summary_model.tokenizer,
|
||||||
)
|
)
|
||||||
prompt_texts.append(prompt_text)
|
|
||||||
|
|
||||||
processor_inputs = proc(
|
if isinstance(batch_times, torch.Tensor):
|
||||||
text=prompt_texts, images=pil_list, return_tensors="pt", padding=True
|
batch_times_list = batch_times.cpu().tolist()
|
||||||
)
|
else:
|
||||||
captions = self._generate_from_processor_inputs(
|
batch_times_list = list(batch_times)
|
||||||
processor_inputs,
|
for t, c in zip(batch_times_list, captions):
|
||||||
prompt_texts,
|
collected.append((float(t), c))
|
||||||
self.summary_model,
|
finally:
|
||||||
self.summary_model.tokenizer,
|
try:
|
||||||
)
|
extracted_video_gen.close()
|
||||||
|
except Exception:
|
||||||
if isinstance(batch_times, torch.Tensor):
|
warnings.warn("Failed to close video frame generator.", RuntimeWarning)
|
||||||
batch_times_list = batch_times.cpu().tolist()
|
|
||||||
else:
|
|
||||||
batch_times_list = list(batch_times)
|
|
||||||
for t, c in zip(batch_times_list, captions):
|
|
||||||
collected.append((float(t), c))
|
|
||||||
|
|
||||||
collected.sort(key=lambda x: x[0])
|
collected.sort(key=lambda x: x[0])
|
||||||
extracted_video_gen.close()
|
bullets_summary, bullets_vqa = _categorize_outputs(collected, include_questions)
|
||||||
|
|
||||||
MAX_CAPTIONS_FOR_SUMMARY = 200
|
return {
|
||||||
caps_for_summary = (
|
"summary_bullets": bullets_summary,
|
||||||
collected[-MAX_CAPTIONS_FOR_SUMMARY:]
|
"vqa_bullets": bullets_vqa,
|
||||||
if len(collected) > MAX_CAPTIONS_FOR_SUMMARY
|
} # TODO consider taking out time stamps from the returned structure
|
||||||
else collected
|
|
||||||
)
|
|
||||||
|
|
||||||
bullets = []
|
def final_summary(self, summary_dict: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
for t, c in caps_for_summary:
|
"""
|
||||||
snippet = c.replace("\n", " ").strip()
|
Produce a concise summary of the video, based on generated captions for all extracted frames.
|
||||||
bullets.append(f"- [{t:.3f}s] {snippet}")
|
Args:
|
||||||
|
summary_dict (Dict[str, Any]): Dictionary containing captions for the frames.
|
||||||
|
Returns:
|
||||||
|
Dict[str, Any]: A dictionary containing the list of captions with timestamps and the final summary.
|
||||||
|
"""
|
||||||
|
summary_instruction = "Analyze the following captions from multiple frames of the same video and summarize the overall content of the video in one concise paragraph (1-3 sentences). Focus on the key themes, actions, or events across the video, not just the individual frames."
|
||||||
|
proc = self.summary_model.processor
|
||||||
|
|
||||||
|
bullets = summary_dict.get("summary_bullets", [])
|
||||||
|
if not bullets:
|
||||||
|
raise ValueError("No captions available for summary generation.")
|
||||||
|
|
||||||
combined_captions_text = "\n".join(bullets)
|
combined_captions_text = "\n".join(bullets)
|
||||||
summary_user_text = (
|
summary_user_text = summary_instruction + "\n\n" + combined_captions_text + "\n"
|
||||||
summary_instruction
|
|
||||||
+ "\n\n"
|
|
||||||
+ combined_captions_text
|
|
||||||
+ "\n\nPlease produce a single concise paragraph."
|
|
||||||
)
|
|
||||||
|
|
||||||
messages = [
|
messages = [
|
||||||
{
|
{
|
||||||
@ -381,40 +385,130 @@ class VideoSummaryDetector(AnalysisMethod):
|
|||||||
final_summary_list = self._generate_from_processor_inputs(
|
final_summary_list = self._generate_from_processor_inputs(
|
||||||
summary_inputs,
|
summary_inputs,
|
||||||
[summary_prompt_text],
|
[summary_prompt_text],
|
||||||
self.summary_model.model,
|
|
||||||
self.summary_model.tokenizer,
|
self.summary_model.tokenizer,
|
||||||
)
|
)
|
||||||
final_summary = final_summary_list[0].strip() if final_summary_list else ""
|
final_summary = final_summary_list[0].strip() if final_summary_list else ""
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"captions": collected,
|
|
||||||
"summary": final_summary,
|
"summary": final_summary,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def final_answers(
|
||||||
|
self,
|
||||||
|
answers_dict: Dict[str, Any],
|
||||||
|
list_of_questions: List[str],
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Answer the list of questions for the video based on the VQA bullets from the frames.
|
||||||
|
Args:
|
||||||
|
answers_dict (Dict[str, Any]): Dictionary containing the VQA bullets.
|
||||||
|
Returns:
|
||||||
|
Dict[str, Any]: A dictionary containing the list of answers to the questions.
|
||||||
|
"""
|
||||||
|
vqa_bullets = answers_dict.get("vqa_bullets", [])
|
||||||
|
if not vqa_bullets:
|
||||||
|
raise ValueError(
|
||||||
|
"No VQA bullets generated for single frames available for answering questions."
|
||||||
|
)
|
||||||
|
|
||||||
|
include_questions = bool(list_of_questions)
|
||||||
|
if include_questions:
|
||||||
|
q_block = "\n".join(
|
||||||
|
[f"{i + 1}. {q.strip()}" for i, q in enumerate(list_of_questions)]
|
||||||
|
)
|
||||||
|
prompt = (
|
||||||
|
"You are provided with a set of short VQA-captions, each of which is a block of short answers"
|
||||||
|
" extracted from individual frames of the same video.\n\n"
|
||||||
|
"VQA-captions (use ONLY these to answer):\n"
|
||||||
|
f"{vqa_bullets}\n\n"
|
||||||
|
"Answer the following questions briefly, based ONLY on the lists of answers provided above. The VQA-captions above contain answers"
|
||||||
|
" to the same questions you are about to answer. If the answer cannot be determined based on the provided answer blocks,"
|
||||||
|
' reply with the line "The answer cannot be determined based on the information provided."'
|
||||||
|
"Questions:\n"
|
||||||
|
f"{q_block}\n\n"
|
||||||
|
"Produce an ordered list with answers in the same order as the questions. You must have this structure of your output: "
|
||||||
|
"Answers: \n1. <answer to question 1>\n2. <answer to question 2>\n etc."
|
||||||
|
"Return ONLY the ordered list with answers and NOTHING else — no commentary, no explanation, no surrounding markdown."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"list_of_questions must be provided for making final answers."
|
||||||
|
)
|
||||||
|
|
||||||
|
proc = self.summary_model.processor
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [{"type": "text", "text": prompt}],
|
||||||
|
}
|
||||||
|
]
|
||||||
|
final_vqa_prompt_text = proc.apply_chat_template(
|
||||||
|
messages, tokenize=False, add_generation_prompt=True
|
||||||
|
)
|
||||||
|
final_vqa_inputs = proc(
|
||||||
|
text=[final_vqa_prompt_text], return_tensors="pt", padding=True
|
||||||
|
)
|
||||||
|
final_vqa_inputs = {
|
||||||
|
k: v.to(self.summary_model.device) if isinstance(v, torch.Tensor) else v
|
||||||
|
for k, v in final_vqa_inputs.items()
|
||||||
|
}
|
||||||
|
|
||||||
|
final_vqa_list = self._generate_from_processor_inputs(
|
||||||
|
final_vqa_inputs,
|
||||||
|
[final_vqa_prompt_text],
|
||||||
|
self.summary_model.tokenizer,
|
||||||
|
)
|
||||||
|
|
||||||
|
final_vqa_output = final_vqa_list[0].strip() if final_vqa_list else ""
|
||||||
|
vqa_answers = []
|
||||||
|
answer_matches = re.findall(
|
||||||
|
r"\d+\.\s*(.*?)(?=\n\d+\.|$)", final_vqa_output, flags=re.DOTALL
|
||||||
|
)
|
||||||
|
for answer in answer_matches:
|
||||||
|
vqa_answers.append(answer.strip())
|
||||||
|
return {
|
||||||
|
"vqa_answers": vqa_answers,
|
||||||
|
}
|
||||||
|
|
||||||
def analyse_videos_from_dict(
|
def analyse_videos_from_dict(
|
||||||
self, frame_rate_per_second: float = 2.0
|
self,
|
||||||
|
analysis_type: Union[AnalysisType, str] = AnalysisType.SUMMARY,
|
||||||
|
frame_rate_per_second: float = 2.0,
|
||||||
|
list_of_questions: Optional[List[str]] = None,
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Analyse the video specified in self.subdict using frame extraction and captioning.
|
Analyse the video specified in self.subdict using frame extraction and captioning.
|
||||||
For short videos (<=100 frames at the specified frame rate), it uses brute-force captioning.
|
|
||||||
For longer videos, it currently defaults to brute-force captioning, but can be extended for more complex methods.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
analysis_type (Union[AnalysisType, str], optional): Type of analysis to perform. Defaults to AnalysisType.SUMMARY.
|
||||||
frame_rate_per_second (float): Frame extraction rate in frames per second. Default is 2.0.
|
frame_rate_per_second (float): Frame extraction rate in frames per second. Default is 2.0.
|
||||||
|
list_of_questions (List[str], optional): List of questions to answer about the video. Required if analysis_type includes questions.
|
||||||
Returns:
|
Returns:
|
||||||
Dict[str, Any]: A dictionary containing the analysis results, including captions and summary.
|
Dict[str, Any]: A dictionary containing the analysis results, including summary and answers for provided questions(if any).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
all_answers = {}
|
all_answers = {}
|
||||||
# TODO: add support for answering questions about videos
|
analysis_type, is_summary, is_questions = AnalysisType._validate_analysis_type(
|
||||||
for video_key in list(self.subdict.keys()):
|
analysis_type, list_of_questions
|
||||||
entry = self.subdict[video_key]
|
)
|
||||||
extracted_video_gen = self._extract_video_frames(
|
|
||||||
|
for video_key, entry in self.subdict.items():
|
||||||
|
summary_and_vqa = {}
|
||||||
|
frames_generator = self._extract_video_frames(
|
||||||
entry, frame_rate_per_second=frame_rate_per_second
|
entry, frame_rate_per_second=frame_rate_per_second
|
||||||
)
|
)
|
||||||
|
|
||||||
answer = self.brute_force_summary(extracted_video_gen)
|
answers_dict = self.make_captions_from_extracted_frames(
|
||||||
all_answers[video_key] = {"summary": answer["summary"]}
|
frames_generator, list_of_questions=list_of_questions
|
||||||
|
)
|
||||||
# TODO: captions has to be post-processed with foreseeing audio analysis
|
# TODO: captions has to be post-processed with foreseeing audio analysis
|
||||||
|
# TODO: captions and answers may lead to prompt, that superior model limits. Consider hierarchical approach.
|
||||||
|
if is_summary:
|
||||||
|
answer = self.final_summary(answers_dict)
|
||||||
|
summary_and_vqa["summary"] = answer["summary"]
|
||||||
|
if is_questions:
|
||||||
|
answer = self.final_answers(answers_dict, list_of_questions)
|
||||||
|
summary_and_vqa["vqa_answers"] = answer["vqa_answers"]
|
||||||
|
|
||||||
return answer
|
all_answers[video_key] = summary_and_vqa
|
||||||
|
|
||||||
|
return all_answers
|
||||||
|
|||||||
Загрузка…
x
Ссылка в новой задаче
Block a user