зеркало из
https://github.com/ssciwr/AMMICO.git
synced 2025-10-29 21:16:06 +02:00
Merge 3e036514cd3c05e70bd74eddeabe5bdb09c38890 into 3f9e855aebddf6eddaa81de1cd883bc5bcf5d3bc
Этот коммит содержится в:
Коммит
2ff2dffc59
2
.github/workflows/ci.yml
поставляемый
2
.github/workflows/ci.yml
поставляемый
@ -31,7 +31,7 @@ jobs:
|
||||
- name: Run pytest
|
||||
run: |
|
||||
cd ammico
|
||||
python -m pytest -svv -m "not gcv" --cov=. --cov-report=xml
|
||||
python -m pytest -svv -m "not gcv and not long" --cov=. --cov-report=xml
|
||||
- name: Upload coverage
|
||||
if: matrix.os == 'ubuntu-latest' && matrix.python-version == '3.11'
|
||||
uses: codecov/codecov-action@v3
|
||||
|
||||
@ -1,7 +1,10 @@
|
||||
from ammico.display import AnalysisExplorer
|
||||
from ammico.faces import EmotionDetector, ethical_disclosure
|
||||
from ammico.model import MultimodalSummaryModel
|
||||
from ammico.text import TextDetector, TextAnalyzer, privacy_disclosure
|
||||
from ammico.utils import find_files, get_dataframe
|
||||
from ammico.image_summary import ImageSummaryDetector
|
||||
from ammico.utils import find_files, get_dataframe, AnalysisType, find_videos
|
||||
from ammico.video_summary import VideoSummaryDetector
|
||||
|
||||
# Export the version defined in project metadata
|
||||
try:
|
||||
@ -12,11 +15,16 @@ except ImportError:
|
||||
__version__ = "unknown"
|
||||
|
||||
__all__ = [
|
||||
"AnalysisType",
|
||||
"AnalysisExplorer",
|
||||
"EmotionDetector",
|
||||
"MultimodalSummaryModel",
|
||||
"TextDetector",
|
||||
"TextAnalyzer",
|
||||
"ImageSummaryDetector",
|
||||
"VideoSummaryDetector",
|
||||
"find_files",
|
||||
"find_videos",
|
||||
"get_dataframe",
|
||||
"ethical_disclosure",
|
||||
"privacy_disclosure",
|
||||
|
||||
@ -1,10 +1,14 @@
|
||||
import ammico.faces as faces
|
||||
import ammico.text as text
|
||||
import ammico.colors as colors
|
||||
import ammico.image_summary as image_summary
|
||||
from ammico.model import MultimodalSummaryModel
|
||||
import pandas as pd
|
||||
from dash import html, Input, Output, dcc, State, Dash
|
||||
from PIL import Image
|
||||
import dash_bootstrap_components as dbc
|
||||
import warnings
|
||||
from typing import Dict, Any, List, Optional
|
||||
|
||||
|
||||
COLOR_SCHEMES = [
|
||||
@ -94,6 +98,8 @@ class AnalysisExplorer:
|
||||
State("left_select_id", "options"),
|
||||
State("left_select_id", "value"),
|
||||
State("Dropdown_select_Detector", "value"),
|
||||
State("Dropdown_analysis_type", "value"),
|
||||
State("textarea_questions", "value"),
|
||||
State("setting_Text_analyse_text", "value"),
|
||||
State("setting_privacy_env_var", "value"),
|
||||
State("setting_Emotion_emotion_threshold", "value"),
|
||||
@ -108,9 +114,15 @@ class AnalysisExplorer:
|
||||
Output("settings_TextDetector", "style"),
|
||||
Output("settings_EmotionDetector", "style"),
|
||||
Output("settings_ColorDetector", "style"),
|
||||
Output("settings_VQA", "style"),
|
||||
Input("Dropdown_select_Detector", "value"),
|
||||
)(self._update_detector_setting)
|
||||
|
||||
self.app.callback(
|
||||
Output("textarea_questions", "style"),
|
||||
Input("Dropdown_analysis_type", "value"),
|
||||
)(self._show_questions_textarea_on_demand)
|
||||
|
||||
# I split the different sections into subfunctions for better clarity
|
||||
def _top_file_explorer(self, mydict: dict) -> html.Div:
|
||||
"""Initialize the file explorer dropdown for selecting the file to be analyzed.
|
||||
@ -157,14 +169,6 @@ class AnalysisExplorer:
|
||||
id="settings_TextDetector",
|
||||
style={"display": "none"},
|
||||
children=[
|
||||
dbc.Row(
|
||||
dcc.Checklist(
|
||||
["Analyse text"],
|
||||
["Analyse text"],
|
||||
id="setting_Text_analyse_text",
|
||||
style={"margin-bottom": "10px"},
|
||||
),
|
||||
),
|
||||
# row 1
|
||||
dbc.Row(
|
||||
dbc.Col(
|
||||
@ -272,8 +276,69 @@ class AnalysisExplorer:
|
||||
)
|
||||
],
|
||||
),
|
||||
# start VQA settings
|
||||
html.Div(
|
||||
id="settings_VQA",
|
||||
style={"display": "none"},
|
||||
children=[
|
||||
dbc.Card(
|
||||
[
|
||||
dbc.CardBody(
|
||||
[
|
||||
dbc.Row(
|
||||
dbc.Col(
|
||||
dcc.Dropdown(
|
||||
id="Dropdown_analysis_type",
|
||||
options=[
|
||||
{"label": v, "value": v}
|
||||
for v in SUMMARY_ANALYSIS_TYPE
|
||||
],
|
||||
value="summary_and_questions",
|
||||
clearable=False,
|
||||
style={
|
||||
"width": "100%",
|
||||
"minWidth": "240px",
|
||||
"maxWidth": "520px",
|
||||
},
|
||||
),
|
||||
),
|
||||
justify="start",
|
||||
),
|
||||
html.Div(style={"height": "8px"}),
|
||||
dbc.Row(
|
||||
[
|
||||
dbc.Col(
|
||||
dcc.Textarea(
|
||||
id="textarea_questions",
|
||||
value="Are there people in the image?\nWhat is this picture about?",
|
||||
placeholder="One question per line...",
|
||||
style={
|
||||
"width": "100%",
|
||||
"minHeight": "160px",
|
||||
"height": "220px",
|
||||
"resize": "vertical",
|
||||
"overflow": "auto",
|
||||
},
|
||||
rows=8,
|
||||
),
|
||||
width=12,
|
||||
),
|
||||
],
|
||||
justify="start",
|
||||
),
|
||||
]
|
||||
)
|
||||
],
|
||||
style={
|
||||
"width": "100%",
|
||||
"marginTop": "10px",
|
||||
"zIndex": 2000,
|
||||
},
|
||||
)
|
||||
],
|
||||
),
|
||||
],
|
||||
style={"width": "100%", "display": "inline-block"},
|
||||
style={"width": "100%", "display": "inline-block", "overflow": "visible"},
|
||||
)
|
||||
return settings_layout
|
||||
|
||||
@ -293,6 +358,7 @@ class AnalysisExplorer:
|
||||
"TextDetector",
|
||||
"EmotionDetector",
|
||||
"ColorDetector",
|
||||
"VQA",
|
||||
],
|
||||
value="TextDetector",
|
||||
id="Dropdown_select_Detector",
|
||||
@ -344,7 +410,7 @@ class AnalysisExplorer:
|
||||
port (int, optional): The port number to run the server on (default: 8050).
|
||||
"""
|
||||
|
||||
self.app.run_server(debug=True, port=port)
|
||||
self.app.run(debug=True, port=port)
|
||||
|
||||
# Dash callbacks
|
||||
def update_picture(self, img_path: str):
|
||||
@ -379,19 +445,27 @@ class AnalysisExplorer:
|
||||
|
||||
if setting_input == "EmotionDetector":
|
||||
return display_none, display_flex, display_none, display_none
|
||||
|
||||
if setting_input == "ColorDetector":
|
||||
return display_none, display_none, display_flex, display_none
|
||||
|
||||
if setting_input == "VQA":
|
||||
return display_none, display_none, display_none, display_flex
|
||||
else:
|
||||
return display_none, display_none, display_none, display_none
|
||||
|
||||
def _parse_questions(self, text: Optional[str]) -> Optional[List[str]]:
|
||||
if not text:
|
||||
return None
|
||||
qs = [q.strip() for q in text.splitlines() if q.strip()]
|
||||
return qs if qs else None
|
||||
|
||||
def _right_output_analysis(
|
||||
self,
|
||||
n_clicks,
|
||||
all_img_options: dict,
|
||||
current_img_value: str,
|
||||
detector_value: str,
|
||||
analysis_type_value: str,
|
||||
textarea_questions_value: str,
|
||||
settings_text_analyse_text: list,
|
||||
setting_privacy_env_var: str,
|
||||
setting_emotion_emotion_threshold: int,
|
||||
@ -413,54 +487,75 @@ class AnalysisExplorer:
|
||||
"EmotionDetector": faces.EmotionDetector,
|
||||
"TextDetector": text.TextDetector,
|
||||
"ColorDetector": colors.ColorDetector,
|
||||
"VQA": image_summary.ImageSummaryDetector,
|
||||
}
|
||||
|
||||
# Get image ID from dropdown value, which is the filepath
|
||||
if current_img_value is None:
|
||||
return {}
|
||||
image_id = all_img_options[current_img_value]
|
||||
# copy image so prvious runs don't leave their default values in the dict
|
||||
image_copy = self.mydict[image_id].copy()
|
||||
image_copy = self.mydict.get(image_id, {}).copy()
|
||||
|
||||
# detector value is the string name of the chosen detector
|
||||
identify_function = identify_dict[detector_value]
|
||||
|
||||
if detector_value == "TextDetector":
|
||||
analyse_text = (
|
||||
True if settings_text_analyse_text == ["Analyse text"] else False
|
||||
)
|
||||
detector_class = identify_function(
|
||||
image_copy,
|
||||
analyse_text=analyse_text,
|
||||
accept_privacy=(
|
||||
setting_privacy_env_var
|
||||
if setting_privacy_env_var
|
||||
else "PRIVACY_AMMICO"
|
||||
),
|
||||
)
|
||||
elif detector_value == "EmotionDetector":
|
||||
detector_class = identify_function(
|
||||
image_copy,
|
||||
emotion_threshold=setting_emotion_emotion_threshold,
|
||||
race_threshold=setting_emotion_race_threshold,
|
||||
gender_threshold=setting_emotion_gender_threshold,
|
||||
accept_disclosure=(
|
||||
setting_emotion_env_var
|
||||
if setting_emotion_env_var
|
||||
else "DISCLOSURE_AMMICO"
|
||||
),
|
||||
)
|
||||
elif detector_value == "ColorDetector":
|
||||
detector_class = identify_function(
|
||||
image_copy,
|
||||
delta_e_method=setting_color_delta_e_method,
|
||||
)
|
||||
analysis_dict: Dict[str, Any] = {}
|
||||
if detector_value == "VQA":
|
||||
try:
|
||||
qwen_model = MultimodalSummaryModel(
|
||||
model_id="Qwen/Qwen2.5-VL-3B-Instruct"
|
||||
) # TODO: allow user to specify model
|
||||
vqa_cls = identify_dict.get("VQA")
|
||||
vqa_detector = vqa_cls(qwen_model, subdict={})
|
||||
questions_list = self._parse_questions(textarea_questions_value)
|
||||
analysis_result = vqa_detector.analyse_image(
|
||||
image_copy,
|
||||
analysis_type=analysis_type_value,
|
||||
list_of_questions=questions_list,
|
||||
is_concise_summary=True,
|
||||
is_concise_answer=True,
|
||||
)
|
||||
analysis_dict = analysis_result or {}
|
||||
except Exception as e:
|
||||
warnings.warn(f"VQA/Image tasks failed: {e}")
|
||||
analysis_dict = {"image_tasks_error": str(e)}
|
||||
else:
|
||||
detector_class = identify_function(image_copy)
|
||||
analysis_dict = detector_class.analyse_image()
|
||||
# detector value is the string name of the chosen detector
|
||||
identify_function = identify_dict[detector_value]
|
||||
|
||||
# Initialize an empty dictionary
|
||||
new_analysis_dict = {}
|
||||
if detector_value == "TextDetector":
|
||||
analyse_text = (
|
||||
True if settings_text_analyse_text == ["Analyse text"] else False
|
||||
)
|
||||
detector_class = identify_function(
|
||||
image_copy,
|
||||
analyse_text=analyse_text,
|
||||
accept_privacy=(
|
||||
setting_privacy_env_var
|
||||
if setting_privacy_env_var
|
||||
else "PRIVACY_AMMICO"
|
||||
),
|
||||
)
|
||||
elif detector_value == "EmotionDetector":
|
||||
detector_class = identify_function(
|
||||
image_copy,
|
||||
emotion_threshold=setting_emotion_emotion_threshold,
|
||||
race_threshold=setting_emotion_race_threshold,
|
||||
gender_threshold=setting_emotion_gender_threshold,
|
||||
accept_disclosure=(
|
||||
setting_emotion_env_var
|
||||
if setting_emotion_env_var
|
||||
else "DISCLOSURE_AMMICO"
|
||||
),
|
||||
)
|
||||
elif detector_value == "ColorDetector":
|
||||
detector_class = identify_function(
|
||||
image_copy,
|
||||
delta_e_method=setting_color_delta_e_method,
|
||||
)
|
||||
else:
|
||||
detector_class = identify_function(image_copy)
|
||||
|
||||
analysis_dict = detector_class.analyse_image()
|
||||
|
||||
new_analysis_dict: Dict[str, Any] = {}
|
||||
|
||||
# Iterate over the items in the original dictionary
|
||||
for k, v in analysis_dict.items():
|
||||
@ -480,3 +575,9 @@ class AnalysisExplorer:
|
||||
return dbc.Table.from_dataframe(
|
||||
df, striped=True, bordered=True, hover=True, index=True
|
||||
)
|
||||
|
||||
def _show_questions_textarea_on_demand(self, analysis_type_value: str) -> dict:
|
||||
if analysis_type_value in ("questions", "summary_and_questions"):
|
||||
return {"display": "block", "width": "100%"}
|
||||
else:
|
||||
return {"display": "none"}
|
||||
|
||||
365
ammico/image_summary.py
Обычный файл
365
ammico/image_summary.py
Обычный файл
@ -0,0 +1,365 @@
|
||||
from ammico.utils import AnalysisMethod, AnalysisType
|
||||
from ammico.model import MultimodalSummaryModel
|
||||
|
||||
import os
|
||||
import torch
|
||||
from PIL import Image
|
||||
import warnings
|
||||
|
||||
from typing import List, Optional, Union, Dict, Any
|
||||
from collections.abc import Sequence as _Sequence
|
||||
from transformers import GenerationConfig
|
||||
from qwen_vl_utils import process_vision_info
|
||||
|
||||
|
||||
class ImageSummaryDetector(AnalysisMethod):
|
||||
def __init__(
|
||||
self,
|
||||
summary_model: MultimodalSummaryModel,
|
||||
subdict: Optional[Dict[str, Any]] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Class for analysing images using QWEN-2.5-VL model.
|
||||
It provides methods for generating captions and answering questions about images.
|
||||
|
||||
Args:
|
||||
summary_model ([type], optional): An instance of MultimodalSummaryModel to be used for analysis.
|
||||
subdict (dict, optional): Dictionary containing the image to be analysed. Defaults to {}.
|
||||
|
||||
Returns:
|
||||
None.
|
||||
"""
|
||||
if subdict is None:
|
||||
subdict = {}
|
||||
|
||||
super().__init__(subdict)
|
||||
self.summary_model = summary_model
|
||||
|
||||
def _load_pil_if_needed(
|
||||
self, filename: Union[str, os.PathLike, Image.Image]
|
||||
) -> Image.Image:
|
||||
if isinstance(filename, (str, os.PathLike)):
|
||||
return Image.open(filename).convert("RGB")
|
||||
elif isinstance(filename, Image.Image):
|
||||
return filename.convert("RGB")
|
||||
else:
|
||||
raise ValueError("filename must be a path or PIL.Image")
|
||||
|
||||
@staticmethod
|
||||
def _is_sequence_but_not_str(obj: Any) -> bool:
|
||||
"""True for sequence-like but not a string/bytes/PIL.Image."""
|
||||
return isinstance(obj, _Sequence) and not isinstance(
|
||||
obj, (str, bytes, Image.Image)
|
||||
)
|
||||
|
||||
def _prepare_inputs(
|
||||
self, list_of_questions: list[str], entry: Optional[Dict[str, Any]] = None
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
filename = entry.get("filename")
|
||||
if filename is None:
|
||||
raise ValueError("entry must contain key 'filename'")
|
||||
|
||||
if isinstance(filename, (str, os.PathLike, Image.Image)):
|
||||
images_context = self._load_pil_if_needed(filename)
|
||||
elif self._is_sequence_but_not_str(filename):
|
||||
images_context = [self._load_pil_if_needed(i) for i in filename]
|
||||
else:
|
||||
raise ValueError(
|
||||
"Unsupported 'filename' entry: expected path, PIL.Image, or sequence."
|
||||
)
|
||||
|
||||
images_only_messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
*(
|
||||
[{"type": "image", "image": img} for img in images_context]
|
||||
if isinstance(images_context, list)
|
||||
else [{"type": "image", "image": images_context}]
|
||||
)
|
||||
],
|
||||
}
|
||||
]
|
||||
|
||||
try:
|
||||
image_inputs, _ = process_vision_info(images_only_messages)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Image processing failed: {e}")
|
||||
|
||||
texts: List[str] = []
|
||||
for q in list_of_questions:
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
*(
|
||||
[
|
||||
{"type": "image", "image": image}
|
||||
for image in images_context
|
||||
]
|
||||
if isinstance(images_context, list)
|
||||
else [{"type": "image", "image": images_context}]
|
||||
),
|
||||
{"type": "text", "text": q},
|
||||
],
|
||||
}
|
||||
]
|
||||
text = self.summary_model.processor.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
texts.append(text)
|
||||
|
||||
images_batch = [image_inputs] * len(texts)
|
||||
inputs = self.summary_model.processor(
|
||||
text=texts,
|
||||
images=images_batch,
|
||||
padding=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
inputs = {k: v.to(self.summary_model.device) for k, v in inputs.items()}
|
||||
|
||||
return inputs
|
||||
|
||||
def analyse_image(
|
||||
self,
|
||||
entry: dict,
|
||||
analysis_type: Union[str, AnalysisType] = AnalysisType.SUMMARY,
|
||||
list_of_questions: Optional[List[str]] = None,
|
||||
is_concise_summary: bool = True,
|
||||
is_concise_answer: bool = True,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Analyse a single image entry. Returns dict with keys depending on analysis_type:
|
||||
- 'caption' (str) if summary requested
|
||||
- 'vqa' (dict) if questions requested
|
||||
"""
|
||||
self.subdict = entry
|
||||
analysis_type, is_summary, is_questions = AnalysisType._validate_analysis_type(
|
||||
analysis_type, list_of_questions
|
||||
)
|
||||
|
||||
if is_summary:
|
||||
try:
|
||||
caps = self.generate_caption(
|
||||
entry,
|
||||
num_return_sequences=1,
|
||||
is_concise_summary=is_concise_summary,
|
||||
)
|
||||
self.subdict["caption"] = caps[0] if caps else ""
|
||||
except Exception as e:
|
||||
warnings.warn(f"Caption generation failed: {e}")
|
||||
|
||||
if is_questions:
|
||||
try:
|
||||
vqa_map = self.answer_questions(
|
||||
list_of_questions, entry, is_concise_answer
|
||||
)
|
||||
self.subdict["vqa"] = vqa_map
|
||||
except Exception as e:
|
||||
warnings.warn(f"VQA failed: {e}")
|
||||
|
||||
return self.subdict
|
||||
|
||||
def analyse_images_from_dict(
|
||||
self,
|
||||
analysis_type: Union[AnalysisType, str] = AnalysisType.SUMMARY,
|
||||
list_of_questions: Optional[List[str]] = None,
|
||||
keys_batch_size: int = 16,
|
||||
is_concise_summary: bool = True,
|
||||
is_concise_answer: bool = True,
|
||||
) -> Dict[str, dict]:
|
||||
"""
|
||||
Analyse image with model.
|
||||
|
||||
Args:
|
||||
analysis_type (str): type of the analysis.
|
||||
list_of_questions (list[str]): list of questions.
|
||||
keys_batch_size (int): number of images to process in a batch.
|
||||
is_concise_summary (bool): whether to generate concise summary.
|
||||
is_concise_answer (bool): whether to generate concise answers.
|
||||
Returns:
|
||||
self.subdict (dict): dictionary with analysis results.
|
||||
"""
|
||||
# TODO: add option to ask multiple questions per image as one batch.
|
||||
analysis_type, is_summary, is_questions = AnalysisType._validate_analysis_type(
|
||||
analysis_type, list_of_questions
|
||||
)
|
||||
|
||||
keys = list(self.subdict.keys())
|
||||
for batch_start in range(0, len(keys), keys_batch_size):
|
||||
batch_keys = keys[batch_start : batch_start + keys_batch_size]
|
||||
for key in batch_keys:
|
||||
entry = self.subdict[key]
|
||||
if is_summary:
|
||||
try:
|
||||
caps = self.generate_caption(
|
||||
entry,
|
||||
num_return_sequences=1,
|
||||
is_concise_summary=is_concise_summary,
|
||||
)
|
||||
entry["caption"] = caps[0] if caps else ""
|
||||
except Exception as e:
|
||||
warnings.warn(f"Caption generation failed: {e}")
|
||||
|
||||
if is_questions:
|
||||
try:
|
||||
vqa_map = self.answer_questions(
|
||||
list_of_questions, entry, is_concise_answer
|
||||
)
|
||||
entry["vqa"] = vqa_map
|
||||
except Exception as e:
|
||||
warnings.warn(f"VQA failed: {e}")
|
||||
|
||||
self.subdict[key] = entry
|
||||
return self.subdict
|
||||
|
||||
def generate_caption(
|
||||
self,
|
||||
entry: Optional[Dict[str, Any]] = None,
|
||||
num_return_sequences: int = 1,
|
||||
is_concise_summary: bool = True,
|
||||
) -> List[str]:
|
||||
"""
|
||||
Create caption for image. Depending on is_concise_summary it will be either concise or detailed.
|
||||
|
||||
Args:
|
||||
entry (dict): dictionary containing the image to be captioned.
|
||||
num_return_sequences (int): number of captions to generate.
|
||||
is_concise_summary (bool): whether to generate concise summary.
|
||||
|
||||
Returns:
|
||||
results (list[str]): list of generated captions.
|
||||
"""
|
||||
if is_concise_summary:
|
||||
prompt = ["Describe this image in one concise caption."]
|
||||
max_new_tokens = 64
|
||||
else:
|
||||
prompt = ["Describe this image."]
|
||||
max_new_tokens = 256
|
||||
inputs = self._prepare_inputs(prompt, entry)
|
||||
|
||||
gen_conf = GenerationConfig(
|
||||
max_new_tokens=max_new_tokens,
|
||||
do_sample=False,
|
||||
num_return_sequences=num_return_sequences,
|
||||
)
|
||||
|
||||
with torch.inference_mode():
|
||||
try:
|
||||
if self.summary_model.device == "cuda":
|
||||
with torch.amp.autocast("cuda", enabled=True):
|
||||
generated_ids = self.summary_model.model.generate(
|
||||
**inputs, generation_config=gen_conf
|
||||
)
|
||||
else:
|
||||
generated_ids = self.summary_model.model.generate(
|
||||
**inputs, generation_config=gen_conf
|
||||
)
|
||||
except RuntimeError as e:
|
||||
warnings.warn(
|
||||
f"Retry without autocast failed: {e}. Attempting cudnn-disabled retry."
|
||||
)
|
||||
cudnn_was_enabled = (
|
||||
torch.backends.cudnn.is_available() and torch.backends.cudnn.enabled
|
||||
)
|
||||
if cudnn_was_enabled:
|
||||
torch.backends.cudnn.enabled = False
|
||||
try:
|
||||
generated_ids = self.summary_model.model.generate(
|
||||
**inputs, generation_config=gen_conf
|
||||
)
|
||||
except Exception as retry_error:
|
||||
raise RuntimeError(
|
||||
f"Failed to generate ids after retry: {retry_error}"
|
||||
) from retry_error
|
||||
finally:
|
||||
if cudnn_was_enabled:
|
||||
torch.backends.cudnn.enabled = True
|
||||
|
||||
decoded = None
|
||||
if "input_ids" in inputs:
|
||||
in_ids = inputs["input_ids"]
|
||||
trimmed = [
|
||||
out_ids[len(inp_ids) :]
|
||||
for inp_ids, out_ids in zip(in_ids, generated_ids)
|
||||
]
|
||||
decoded = self.summary_model.tokenizer.batch_decode(
|
||||
trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
||||
)
|
||||
else:
|
||||
decoded = self.summary_model.tokenizer.batch_decode(
|
||||
generated_ids,
|
||||
skip_special_tokens=True,
|
||||
clean_up_tokenization_spaces=False,
|
||||
)
|
||||
|
||||
results = [d.strip() for d in decoded]
|
||||
return results
|
||||
|
||||
def answer_questions(
|
||||
self,
|
||||
list_of_questions: list[str],
|
||||
entry: Optional[Dict[str, Any]] = None,
|
||||
is_concise_answer: bool = True,
|
||||
) -> List[str]:
|
||||
"""
|
||||
Create answers for list of questions about image.
|
||||
Args:
|
||||
list_of_questions (list[str]): list of questions.
|
||||
entry (dict): dictionary containing the image to be captioned.
|
||||
is_concise_answer (bool): whether to generate concise answers.
|
||||
Returns:
|
||||
answers (list[str]): list of answers.
|
||||
"""
|
||||
if is_concise_answer:
|
||||
gen_conf = GenerationConfig(max_new_tokens=64, do_sample=False)
|
||||
for i in range(len(list_of_questions)):
|
||||
if not list_of_questions[i].strip().endswith("?"):
|
||||
list_of_questions[i] = list_of_questions[i].strip() + "?"
|
||||
if not list_of_questions[i].lower().startswith("answer concisely"):
|
||||
list_of_questions[i] = "Answer concisely: " + list_of_questions[i]
|
||||
else:
|
||||
gen_conf = GenerationConfig(max_new_tokens=128, do_sample=False)
|
||||
|
||||
question_chunk_size = 8
|
||||
answers: List[str] = []
|
||||
n = len(list_of_questions)
|
||||
for i in range(0, n, question_chunk_size):
|
||||
chunk = list_of_questions[i : i + question_chunk_size]
|
||||
inputs = self._prepare_inputs(chunk, entry)
|
||||
with torch.inference_mode():
|
||||
if self.summary_model.device == "cuda":
|
||||
with torch.amp.autocast("cuda", enabled=True):
|
||||
out_ids = self.summary_model.model.generate(
|
||||
**inputs, generation_config=gen_conf
|
||||
)
|
||||
else:
|
||||
out_ids = self.summary_model.model.generate(
|
||||
**inputs, generation_config=gen_conf
|
||||
)
|
||||
|
||||
if "input_ids" in inputs:
|
||||
in_ids = inputs["input_ids"]
|
||||
trimmed_batch = [
|
||||
out_row[len(inp_row) :] for inp_row, out_row in zip(in_ids, out_ids)
|
||||
]
|
||||
decoded = self.summary_model.tokenizer.batch_decode(
|
||||
trimmed_batch,
|
||||
skip_special_tokens=True,
|
||||
clean_up_tokenization_spaces=False,
|
||||
)
|
||||
else:
|
||||
decoded = self.summary_model.tokenizer.batch_decode(
|
||||
out_ids,
|
||||
skip_special_tokens=True,
|
||||
clean_up_tokenization_spaces=False,
|
||||
)
|
||||
|
||||
answers.extend([d.strip() for d in decoded])
|
||||
|
||||
if len(answers) != len(list_of_questions):
|
||||
raise ValueError(
|
||||
f"Expected {len(list_of_questions)} answers, but got {len(answers)}, try vary amount of questions"
|
||||
)
|
||||
|
||||
return answers
|
||||
126
ammico/model.py
Обычный файл
126
ammico/model.py
Обычный файл
@ -0,0 +1,126 @@
|
||||
import torch
|
||||
import warnings
|
||||
from transformers import (
|
||||
Qwen2_5_VLForConditionalGeneration,
|
||||
AutoProcessor,
|
||||
BitsAndBytesConfig,
|
||||
AutoTokenizer,
|
||||
)
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class MultimodalSummaryModel:
|
||||
DEFAULT_CUDA_MODEL = "Qwen/Qwen2.5-VL-7B-Instruct"
|
||||
DEFAULT_CPU_MODEL = "Qwen/Qwen2.5-VL-3B-Instruct"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_id: Optional[str] = None,
|
||||
device: Optional[str] = None,
|
||||
cache_dir: Optional[str] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Class for QWEN-2.5-VL model loading and inference.
|
||||
Args:
|
||||
model_id: Type of model to load, defaults to a smaller version for CPU if device is "cpu".
|
||||
device: "cuda" or "cpu" (auto-detected when None).
|
||||
cache_dir: huggingface cache dir (optional).
|
||||
"""
|
||||
self.device = self._resolve_device(device)
|
||||
|
||||
if model_id is not None and model_id not in (
|
||||
self.DEFAULT_CUDA_MODEL,
|
||||
self.DEFAULT_CPU_MODEL,
|
||||
):
|
||||
raise ValueError(
|
||||
f"model_id must be one of {self.DEFAULT_CUDA_MODEL} or {self.DEFAULT_CPU_MODEL}"
|
||||
)
|
||||
|
||||
self.model_id = model_id or (
|
||||
self.DEFAULT_CUDA_MODEL if self.device == "cuda" else self.DEFAULT_CPU_MODEL
|
||||
)
|
||||
|
||||
self.cache_dir = cache_dir
|
||||
self._trust_remote_code = True
|
||||
self._quantize = True
|
||||
|
||||
self.model = None
|
||||
self.processor = None
|
||||
self.tokenizer = None
|
||||
|
||||
self._load_model_and_processor()
|
||||
|
||||
@staticmethod
|
||||
def _resolve_device(device: Optional[str]) -> str:
|
||||
if device is None:
|
||||
return "cuda" if torch.cuda.is_available() else "cpu"
|
||||
if device.lower() not in ("cuda", "cpu"):
|
||||
raise ValueError("device must be 'cuda' or 'cpu'")
|
||||
if device.lower() == "cuda" and not torch.cuda.is_available():
|
||||
warnings.warn(
|
||||
"Although 'cuda' was requested, no CUDA device is available. Using CPU instead.",
|
||||
RuntimeWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
return "cpu"
|
||||
return device.lower()
|
||||
|
||||
def _load_model_and_processor(self):
|
||||
load_kwargs = {"trust_remote_code": self._trust_remote_code, "use_cache": True}
|
||||
if self.cache_dir:
|
||||
load_kwargs["cache_dir"] = self.cache_dir
|
||||
|
||||
self.processor = AutoProcessor.from_pretrained(
|
||||
self.model_id, padding_side="left", **load_kwargs
|
||||
)
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(self.model_id, **load_kwargs)
|
||||
|
||||
if self.device == "cuda":
|
||||
compute_dtype = (
|
||||
torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
|
||||
)
|
||||
bnb_config = BitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
bnb_4bit_use_double_quant=True,
|
||||
bnb_4bit_quant_type="nf4",
|
||||
bnb_4bit_compute_dtype=compute_dtype,
|
||||
)
|
||||
load_kwargs["quantization_config"] = bnb_config
|
||||
load_kwargs["device_map"] = "auto"
|
||||
|
||||
else:
|
||||
load_kwargs.pop("quantization_config", None)
|
||||
load_kwargs.pop("device_map", None)
|
||||
|
||||
self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
||||
self.model_id, **load_kwargs
|
||||
)
|
||||
self.model.eval()
|
||||
|
||||
def _close(self) -> None:
|
||||
"""Free model resources (helpful in long-running processes)."""
|
||||
try:
|
||||
if self.model is not None:
|
||||
del self.model
|
||||
self.model = None
|
||||
if self.processor is not None:
|
||||
del self.processor
|
||||
self.processor = None
|
||||
if self.tokenizer is not None:
|
||||
del self.tokenizer
|
||||
self.tokenizer = None
|
||||
finally:
|
||||
try:
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
except Exception as e:
|
||||
warnings.warn(
|
||||
"Failed to empty CUDA cache. This is not critical, but may lead to memory lingering: "
|
||||
f"{e!r}",
|
||||
RuntimeWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
def close(self) -> None:
|
||||
"""Free model resources (helpful in long-running processes)."""
|
||||
self._close()
|
||||
190
ammico/notebooks/DemoImageSummaryVQA.ipynb
Обычный файл
190
ammico/notebooks/DemoImageSummaryVQA.ipynb
Обычный файл
@ -0,0 +1,190 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "0",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Image summary and visual question answering"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "1",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"This notebook shows how to generate image captions and use the visual question answering with AMMICO. \n",
|
||||
"\n",
|
||||
"The first cell imports `ammico`.\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "2",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import ammico"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "3",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"The cell below loads the model for VQA tasks. By default, it loads a large model on the GPU (if your device supports CUDA), otherwise it loads a relatively smaller model on the CPU. But you can specify other settings (e.g., a small model on the GPU) if you want."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "4",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"model = ammico.MultimodalSummaryModel()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "5",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Here you need to provide the path to your google drive folder or local folder containing the images"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "6",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"image_dict = ammico.find_files(\n",
|
||||
" path=str(\"/insert/your/path/here/\"),\n",
|
||||
" limit=-1, # -1 means no limit on the number of files, by default it is set to 20\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "7",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"The cell below creates an object that analyzes images and generates a summary using a specific model and image data."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "8",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"img = ammico.ImageSummaryDetector(summary_model=model, subdict=image_dict)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "9",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Image summary "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "10",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"To start your work with images, you should call the `analyse_images` method.\n",
|
||||
"\n",
|
||||
"You can specify what kind of analysis you want to perform with `analysis_type`. `\"summary\"` will generate a summary for all pictures in your dictionary, `\"questions\"` will prepare answers to your questions for all pictures, and `\"summary_and_questions\"` will do both.\n",
|
||||
"\n",
|
||||
"Parameter `\"is_concise_summary\"` regulates the length of an answer.\n",
|
||||
"\n",
|
||||
"Here we want to get a long summary on each object in our image dictionary."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "11",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"summaries = img.analyse_images(analysis_type=\"summary\", is_concise_summary=False)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "12",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## VQA"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "13",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"In addition to analyzing images in `ammico`, the same model can be used in VQA mode. To do this, you need to define the questions that will be applied to all images from your dict."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "14",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"questions = [\"Are there any visible signs of violence?\", \"Is it safe to be there?\"]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "15",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Here is an example of VQA mode usage. You can specify whether you want to receive short answers (recommended option) or not."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "16",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"vqa_results = img.analyse_images(\n",
|
||||
" analysis_type=\"questions\",\n",
|
||||
" list_of_questions=questions,\n",
|
||||
" is_concise_answer=True,\n",
|
||||
")"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "ammico-dev",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.13"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
@ -104,7 +104,8 @@
|
||||
"import ammico\n",
|
||||
"\n",
|
||||
"# for displaying a progress bar\n",
|
||||
"from tqdm import tqdm"
|
||||
"from tqdm import tqdm\n",
|
||||
"import os"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -140,7 +141,9 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# os.environ[\"GOOGLE_APPLICATION_CREDENTIALS\"] = \"/content/drive/MyDrive/misinformation-data/misinformation-campaign-981aa55a3b13.json\""
|
||||
"os.environ[\"GOOGLE_APPLICATION_CREDENTIALS\"] = (\n",
|
||||
" \"/home/inga/projects/misinformation-project/misinformation-notes/misinformation-campaign-981aa55a3b13.json\"\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -171,6 +174,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"data_path = \"./data-test\"\n",
|
||||
"image_dict = ammico.find_files(\n",
|
||||
" # path = \"/content/drive/MyDrive/misinformation-data/\",\n",
|
||||
" path=str(data_path),\n",
|
||||
@ -337,7 +341,7 @@
|
||||
" enumerate(image_dict.keys()), total=len(image_dict)\n",
|
||||
"): # loop through all images\n",
|
||||
" image_dict[key] = ammico.TextDetector(\n",
|
||||
" image_dict[key], analyse_text=True\n",
|
||||
" image_dict[key]\n",
|
||||
" ).analyse_image() # analyse image with EmotionDetector and update dict\n",
|
||||
"\n",
|
||||
" if (\n",
|
||||
@ -361,23 +365,12 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# initialize the models\n",
|
||||
"image_summary_detector = ammico.SummaryDetector(\n",
|
||||
" subdict=image_dict, analysis_type=\"summary\", model_type=\"base\"\n",
|
||||
"model = ammico.MultimodalSummaryModel()\n",
|
||||
"image_summary_detector = ammico.ImageSummaryDetector(\n",
|
||||
" subdict=image_dict, summary_model=model\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"# run the analysis without having to re-iniatialize the model\n",
|
||||
"for num, key in tqdm(\n",
|
||||
" enumerate(image_dict.keys()), total=len(image_dict)\n",
|
||||
"): # loop through all images\n",
|
||||
" image_dict[key] = image_summary_detector.analyse_image(\n",
|
||||
" subdict=image_dict[key], analysis_type=\"summary\"\n",
|
||||
" ) # analyse image with SummaryDetector and update dict\n",
|
||||
"\n",
|
||||
" if (\n",
|
||||
" num % dump_every == 0 | num == len(image_dict) - 1\n",
|
||||
" ): # save results every dump_every to dump_file\n",
|
||||
" image_df = ammico.get_dataframe(image_dict)\n",
|
||||
" image_df.to_csv(dump_file)"
|
||||
"image_summary_detector.analyse_images(analysis_type=\"summary\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -394,6 +387,7 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# initialize the models\n",
|
||||
"# currently this does not work because of the way the summary detector is implemented\n",
|
||||
"image_summary_detector = ammico.SummaryDetector(\n",
|
||||
" subdict=image_dict, analysis_type=\"summary\", model_type=\"base\"\n",
|
||||
")\n",
|
||||
|
||||
96
ammico/notebooks/DemoVideoSummaryVQA.ipynb
Обычный файл
96
ammico/notebooks/DemoVideoSummaryVQA.ipynb
Обычный файл
@ -0,0 +1,96 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Video summary and visual question answering"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import ammico"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Currently this module supports only video summarization, but it will be updated in the nearest future"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"video_dict = ammico.find_videos(\n",
|
||||
" path=str(\"/insert/your/path/here/\"), # path to the folder with images\n",
|
||||
" limit=-1, # -1 means no limit on the number of files, by default it is set to 20\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"model = ammico.MultimodalSummaryModel()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"vid_summary_model = ammico.VideoSummaryDetector(summary_model=model, subdict=video_dict)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"summary_dict = vid_summary_model.analyse_video()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"summary_dict[\"summary\"]"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "ammico-dev",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.13"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
||||
@ -1,5 +1,6 @@
|
||||
import os
|
||||
import pytest
|
||||
from ammico.model import MultimodalSummaryModel
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -46,3 +47,12 @@ def get_test_my_dict(get_path):
|
||||
},
|
||||
}
|
||||
return test_my_dict
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def model():
|
||||
m = MultimodalSummaryModel(device="cpu")
|
||||
try:
|
||||
yield m
|
||||
finally:
|
||||
m.close()
|
||||
|
||||
@ -50,6 +50,8 @@ def test_right_output_analysis_emotions(get_AE, get_options, monkeypatch):
|
||||
get_options[3],
|
||||
get_options[0],
|
||||
"EmotionDetector",
|
||||
"summary",
|
||||
"Some question",
|
||||
True,
|
||||
"SOME_VAR",
|
||||
50,
|
||||
|
||||
37
ammico/test/test_image_summary.py
Обычный файл
37
ammico/test/test_image_summary.py
Обычный файл
@ -0,0 +1,37 @@
|
||||
from ammico.image_summary import ImageSummaryDetector
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.long
|
||||
def test_image_summary_detector(model, get_testdict):
|
||||
detector = ImageSummaryDetector(summary_model=model, subdict=get_testdict)
|
||||
results = detector.analyse_images_from_dict(analysis_type="summary")
|
||||
assert len(results) == 2
|
||||
for key in get_testdict.keys():
|
||||
assert key in results
|
||||
assert "caption" in results[key]
|
||||
assert isinstance(results[key]["caption"], str)
|
||||
assert len(results[key]["caption"]) > 0
|
||||
|
||||
|
||||
@pytest.mark.long
|
||||
def test_image_summary_detector_questions(model, get_testdict):
|
||||
list_of_questions = [
|
||||
"What is happening in the image?",
|
||||
"How many cars are in the image in total?",
|
||||
]
|
||||
detector = ImageSummaryDetector(summary_model=model, subdict=get_testdict)
|
||||
results = detector.analyse_images_from_dict(
|
||||
analysis_type="questions", list_of_questions=list_of_questions
|
||||
)
|
||||
assert len(results) == 2
|
||||
for key in get_testdict.keys():
|
||||
assert "vqa" in results[key]
|
||||
if key == "IMG_2746":
|
||||
assert "marathon" in results[key]["vqa"][0].lower()
|
||||
|
||||
if key == "IMG_2809":
|
||||
assert (
|
||||
"two" in results[key]["vqa"][1].lower() or "2" in results[key]["vqa"][1]
|
||||
)
|
||||
30
ammico/test/test_model.py
Обычный файл
30
ammico/test/test_model.py
Обычный файл
@ -0,0 +1,30 @@
|
||||
import pytest
|
||||
from ammico.model import MultimodalSummaryModel
|
||||
|
||||
|
||||
@pytest.mark.long
|
||||
def test_model_init(model):
|
||||
assert model.model is not None
|
||||
assert model.processor is not None
|
||||
assert model.tokenizer is not None
|
||||
assert model.device is not None
|
||||
|
||||
|
||||
@pytest.mark.long
|
||||
def test_model_invalid_device():
|
||||
with pytest.raises(ValueError):
|
||||
MultimodalSummaryModel(device="invalid_device")
|
||||
|
||||
|
||||
@pytest.mark.long
|
||||
def test_model_invalid_model_id():
|
||||
with pytest.raises(ValueError):
|
||||
MultimodalSummaryModel(model_id="non_existent_model", device="cpu")
|
||||
|
||||
|
||||
@pytest.mark.long
|
||||
def test_free_resources():
|
||||
model = MultimodalSummaryModel(device="cpu")
|
||||
model.close()
|
||||
assert model.model is None
|
||||
assert model.processor is None
|
||||
@ -52,24 +52,16 @@ def test_privacy_statement(monkeypatch):
|
||||
def test_TextDetector(set_testdict, accepted):
|
||||
for item in set_testdict:
|
||||
test_obj = tt.TextDetector(set_testdict[item], accept_privacy=accepted)
|
||||
assert not test_obj.analyse_text
|
||||
assert not test_obj.skip_extraction
|
||||
assert test_obj.subdict["filename"] == set_testdict[item]["filename"]
|
||||
test_obj = tt.TextDetector(
|
||||
{}, analyse_text=True, skip_extraction=True, accept_privacy=accepted
|
||||
)
|
||||
assert test_obj.analyse_text
|
||||
test_obj = tt.TextDetector({}, skip_extraction=True, accept_privacy=accepted)
|
||||
assert test_obj.skip_extraction
|
||||
with pytest.raises(ValueError):
|
||||
tt.TextDetector({}, analyse_text=1.0, accept_privacy=accepted)
|
||||
with pytest.raises(ValueError):
|
||||
tt.TextDetector({}, skip_extraction=1.0, accept_privacy=accepted)
|
||||
|
||||
|
||||
def test_run_spacy(set_testdict, get_path, accepted):
|
||||
test_obj = tt.TextDetector(
|
||||
set_testdict["IMG_3755"], analyse_text=True, accept_privacy=accepted
|
||||
)
|
||||
test_obj = tt.TextDetector(set_testdict["IMG_3755"], accept_privacy=accepted)
|
||||
ref_file = get_path + "text_IMG_3755.txt"
|
||||
with open(ref_file, "r") as file:
|
||||
reference_text = file.read()
|
||||
@ -108,15 +100,11 @@ def test_analyse_image(set_testdict, set_environ, accepted):
|
||||
for item in set_testdict:
|
||||
test_obj = tt.TextDetector(set_testdict[item], accept_privacy=accepted)
|
||||
test_obj.analyse_image()
|
||||
test_obj = tt.TextDetector(
|
||||
set_testdict[item], analyse_text=True, accept_privacy=accepted
|
||||
)
|
||||
test_obj = tt.TextDetector(set_testdict[item], accept_privacy=accepted)
|
||||
test_obj.analyse_image()
|
||||
testdict = {}
|
||||
testdict["text"] = 20000 * "m"
|
||||
test_obj = tt.TextDetector(
|
||||
testdict, skip_extraction=True, analyse_text=True, accept_privacy=accepted
|
||||
)
|
||||
test_obj = tt.TextDetector(testdict, skip_extraction=True, accept_privacy=accepted)
|
||||
test_obj.analyse_image()
|
||||
assert test_obj.subdict["text_truncated"] == 5000 * "m"
|
||||
assert test_obj.subdict["text"] == 20000 * "m"
|
||||
|
||||
@ -67,7 +67,6 @@ class TextDetector(AnalysisMethod):
|
||||
def __init__(
|
||||
self,
|
||||
subdict: dict,
|
||||
analyse_text: bool = False,
|
||||
skip_extraction: bool = False,
|
||||
accept_privacy: str = "PRIVACY_AMMICO",
|
||||
) -> None:
|
||||
@ -76,8 +75,6 @@ class TextDetector(AnalysisMethod):
|
||||
Args:
|
||||
subdict (dict): Dictionary containing file name/path, and possibly previous
|
||||
analysis results from other modules.
|
||||
analyse_text (bool, optional): Decide if extracted text will be further subject
|
||||
to analysis. Defaults to False.
|
||||
skip_extraction (bool, optional): Decide if text will be extracted from images or
|
||||
is already provided via a csv. Defaults to False.
|
||||
accept_privacy (str, optional): Environment variable to accept the privacy
|
||||
@ -96,17 +93,13 @@ class TextDetector(AnalysisMethod):
|
||||
"Privacy disclosure not accepted - skipping text detection."
|
||||
)
|
||||
self.translator = Translator(raise_exception=True)
|
||||
if not isinstance(analyse_text, bool):
|
||||
raise ValueError("analyse_text needs to be set to true or false")
|
||||
self.analyse_text = analyse_text
|
||||
self.skip_extraction = skip_extraction
|
||||
if not isinstance(skip_extraction, bool):
|
||||
raise ValueError("skip_extraction needs to be set to true or false")
|
||||
if self.skip_extraction:
|
||||
print("Skipping text extraction from image.")
|
||||
print("Reading text directly from provided dictionary.")
|
||||
if self.analyse_text:
|
||||
self._initialize_spacy()
|
||||
self._initialize_spacy()
|
||||
|
||||
def set_keys(self) -> dict:
|
||||
"""Set the default keys for text analysis.
|
||||
@ -183,7 +176,7 @@ class TextDetector(AnalysisMethod):
|
||||
self._truncate_text()
|
||||
self.translate_text()
|
||||
self.remove_linebreaks()
|
||||
if self.analyse_text and self.subdict["text_english"]:
|
||||
if self.subdict["text_english"]:
|
||||
self._run_spacy()
|
||||
return self.subdict
|
||||
|
||||
|
||||
145
ammico/utils.py
145
ammico/utils.py
@ -5,6 +5,9 @@ import pooch
|
||||
import importlib_resources
|
||||
import collections
|
||||
import random
|
||||
from enum import Enum
|
||||
from typing import List, Tuple, Optional, Union
|
||||
import re
|
||||
|
||||
|
||||
pkg = importlib_resources.files("ammico")
|
||||
@ -40,6 +43,41 @@ def ammico_prefetch_models():
|
||||
res.get()
|
||||
|
||||
|
||||
class AnalysisType(str, Enum):
|
||||
SUMMARY = "summary"
|
||||
QUESTIONS = "questions"
|
||||
SUMMARY_AND_QUESTIONS = "summary_and_questions"
|
||||
|
||||
@classmethod
|
||||
def _validate_analysis_type(
|
||||
cls,
|
||||
analysis_type: Union["AnalysisType", str],
|
||||
list_of_questions: Optional[List[str]],
|
||||
) -> Tuple[str, bool, bool]:
|
||||
max_questions_per_image = 15 # safety cap to avoid too many questions
|
||||
if isinstance(analysis_type, AnalysisType):
|
||||
analysis_type = analysis_type.value
|
||||
|
||||
allowed = {item.value for item in AnalysisType}
|
||||
if analysis_type not in allowed:
|
||||
raise ValueError(f"analysis_type must be one of {allowed}")
|
||||
|
||||
if analysis_type in ("questions", "summary_and_questions"):
|
||||
if not list_of_questions:
|
||||
raise ValueError(
|
||||
"list_of_questions must be provided for QUESTIONS analysis type."
|
||||
)
|
||||
|
||||
if len(list_of_questions) > max_questions_per_image:
|
||||
raise ValueError(
|
||||
f"Number of questions per image ({len(list_of_questions)}) exceeds safety cap ({max_questions_per_image}). Reduce questions or increase max_questions_per_image."
|
||||
)
|
||||
|
||||
is_summary = analysis_type in ("summary", "summary_and_questions")
|
||||
is_questions = analysis_type in ("questions", "summary_and_questions")
|
||||
return analysis_type, is_summary, is_questions
|
||||
|
||||
|
||||
class AnalysisMethod:
|
||||
"""Base class to be inherited by all analysis methods."""
|
||||
|
||||
@ -94,6 +132,113 @@ def _limit_results(results, limit):
|
||||
return results
|
||||
|
||||
|
||||
def _categorize_outputs(
|
||||
collected: List[Tuple[float, str]],
|
||||
include_questions: bool = False,
|
||||
) -> Tuple[List[str], List[str]]:
|
||||
"""
|
||||
Categorize collected outputs into summary bullets and VQA bullets.
|
||||
Args:
|
||||
collected (List[Tuple[float, str]]): List of tuples containing timestamps and generated texts.
|
||||
Returns:
|
||||
Tuple[List[str], List[str]]: A tuple containing two lists - summary bullets and VQA bullets.
|
||||
"""
|
||||
MAX_CAPTIONS_FOR_SUMMARY = 600 # TODO For now, this is a constant value, but later we need to make it adjustable, with the idea of cutting out the most similar frames to reduce the load on the system.
|
||||
caps_for_summary_vqa = (
|
||||
collected[-MAX_CAPTIONS_FOR_SUMMARY:]
|
||||
if len(collected) > MAX_CAPTIONS_FOR_SUMMARY
|
||||
else collected
|
||||
)
|
||||
bullets_summary = []
|
||||
bullets_vqa = []
|
||||
|
||||
for t, c in caps_for_summary_vqa:
|
||||
if include_questions:
|
||||
result_sections = c.strip()
|
||||
m = re.search(
|
||||
r"Summary\s*:\s*(.*?)\s*(?:VQA\s+Answers\s*:\s*(.*))?$",
|
||||
result_sections,
|
||||
flags=re.IGNORECASE | re.DOTALL,
|
||||
)
|
||||
if m:
|
||||
summary_text = (
|
||||
m.group(1).replace("\n", " ").strip() if m.group(1) else None
|
||||
)
|
||||
vqa_text = m.group(2).strip() if m.group(2) else None
|
||||
if not summary_text or not vqa_text:
|
||||
raise ValueError(
|
||||
f"Model output is missing either summary or VQA answers: {c}"
|
||||
)
|
||||
bullets_summary.append(f"- [{t:.3f}s] {summary_text}")
|
||||
bullets_vqa.append(f"- [{t:.3f}s] {vqa_text}")
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Failed to parse summary and VQA answers from model output: {c}"
|
||||
)
|
||||
else:
|
||||
snippet = c.replace("\n", " ").strip()
|
||||
bullets_summary.append(f"- [{t:.3f}s] {snippet}")
|
||||
return bullets_summary, bullets_vqa
|
||||
|
||||
|
||||
def _normalize_whitespace(s: str) -> str:
|
||||
return re.sub(r"\s+", " ", s).strip()
|
||||
|
||||
|
||||
def _strip_prompt_prefix_literal(decoded: str, prompt: str) -> str:
|
||||
"""
|
||||
Remove any literal prompt prefix from decoded text using a normalized-substring match.
|
||||
Guarantees no prompt text remains at the start of returned string (best-effort).
|
||||
"""
|
||||
if not decoded:
|
||||
return ""
|
||||
if not prompt:
|
||||
return decoded.strip()
|
||||
|
||||
d_norm = _normalize_whitespace(decoded)
|
||||
p_norm = _normalize_whitespace(prompt)
|
||||
|
||||
idx = d_norm.find(p_norm)
|
||||
if idx != -1:
|
||||
running = []
|
||||
for i, ch in enumerate(decoded):
|
||||
running.append(ch if not ch.isspace() else " ")
|
||||
cur_norm = _normalize_whitespace("".join(running))
|
||||
if cur_norm.endswith(p_norm):
|
||||
return decoded[i + 1 :].lstrip() if i + 1 < len(decoded) else ""
|
||||
m = re.match(
|
||||
r"^(?:\s*(system|user|assistant)[:\s-]*\n?)+", decoded, flags=re.IGNORECASE
|
||||
)
|
||||
if m:
|
||||
return decoded[m.end() :].lstrip()
|
||||
|
||||
return decoded.lstrip("\n\r ").lstrip(":;- ").strip()
|
||||
|
||||
|
||||
def find_videos(
|
||||
path: str = None,
|
||||
pattern=["mp4"], # TODO: test with more video formats
|
||||
recursive: bool = True,
|
||||
limit=5,
|
||||
random_seed: int = None,
|
||||
) -> dict:
|
||||
"""Find video files on the file system."""
|
||||
if path is None:
|
||||
path = os.environ.get("AMMICO_DATA_HOME", ".")
|
||||
if isinstance(pattern, str):
|
||||
pattern = [pattern]
|
||||
results = []
|
||||
for p in pattern:
|
||||
results.extend(_match_pattern(path, p, recursive=recursive))
|
||||
if len(results) == 0:
|
||||
raise FileNotFoundError(f"No files found in {path} with pattern '{pattern}'")
|
||||
if random_seed is not None:
|
||||
random.seed(random_seed)
|
||||
random.shuffle(results)
|
||||
videos = _limit_results(results, limit)
|
||||
return initialize_dict(videos)
|
||||
|
||||
|
||||
def find_files(
|
||||
path: str = None,
|
||||
pattern=["png", "jpg", "jpeg", "gif", "webp", "avif", "tiff"],
|
||||
|
||||
514
ammico/video_summary.py
Обычный файл
514
ammico/video_summary.py
Обычный файл
@ -0,0 +1,514 @@
|
||||
import re
|
||||
import math
|
||||
import torch
|
||||
import warnings
|
||||
from PIL import Image
|
||||
from torchcodec.decoders import VideoDecoder
|
||||
|
||||
from ammico.model import MultimodalSummaryModel
|
||||
from ammico.utils import (
|
||||
AnalysisMethod,
|
||||
AnalysisType,
|
||||
_categorize_outputs,
|
||||
_strip_prompt_prefix_literal,
|
||||
)
|
||||
|
||||
from typing import List, Dict, Any, Generator, Tuple, Union, Optional
|
||||
from transformers import GenerationConfig
|
||||
|
||||
|
||||
class VideoSummaryDetector(AnalysisMethod):
|
||||
MAX_SAMPLES_CAP = 1000 # safety cap for total extracted frames
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
summary_model: MultimodalSummaryModel,
|
||||
subdict: Optional[Dict[str, Any]] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Class for analysing videos using QWEN-2.5-VL model.
|
||||
It provides methods for generating captions and answering questions about videos.
|
||||
|
||||
Args:
|
||||
summary_model ([type], optional): An instance of MultimodalSummaryModel to be used for analysis.
|
||||
subdict (dict, optional): Dictionary containing the video to be analysed. Defaults to {}.
|
||||
|
||||
Returns:
|
||||
None.
|
||||
"""
|
||||
if subdict is None:
|
||||
subdict = {}
|
||||
|
||||
super().__init__(subdict)
|
||||
self.summary_model = summary_model
|
||||
|
||||
def _frame_batch_generator(
|
||||
self,
|
||||
timestamps: torch.Tensor,
|
||||
batch_size: int,
|
||||
video_decoder: VideoDecoder,
|
||||
) -> Generator[Tuple[torch.Tensor, torch.Tensor], None, None]:
|
||||
"""
|
||||
Yield batches of (frames, timestamps) for given frame indices.
|
||||
- frames are returned as a torch.Tensor with shape (B, C, H, W).
|
||||
- timestamps is a 1D torch.Tensor with B elements.
|
||||
"""
|
||||
total = int(timestamps.numel())
|
||||
|
||||
for start in range(0, total, batch_size):
|
||||
batch_secs = timestamps[start : start + batch_size].tolist()
|
||||
fb = video_decoder.get_frames_played_at(batch_secs)
|
||||
frames = fb.data
|
||||
|
||||
if not frames.is_contiguous():
|
||||
frames = frames.contiguous()
|
||||
pts = fb.pts_seconds
|
||||
pts_out = (
|
||||
pts.cpu().to(dtype=torch.float32)
|
||||
if isinstance(pts, torch.Tensor)
|
||||
else torch.tensor(pts, dtype=torch.float32)
|
||||
)
|
||||
yield frames, pts_out
|
||||
|
||||
def _extract_video_frames(
|
||||
self,
|
||||
entry: Dict[str, Any],
|
||||
frame_rate_per_second: float = 2.0,
|
||||
batch_size: int = 32,
|
||||
) -> Generator[Tuple[torch.Tensor, torch.Tensor], None, None]:
|
||||
"""
|
||||
Extract frames from a video at a specified frame rate and return them as a generator of batches.
|
||||
Args:
|
||||
filename (Union[str, os.PathLike]): Path to the video file.
|
||||
frame_rate_per_second (float, optional): Frame extraction rate in frames per second. Default is 2.
|
||||
batch_size (int, optional): Number of frames to include in each batch. Default is 32.
|
||||
Returns:
|
||||
Generator[Tuple[torch.Tensor, torch.Tensor], None, None]: A generator yielding tuples of
|
||||
(frames, timestamps), where frames is a tensor of shape (B, C, H, W) and timestamps is a 1D tensor of length B.
|
||||
"""
|
||||
|
||||
filename = entry.get("filename")
|
||||
if not filename:
|
||||
raise ValueError("entry must contain key 'filename'")
|
||||
|
||||
video_decoder = VideoDecoder(filename)
|
||||
meta = video_decoder.metadata
|
||||
|
||||
video_fps = getattr(meta, "average_fps", None)
|
||||
if video_fps is None or not (
|
||||
isinstance(video_fps, (int, float)) and video_fps > 0
|
||||
):
|
||||
video_fps = 30.0
|
||||
|
||||
begin_stream_seconds = getattr(meta, "begin_stream_seconds", None)
|
||||
end_stream_seconds = getattr(meta, "end_stream_seconds", None)
|
||||
nframes = len(video_decoder)
|
||||
if getattr(meta, "duration_seconds", None) is not None:
|
||||
duration = float(meta.duration_seconds)
|
||||
elif begin_stream_seconds is not None and end_stream_seconds is not None:
|
||||
duration = float(end_stream_seconds) - float(begin_stream_seconds)
|
||||
elif nframes:
|
||||
duration = float(nframes) / float(video_fps)
|
||||
else:
|
||||
duration = 0.0
|
||||
|
||||
if frame_rate_per_second <= 0:
|
||||
raise ValueError("frame_rate_per_second must be > 0")
|
||||
|
||||
n_samples = max(1, int(math.floor(duration * frame_rate_per_second)))
|
||||
n_samples = min(n_samples, self.MAX_SAMPLES_CAP)
|
||||
|
||||
if begin_stream_seconds is not None and end_stream_seconds is not None:
|
||||
sample_times = torch.linspace(
|
||||
float(begin_stream_seconds), float(end_stream_seconds), steps=n_samples
|
||||
)
|
||||
if sample_times.numel() > 1:
|
||||
sample_times = torch.clamp(
|
||||
sample_times,
|
||||
min=float(begin_stream_seconds),
|
||||
max=float(end_stream_seconds) - 1e-6,
|
||||
)
|
||||
else:
|
||||
sample_times = torch.linspace(0.0, max(0.0, duration), steps=n_samples)
|
||||
|
||||
sample_times = sample_times.to(dtype=torch.float32, device="cpu")
|
||||
generator = self._frame_batch_generator(sample_times, batch_size, video_decoder)
|
||||
|
||||
return generator
|
||||
|
||||
def _decode_trimmed_outputs(
|
||||
self,
|
||||
generated_ids: torch.Tensor,
|
||||
inputs: Dict[str, torch.Tensor],
|
||||
tokenizer,
|
||||
prompt_texts: List[str],
|
||||
) -> List[str]:
|
||||
"""
|
||||
Trim prompt tokens using attention_mask/input_ids when available and decode to strings.
|
||||
Then remove any literal prompt prefix using prompt_texts (one per batch element).
|
||||
"""
|
||||
|
||||
batch_size = generated_ids.shape[0]
|
||||
|
||||
if "input_ids" in inputs:
|
||||
token_for_padding = (
|
||||
tokenizer.pad_token_id
|
||||
if getattr(tokenizer, "pad_token_id", None) is not None
|
||||
else getattr(tokenizer, "eos_token_id", None)
|
||||
)
|
||||
if token_for_padding is None:
|
||||
lengths = [int(inputs["input_ids"].shape[1])] * batch_size
|
||||
else:
|
||||
lengths = inputs["input_ids"].ne(token_for_padding).sum(dim=1).tolist()
|
||||
else:
|
||||
lengths = [0] * batch_size
|
||||
|
||||
trimmed_ids = []
|
||||
for i in range(batch_size):
|
||||
out_ids = generated_ids[i]
|
||||
in_len = int(lengths[i]) if i < len(lengths) else 0
|
||||
if out_ids.size(0) > in_len:
|
||||
t = out_ids[in_len:]
|
||||
else:
|
||||
t = out_ids.new_empty((0,), dtype=out_ids.dtype)
|
||||
t_cpu = t.to("cpu")
|
||||
trimmed_ids.append(t_cpu.tolist())
|
||||
|
||||
decoded = tokenizer.batch_decode(
|
||||
trimmed_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True
|
||||
)
|
||||
decoded_results = []
|
||||
for ptext, raw in zip(prompt_texts, decoded):
|
||||
cleaned = _strip_prompt_prefix_literal(raw, ptext)
|
||||
decoded_results.append(cleaned)
|
||||
return decoded_results
|
||||
|
||||
def _generate_from_processor_inputs(
|
||||
self,
|
||||
processor_inputs: Dict[str, torch.Tensor],
|
||||
prompt_texts: List[str],
|
||||
tokenizer,
|
||||
):
|
||||
"""
|
||||
Run model.generate on already-processed processor_inputs (tensors moved to device),
|
||||
then decode and trim prompt tokens & remove literal prompt prefixes using prompt_texts.
|
||||
"""
|
||||
gen_conf = GenerationConfig(
|
||||
max_new_tokens=64,
|
||||
do_sample=False,
|
||||
num_return_sequences=1,
|
||||
)
|
||||
|
||||
for k, v in list(processor_inputs.items()):
|
||||
if isinstance(v, torch.Tensor):
|
||||
processor_inputs[k] = v.to(self.summary_model.device)
|
||||
|
||||
with torch.inference_mode():
|
||||
try:
|
||||
if self.summary_model.device == "cuda":
|
||||
with torch.amp.autocast("cuda", enabled=True):
|
||||
generated_ids = self.summary_model.model.generate(
|
||||
**processor_inputs, generation_config=gen_conf
|
||||
)
|
||||
else:
|
||||
generated_ids = self.summary_model.model.generate(
|
||||
**processor_inputs, generation_config=gen_conf
|
||||
)
|
||||
except RuntimeError as e:
|
||||
warnings.warn(
|
||||
f"Generation failed with error: {e}. Retrying with cuDNN disabled.",
|
||||
RuntimeWarning,
|
||||
)
|
||||
cudnn_was_enabled = (
|
||||
torch.backends.cudnn.is_available() and torch.backends.cudnn.enabled
|
||||
)
|
||||
if cudnn_was_enabled:
|
||||
torch.backends.cudnn.enabled = False
|
||||
try:
|
||||
generated_ids = self.summary_model.model.generate(
|
||||
**processor_inputs, generation_config=gen_conf
|
||||
)
|
||||
except Exception as retry_error:
|
||||
raise RuntimeError(
|
||||
f"Failed to generate ids after retry: {retry_error}"
|
||||
) from retry_error
|
||||
finally:
|
||||
if cudnn_was_enabled:
|
||||
torch.backends.cudnn.enabled = True
|
||||
|
||||
decoded = self._decode_trimmed_outputs(
|
||||
generated_ids, processor_inputs, tokenizer, prompt_texts
|
||||
)
|
||||
return decoded
|
||||
|
||||
def _tensor_batch_to_pil_list(self, batch: torch.Tensor) -> List[Image.Image]:
|
||||
"""
|
||||
Convert a uint8 torch tensor batch (B, C, H, W) on CPU to list of PIL images (RGB).
|
||||
The conversion is done on CPU and returns PIL.Image objects.
|
||||
"""
|
||||
if batch.device.type != "cpu":
|
||||
batch = batch.to("cpu")
|
||||
|
||||
batch = batch.contiguous()
|
||||
if batch.dtype.is_floating_point:
|
||||
batch = (batch.clamp(0.0, 1.0) * 255.0).to(torch.uint8)
|
||||
elif batch.dtype != torch.uint8:
|
||||
batch = batch.to(torch.uint8)
|
||||
pil_list: List[Image.Image] = []
|
||||
for frame in batch:
|
||||
arr = frame.permute(1, 2, 0).numpy()
|
||||
pil_list.append(Image.fromarray(arr))
|
||||
return pil_list
|
||||
|
||||
def make_captions_from_extracted_frames(
|
||||
self,
|
||||
extracted_video_gen: Generator[Tuple[torch.Tensor, torch.Tensor], None, None],
|
||||
list_of_questions: Optional[List[str]] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Generate captions for all extracted frames and then produce a concise summary of the video.
|
||||
Args:
|
||||
extracted_video_dict (Dict[str, Any]): Dictionary containing the frame generator and number of frames.
|
||||
summary_instruction (str, optional): Instruction for summarizing the captions. Defaults to a concise paragraph.
|
||||
Returns:
|
||||
Dict[str, Any]: A dictionary containing the list of captions with timestamps and the final summary.
|
||||
"""
|
||||
|
||||
caption_instruction = "Describe this image in one concise caption."
|
||||
include_questions = bool(list_of_questions)
|
||||
if include_questions:
|
||||
q_block = "\n".join(
|
||||
[f"{i + 1}. {q.strip()}" for i, q in enumerate(list_of_questions)]
|
||||
)
|
||||
caption_instruction += (
|
||||
" In addition to the concise caption, also answer the following questions based ONLY on the image. Answers must be very brief and concise."
|
||||
" Produce exactly two labeled sections: \n\n"
|
||||
"Summary: <concise summary>\n\n"
|
||||
"VQA Answers: \n1. <answer to question 1>\n2. <answer to question 2>\n etc."
|
||||
"\nReturn only those two sections for each image (do not add extra commentary)."
|
||||
"\nIf the answer cannot be determined based on the provided answer blocks,"
|
||||
' reply with the line "The answer cannot be determined based on the information provided."'
|
||||
f"\n\nQuestions:\n{q_block}"
|
||||
)
|
||||
collected: List[Tuple[float, str]] = []
|
||||
proc = self.summary_model.processor
|
||||
try:
|
||||
for batch_frames, batch_times in extracted_video_gen:
|
||||
pil_list = self._tensor_batch_to_pil_list(batch_frames.cpu())
|
||||
|
||||
prompt_texts = []
|
||||
for p in pil_list:
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image", "image": p},
|
||||
{"type": "text", "text": caption_instruction},
|
||||
],
|
||||
}
|
||||
]
|
||||
|
||||
prompt_text = proc.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
prompt_texts.append(prompt_text)
|
||||
|
||||
processor_inputs = proc(
|
||||
text=prompt_texts,
|
||||
images=pil_list,
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
)
|
||||
captions = self._generate_from_processor_inputs(
|
||||
processor_inputs,
|
||||
prompt_texts,
|
||||
self.summary_model.tokenizer,
|
||||
)
|
||||
|
||||
if isinstance(batch_times, torch.Tensor):
|
||||
batch_times_list = batch_times.cpu().tolist()
|
||||
else:
|
||||
batch_times_list = list(batch_times)
|
||||
for t, c in zip(batch_times_list, captions):
|
||||
collected.append((float(t), c))
|
||||
finally:
|
||||
try:
|
||||
extracted_video_gen.close()
|
||||
except Exception:
|
||||
warnings.warn("Failed to close video frame generator.", RuntimeWarning)
|
||||
|
||||
collected.sort(key=lambda x: x[0])
|
||||
bullets_summary, bullets_vqa = _categorize_outputs(collected, include_questions)
|
||||
|
||||
return {
|
||||
"summary_bullets": bullets_summary,
|
||||
"vqa_bullets": bullets_vqa,
|
||||
} # TODO consider taking out time stamps from the returned structure
|
||||
|
||||
def final_summary(self, summary_dict: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
Produce a concise summary of the video, based on generated captions for all extracted frames.
|
||||
Args:
|
||||
summary_dict (Dict[str, Any]): Dictionary containing captions for the frames.
|
||||
Returns:
|
||||
Dict[str, Any]: A dictionary containing the list of captions with timestamps and the final summary.
|
||||
"""
|
||||
summary_instruction = "Analyze the following captions from multiple frames of the same video and summarize the overall content of the video in one concise paragraph (1-3 sentences). Focus on the key themes, actions, or events across the video, not just the individual frames."
|
||||
proc = self.summary_model.processor
|
||||
|
||||
bullets = summary_dict.get("summary_bullets", [])
|
||||
if not bullets:
|
||||
raise ValueError("No captions available for summary generation.")
|
||||
|
||||
combined_captions_text = "\n".join(bullets)
|
||||
summary_user_text = summary_instruction + "\n\n" + combined_captions_text + "\n"
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [{"type": "text", "text": summary_user_text}],
|
||||
}
|
||||
]
|
||||
|
||||
summary_prompt_text = proc.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
|
||||
summary_inputs = proc(
|
||||
text=[summary_prompt_text], return_tensors="pt", padding=True
|
||||
)
|
||||
|
||||
summary_inputs = {
|
||||
k: v.to(self.summary_model.device) if isinstance(v, torch.Tensor) else v
|
||||
for k, v in summary_inputs.items()
|
||||
}
|
||||
final_summary_list = self._generate_from_processor_inputs(
|
||||
summary_inputs,
|
||||
[summary_prompt_text],
|
||||
self.summary_model.tokenizer,
|
||||
)
|
||||
final_summary = final_summary_list[0].strip() if final_summary_list else ""
|
||||
|
||||
return {
|
||||
"summary": final_summary,
|
||||
}
|
||||
|
||||
def final_answers(
|
||||
self,
|
||||
answers_dict: Dict[str, Any],
|
||||
list_of_questions: List[str],
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Answer the list of questions for the video based on the VQA bullets from the frames.
|
||||
Args:
|
||||
answers_dict (Dict[str, Any]): Dictionary containing the VQA bullets.
|
||||
Returns:
|
||||
Dict[str, Any]: A dictionary containing the list of answers to the questions.
|
||||
"""
|
||||
vqa_bullets = answers_dict.get("vqa_bullets", [])
|
||||
if not vqa_bullets:
|
||||
raise ValueError(
|
||||
"No VQA bullets generated for single frames available for answering questions."
|
||||
)
|
||||
|
||||
include_questions = bool(list_of_questions)
|
||||
if include_questions:
|
||||
q_block = "\n".join(
|
||||
[f"{i + 1}. {q.strip()}" for i, q in enumerate(list_of_questions)]
|
||||
)
|
||||
prompt = (
|
||||
"You are provided with a set of short VQA-captions, each of which is a block of short answers"
|
||||
" extracted from individual frames of the same video.\n\n"
|
||||
"VQA-captions (use ONLY these to answer):\n"
|
||||
f"{vqa_bullets}\n\n"
|
||||
"Answer the following questions briefly, based ONLY on the lists of answers provided above. The VQA-captions above contain answers"
|
||||
" to the same questions you are about to answer. If the answer cannot be determined based on the provided answer blocks,"
|
||||
' reply with the line "The answer cannot be determined based on the information provided."'
|
||||
"Questions:\n"
|
||||
f"{q_block}\n\n"
|
||||
"Produce an ordered list with answers in the same order as the questions. You must have this structure of your output: "
|
||||
"Answers: \n1. <answer to question 1>\n2. <answer to question 2>\n etc."
|
||||
"Return ONLY the ordered list with answers and NOTHING else — no commentary, no explanation, no surrounding markdown."
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"list_of_questions must be provided for making final answers."
|
||||
)
|
||||
|
||||
proc = self.summary_model.processor
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [{"type": "text", "text": prompt}],
|
||||
}
|
||||
]
|
||||
final_vqa_prompt_text = proc.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
final_vqa_inputs = proc(
|
||||
text=[final_vqa_prompt_text], return_tensors="pt", padding=True
|
||||
)
|
||||
final_vqa_inputs = {
|
||||
k: v.to(self.summary_model.device) if isinstance(v, torch.Tensor) else v
|
||||
for k, v in final_vqa_inputs.items()
|
||||
}
|
||||
|
||||
final_vqa_list = self._generate_from_processor_inputs(
|
||||
final_vqa_inputs,
|
||||
[final_vqa_prompt_text],
|
||||
self.summary_model.tokenizer,
|
||||
)
|
||||
|
||||
final_vqa_output = final_vqa_list[0].strip() if final_vqa_list else ""
|
||||
vqa_answers = []
|
||||
answer_matches = re.findall(
|
||||
r"\d+\.\s*(.*?)(?=\n\d+\.|$)", final_vqa_output, flags=re.DOTALL
|
||||
)
|
||||
for answer in answer_matches:
|
||||
vqa_answers.append(answer.strip())
|
||||
return {
|
||||
"vqa_answers": vqa_answers,
|
||||
}
|
||||
|
||||
def analyse_videos_from_dict(
|
||||
self,
|
||||
analysis_type: Union[AnalysisType, str] = AnalysisType.SUMMARY,
|
||||
frame_rate_per_second: float = 2.0,
|
||||
list_of_questions: Optional[List[str]] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Analyse the video specified in self.subdict using frame extraction and captioning.
|
||||
Args:
|
||||
analysis_type (Union[AnalysisType, str], optional): Type of analysis to perform. Defaults to AnalysisType.SUMMARY.
|
||||
frame_rate_per_second (float): Frame extraction rate in frames per second. Default is 2.0.
|
||||
list_of_questions (List[str], optional): List of questions to answer about the video. Required if analysis_type includes questions.
|
||||
Returns:
|
||||
Dict[str, Any]: A dictionary containing the analysis results, including summary and answers for provided questions(if any).
|
||||
"""
|
||||
|
||||
all_answers = {}
|
||||
analysis_type, is_summary, is_questions = AnalysisType._validate_analysis_type(
|
||||
analysis_type, list_of_questions
|
||||
)
|
||||
|
||||
for video_key, entry in self.subdict.items():
|
||||
summary_and_vqa = {}
|
||||
frames_generator = self._extract_video_frames(
|
||||
entry, frame_rate_per_second=frame_rate_per_second
|
||||
)
|
||||
|
||||
answers_dict = self.make_captions_from_extracted_frames(
|
||||
frames_generator, list_of_questions=list_of_questions
|
||||
)
|
||||
# TODO: captions has to be post-processed with foreseeing audio analysis
|
||||
# TODO: captions and answers may lead to prompt, that superior model limits. Consider hierarchical approach.
|
||||
if is_summary:
|
||||
answer = self.final_summary(answers_dict)
|
||||
summary_and_vqa["summary"] = answer["summary"]
|
||||
if is_questions:
|
||||
answer = self.final_answers(answers_dict, list_of_questions)
|
||||
summary_and_vqa["vqa_answers"] = answer["vqa_answers"]
|
||||
|
||||
all_answers[video_key] = summary_and_vqa
|
||||
|
||||
return all_answers
|
||||
20
environment.yml
Обычный файл
20
environment.yml
Обычный файл
@ -0,0 +1,20 @@
|
||||
name: ammico-dev
|
||||
channels:
|
||||
- pytorch
|
||||
- nvidia
|
||||
- rapidsai
|
||||
- conda-forge
|
||||
- defaults
|
||||
|
||||
dependencies:
|
||||
- python=3.11
|
||||
- cudatoolkit=11.8
|
||||
- pytorch=2.5.1
|
||||
- pytorch-cuda=11.8
|
||||
- torchvision=0.20.1
|
||||
- torchaudio=2.5.1
|
||||
- faiss-gpu-raft=1.8.0
|
||||
- ipykernel
|
||||
- jupyterlab
|
||||
- jupyterlab_widgets
|
||||
- ffmpeg<8
|
||||
@ -18,16 +18,19 @@ classifiers = [
|
||||
"Operating System :: OS Independent",
|
||||
"License :: OSI Approved :: MIT License",
|
||||
]
|
||||
|
||||
|
||||
dependencies = [
|
||||
"accelerate>=0.22",
|
||||
"bitsandbytes",
|
||||
"colorgram.py",
|
||||
"colour-science",
|
||||
"dash",
|
||||
"dash-bootstrap-components",
|
||||
"deepface",
|
||||
"google-cloud-vision",
|
||||
"googletrans==4.0.0rc1",
|
||||
"googletrans-py", # instead of googletrans4.0.0rc1, for a temporary solution due to incompatibility with jupyterlab
|
||||
"grpcio",
|
||||
"huggingface-hub>=0.34.0",
|
||||
"importlib_metadata",
|
||||
"importlib_resources",
|
||||
"matplotlib",
|
||||
@ -36,12 +39,17 @@ dependencies = [
|
||||
"pandas",
|
||||
"Pillow",
|
||||
"pooch",
|
||||
"qwen-vl-utils",
|
||||
"retina_face",
|
||||
"safetensors>=0.6.2",
|
||||
"setuptools",
|
||||
"spacy",
|
||||
"tensorflow<=2.16.0",
|
||||
"tensorflow<2.15", # instead of <=2.16.0 to make it compatible with CUDA 11.8, may change after updating CUDA version.
|
||||
"tf-keras",
|
||||
"torchvision",
|
||||
"tqdm",
|
||||
"transformers>=4.54",
|
||||
"torchcodec<0.2",
|
||||
"webcolors",
|
||||
]
|
||||
|
||||
|
||||
Загрузка…
x
Ссылка в новой задаче
Block a user