зеркало из
https://github.com/ssciwr/AMMICO.git
synced 2025-10-30 13:36:04 +02:00
Merge 3e036514cd3c05e70bd74eddeabe5bdb09c38890 into 3f9e855aebddf6eddaa81de1cd883bc5bcf5d3bc
Этот коммит содержится в:
Коммит
2ff2dffc59
2
.github/workflows/ci.yml
поставляемый
2
.github/workflows/ci.yml
поставляемый
@ -31,7 +31,7 @@ jobs:
|
|||||||
- name: Run pytest
|
- name: Run pytest
|
||||||
run: |
|
run: |
|
||||||
cd ammico
|
cd ammico
|
||||||
python -m pytest -svv -m "not gcv" --cov=. --cov-report=xml
|
python -m pytest -svv -m "not gcv and not long" --cov=. --cov-report=xml
|
||||||
- name: Upload coverage
|
- name: Upload coverage
|
||||||
if: matrix.os == 'ubuntu-latest' && matrix.python-version == '3.11'
|
if: matrix.os == 'ubuntu-latest' && matrix.python-version == '3.11'
|
||||||
uses: codecov/codecov-action@v3
|
uses: codecov/codecov-action@v3
|
||||||
|
|||||||
@ -1,7 +1,10 @@
|
|||||||
from ammico.display import AnalysisExplorer
|
from ammico.display import AnalysisExplorer
|
||||||
from ammico.faces import EmotionDetector, ethical_disclosure
|
from ammico.faces import EmotionDetector, ethical_disclosure
|
||||||
|
from ammico.model import MultimodalSummaryModel
|
||||||
from ammico.text import TextDetector, TextAnalyzer, privacy_disclosure
|
from ammico.text import TextDetector, TextAnalyzer, privacy_disclosure
|
||||||
from ammico.utils import find_files, get_dataframe
|
from ammico.image_summary import ImageSummaryDetector
|
||||||
|
from ammico.utils import find_files, get_dataframe, AnalysisType, find_videos
|
||||||
|
from ammico.video_summary import VideoSummaryDetector
|
||||||
|
|
||||||
# Export the version defined in project metadata
|
# Export the version defined in project metadata
|
||||||
try:
|
try:
|
||||||
@ -12,11 +15,16 @@ except ImportError:
|
|||||||
__version__ = "unknown"
|
__version__ = "unknown"
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
"AnalysisType",
|
||||||
"AnalysisExplorer",
|
"AnalysisExplorer",
|
||||||
"EmotionDetector",
|
"EmotionDetector",
|
||||||
|
"MultimodalSummaryModel",
|
||||||
"TextDetector",
|
"TextDetector",
|
||||||
"TextAnalyzer",
|
"TextAnalyzer",
|
||||||
|
"ImageSummaryDetector",
|
||||||
|
"VideoSummaryDetector",
|
||||||
"find_files",
|
"find_files",
|
||||||
|
"find_videos",
|
||||||
"get_dataframe",
|
"get_dataframe",
|
||||||
"ethical_disclosure",
|
"ethical_disclosure",
|
||||||
"privacy_disclosure",
|
"privacy_disclosure",
|
||||||
|
|||||||
@ -1,10 +1,14 @@
|
|||||||
import ammico.faces as faces
|
import ammico.faces as faces
|
||||||
import ammico.text as text
|
import ammico.text as text
|
||||||
import ammico.colors as colors
|
import ammico.colors as colors
|
||||||
|
import ammico.image_summary as image_summary
|
||||||
|
from ammico.model import MultimodalSummaryModel
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
from dash import html, Input, Output, dcc, State, Dash
|
from dash import html, Input, Output, dcc, State, Dash
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
import dash_bootstrap_components as dbc
|
import dash_bootstrap_components as dbc
|
||||||
|
import warnings
|
||||||
|
from typing import Dict, Any, List, Optional
|
||||||
|
|
||||||
|
|
||||||
COLOR_SCHEMES = [
|
COLOR_SCHEMES = [
|
||||||
@ -94,6 +98,8 @@ class AnalysisExplorer:
|
|||||||
State("left_select_id", "options"),
|
State("left_select_id", "options"),
|
||||||
State("left_select_id", "value"),
|
State("left_select_id", "value"),
|
||||||
State("Dropdown_select_Detector", "value"),
|
State("Dropdown_select_Detector", "value"),
|
||||||
|
State("Dropdown_analysis_type", "value"),
|
||||||
|
State("textarea_questions", "value"),
|
||||||
State("setting_Text_analyse_text", "value"),
|
State("setting_Text_analyse_text", "value"),
|
||||||
State("setting_privacy_env_var", "value"),
|
State("setting_privacy_env_var", "value"),
|
||||||
State("setting_Emotion_emotion_threshold", "value"),
|
State("setting_Emotion_emotion_threshold", "value"),
|
||||||
@ -108,9 +114,15 @@ class AnalysisExplorer:
|
|||||||
Output("settings_TextDetector", "style"),
|
Output("settings_TextDetector", "style"),
|
||||||
Output("settings_EmotionDetector", "style"),
|
Output("settings_EmotionDetector", "style"),
|
||||||
Output("settings_ColorDetector", "style"),
|
Output("settings_ColorDetector", "style"),
|
||||||
|
Output("settings_VQA", "style"),
|
||||||
Input("Dropdown_select_Detector", "value"),
|
Input("Dropdown_select_Detector", "value"),
|
||||||
)(self._update_detector_setting)
|
)(self._update_detector_setting)
|
||||||
|
|
||||||
|
self.app.callback(
|
||||||
|
Output("textarea_questions", "style"),
|
||||||
|
Input("Dropdown_analysis_type", "value"),
|
||||||
|
)(self._show_questions_textarea_on_demand)
|
||||||
|
|
||||||
# I split the different sections into subfunctions for better clarity
|
# I split the different sections into subfunctions for better clarity
|
||||||
def _top_file_explorer(self, mydict: dict) -> html.Div:
|
def _top_file_explorer(self, mydict: dict) -> html.Div:
|
||||||
"""Initialize the file explorer dropdown for selecting the file to be analyzed.
|
"""Initialize the file explorer dropdown for selecting the file to be analyzed.
|
||||||
@ -157,14 +169,6 @@ class AnalysisExplorer:
|
|||||||
id="settings_TextDetector",
|
id="settings_TextDetector",
|
||||||
style={"display": "none"},
|
style={"display": "none"},
|
||||||
children=[
|
children=[
|
||||||
dbc.Row(
|
|
||||||
dcc.Checklist(
|
|
||||||
["Analyse text"],
|
|
||||||
["Analyse text"],
|
|
||||||
id="setting_Text_analyse_text",
|
|
||||||
style={"margin-bottom": "10px"},
|
|
||||||
),
|
|
||||||
),
|
|
||||||
# row 1
|
# row 1
|
||||||
dbc.Row(
|
dbc.Row(
|
||||||
dbc.Col(
|
dbc.Col(
|
||||||
@ -272,8 +276,69 @@ class AnalysisExplorer:
|
|||||||
)
|
)
|
||||||
],
|
],
|
||||||
),
|
),
|
||||||
|
# start VQA settings
|
||||||
|
html.Div(
|
||||||
|
id="settings_VQA",
|
||||||
|
style={"display": "none"},
|
||||||
|
children=[
|
||||||
|
dbc.Card(
|
||||||
|
[
|
||||||
|
dbc.CardBody(
|
||||||
|
[
|
||||||
|
dbc.Row(
|
||||||
|
dbc.Col(
|
||||||
|
dcc.Dropdown(
|
||||||
|
id="Dropdown_analysis_type",
|
||||||
|
options=[
|
||||||
|
{"label": v, "value": v}
|
||||||
|
for v in SUMMARY_ANALYSIS_TYPE
|
||||||
|
],
|
||||||
|
value="summary_and_questions",
|
||||||
|
clearable=False,
|
||||||
|
style={
|
||||||
|
"width": "100%",
|
||||||
|
"minWidth": "240px",
|
||||||
|
"maxWidth": "520px",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
),
|
||||||
|
justify="start",
|
||||||
|
),
|
||||||
|
html.Div(style={"height": "8px"}),
|
||||||
|
dbc.Row(
|
||||||
|
[
|
||||||
|
dbc.Col(
|
||||||
|
dcc.Textarea(
|
||||||
|
id="textarea_questions",
|
||||||
|
value="Are there people in the image?\nWhat is this picture about?",
|
||||||
|
placeholder="One question per line...",
|
||||||
|
style={
|
||||||
|
"width": "100%",
|
||||||
|
"minHeight": "160px",
|
||||||
|
"height": "220px",
|
||||||
|
"resize": "vertical",
|
||||||
|
"overflow": "auto",
|
||||||
|
},
|
||||||
|
rows=8,
|
||||||
|
),
|
||||||
|
width=12,
|
||||||
|
),
|
||||||
|
],
|
||||||
|
justify="start",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
],
|
||||||
|
style={
|
||||||
|
"width": "100%",
|
||||||
|
"marginTop": "10px",
|
||||||
|
"zIndex": 2000,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
],
|
||||||
|
),
|
||||||
],
|
],
|
||||||
style={"width": "100%", "display": "inline-block"},
|
style={"width": "100%", "display": "inline-block", "overflow": "visible"},
|
||||||
)
|
)
|
||||||
return settings_layout
|
return settings_layout
|
||||||
|
|
||||||
@ -293,6 +358,7 @@ class AnalysisExplorer:
|
|||||||
"TextDetector",
|
"TextDetector",
|
||||||
"EmotionDetector",
|
"EmotionDetector",
|
||||||
"ColorDetector",
|
"ColorDetector",
|
||||||
|
"VQA",
|
||||||
],
|
],
|
||||||
value="TextDetector",
|
value="TextDetector",
|
||||||
id="Dropdown_select_Detector",
|
id="Dropdown_select_Detector",
|
||||||
@ -344,7 +410,7 @@ class AnalysisExplorer:
|
|||||||
port (int, optional): The port number to run the server on (default: 8050).
|
port (int, optional): The port number to run the server on (default: 8050).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
self.app.run_server(debug=True, port=port)
|
self.app.run(debug=True, port=port)
|
||||||
|
|
||||||
# Dash callbacks
|
# Dash callbacks
|
||||||
def update_picture(self, img_path: str):
|
def update_picture(self, img_path: str):
|
||||||
@ -379,19 +445,27 @@ class AnalysisExplorer:
|
|||||||
|
|
||||||
if setting_input == "EmotionDetector":
|
if setting_input == "EmotionDetector":
|
||||||
return display_none, display_flex, display_none, display_none
|
return display_none, display_flex, display_none, display_none
|
||||||
|
|
||||||
if setting_input == "ColorDetector":
|
if setting_input == "ColorDetector":
|
||||||
return display_none, display_none, display_flex, display_none
|
return display_none, display_none, display_flex, display_none
|
||||||
|
if setting_input == "VQA":
|
||||||
|
return display_none, display_none, display_none, display_flex
|
||||||
else:
|
else:
|
||||||
return display_none, display_none, display_none, display_none
|
return display_none, display_none, display_none, display_none
|
||||||
|
|
||||||
|
def _parse_questions(self, text: Optional[str]) -> Optional[List[str]]:
|
||||||
|
if not text:
|
||||||
|
return None
|
||||||
|
qs = [q.strip() for q in text.splitlines() if q.strip()]
|
||||||
|
return qs if qs else None
|
||||||
|
|
||||||
def _right_output_analysis(
|
def _right_output_analysis(
|
||||||
self,
|
self,
|
||||||
n_clicks,
|
n_clicks,
|
||||||
all_img_options: dict,
|
all_img_options: dict,
|
||||||
current_img_value: str,
|
current_img_value: str,
|
||||||
detector_value: str,
|
detector_value: str,
|
||||||
|
analysis_type_value: str,
|
||||||
|
textarea_questions_value: str,
|
||||||
settings_text_analyse_text: list,
|
settings_text_analyse_text: list,
|
||||||
setting_privacy_env_var: str,
|
setting_privacy_env_var: str,
|
||||||
setting_emotion_emotion_threshold: int,
|
setting_emotion_emotion_threshold: int,
|
||||||
@ -413,54 +487,75 @@ class AnalysisExplorer:
|
|||||||
"EmotionDetector": faces.EmotionDetector,
|
"EmotionDetector": faces.EmotionDetector,
|
||||||
"TextDetector": text.TextDetector,
|
"TextDetector": text.TextDetector,
|
||||||
"ColorDetector": colors.ColorDetector,
|
"ColorDetector": colors.ColorDetector,
|
||||||
|
"VQA": image_summary.ImageSummaryDetector,
|
||||||
}
|
}
|
||||||
|
|
||||||
# Get image ID from dropdown value, which is the filepath
|
# Get image ID from dropdown value, which is the filepath
|
||||||
if current_img_value is None:
|
if current_img_value is None:
|
||||||
return {}
|
return {}
|
||||||
image_id = all_img_options[current_img_value]
|
image_id = all_img_options[current_img_value]
|
||||||
# copy image so prvious runs don't leave their default values in the dict
|
image_copy = self.mydict.get(image_id, {}).copy()
|
||||||
image_copy = self.mydict[image_id].copy()
|
|
||||||
|
|
||||||
# detector value is the string name of the chosen detector
|
analysis_dict: Dict[str, Any] = {}
|
||||||
identify_function = identify_dict[detector_value]
|
if detector_value == "VQA":
|
||||||
|
try:
|
||||||
if detector_value == "TextDetector":
|
qwen_model = MultimodalSummaryModel(
|
||||||
analyse_text = (
|
model_id="Qwen/Qwen2.5-VL-3B-Instruct"
|
||||||
True if settings_text_analyse_text == ["Analyse text"] else False
|
) # TODO: allow user to specify model
|
||||||
)
|
vqa_cls = identify_dict.get("VQA")
|
||||||
detector_class = identify_function(
|
vqa_detector = vqa_cls(qwen_model, subdict={})
|
||||||
image_copy,
|
questions_list = self._parse_questions(textarea_questions_value)
|
||||||
analyse_text=analyse_text,
|
analysis_result = vqa_detector.analyse_image(
|
||||||
accept_privacy=(
|
image_copy,
|
||||||
setting_privacy_env_var
|
analysis_type=analysis_type_value,
|
||||||
if setting_privacy_env_var
|
list_of_questions=questions_list,
|
||||||
else "PRIVACY_AMMICO"
|
is_concise_summary=True,
|
||||||
),
|
is_concise_answer=True,
|
||||||
)
|
)
|
||||||
elif detector_value == "EmotionDetector":
|
analysis_dict = analysis_result or {}
|
||||||
detector_class = identify_function(
|
except Exception as e:
|
||||||
image_copy,
|
warnings.warn(f"VQA/Image tasks failed: {e}")
|
||||||
emotion_threshold=setting_emotion_emotion_threshold,
|
analysis_dict = {"image_tasks_error": str(e)}
|
||||||
race_threshold=setting_emotion_race_threshold,
|
|
||||||
gender_threshold=setting_emotion_gender_threshold,
|
|
||||||
accept_disclosure=(
|
|
||||||
setting_emotion_env_var
|
|
||||||
if setting_emotion_env_var
|
|
||||||
else "DISCLOSURE_AMMICO"
|
|
||||||
),
|
|
||||||
)
|
|
||||||
elif detector_value == "ColorDetector":
|
|
||||||
detector_class = identify_function(
|
|
||||||
image_copy,
|
|
||||||
delta_e_method=setting_color_delta_e_method,
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
detector_class = identify_function(image_copy)
|
# detector value is the string name of the chosen detector
|
||||||
analysis_dict = detector_class.analyse_image()
|
identify_function = identify_dict[detector_value]
|
||||||
|
|
||||||
# Initialize an empty dictionary
|
if detector_value == "TextDetector":
|
||||||
new_analysis_dict = {}
|
analyse_text = (
|
||||||
|
True if settings_text_analyse_text == ["Analyse text"] else False
|
||||||
|
)
|
||||||
|
detector_class = identify_function(
|
||||||
|
image_copy,
|
||||||
|
analyse_text=analyse_text,
|
||||||
|
accept_privacy=(
|
||||||
|
setting_privacy_env_var
|
||||||
|
if setting_privacy_env_var
|
||||||
|
else "PRIVACY_AMMICO"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
elif detector_value == "EmotionDetector":
|
||||||
|
detector_class = identify_function(
|
||||||
|
image_copy,
|
||||||
|
emotion_threshold=setting_emotion_emotion_threshold,
|
||||||
|
race_threshold=setting_emotion_race_threshold,
|
||||||
|
gender_threshold=setting_emotion_gender_threshold,
|
||||||
|
accept_disclosure=(
|
||||||
|
setting_emotion_env_var
|
||||||
|
if setting_emotion_env_var
|
||||||
|
else "DISCLOSURE_AMMICO"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
elif detector_value == "ColorDetector":
|
||||||
|
detector_class = identify_function(
|
||||||
|
image_copy,
|
||||||
|
delta_e_method=setting_color_delta_e_method,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
detector_class = identify_function(image_copy)
|
||||||
|
|
||||||
|
analysis_dict = detector_class.analyse_image()
|
||||||
|
|
||||||
|
new_analysis_dict: Dict[str, Any] = {}
|
||||||
|
|
||||||
# Iterate over the items in the original dictionary
|
# Iterate over the items in the original dictionary
|
||||||
for k, v in analysis_dict.items():
|
for k, v in analysis_dict.items():
|
||||||
@ -480,3 +575,9 @@ class AnalysisExplorer:
|
|||||||
return dbc.Table.from_dataframe(
|
return dbc.Table.from_dataframe(
|
||||||
df, striped=True, bordered=True, hover=True, index=True
|
df, striped=True, bordered=True, hover=True, index=True
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _show_questions_textarea_on_demand(self, analysis_type_value: str) -> dict:
|
||||||
|
if analysis_type_value in ("questions", "summary_and_questions"):
|
||||||
|
return {"display": "block", "width": "100%"}
|
||||||
|
else:
|
||||||
|
return {"display": "none"}
|
||||||
|
|||||||
365
ammico/image_summary.py
Обычный файл
365
ammico/image_summary.py
Обычный файл
@ -0,0 +1,365 @@
|
|||||||
|
from ammico.utils import AnalysisMethod, AnalysisType
|
||||||
|
from ammico.model import MultimodalSummaryModel
|
||||||
|
|
||||||
|
import os
|
||||||
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
import warnings
|
||||||
|
|
||||||
|
from typing import List, Optional, Union, Dict, Any
|
||||||
|
from collections.abc import Sequence as _Sequence
|
||||||
|
from transformers import GenerationConfig
|
||||||
|
from qwen_vl_utils import process_vision_info
|
||||||
|
|
||||||
|
|
||||||
|
class ImageSummaryDetector(AnalysisMethod):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
summary_model: MultimodalSummaryModel,
|
||||||
|
subdict: Optional[Dict[str, Any]] = None,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Class for analysing images using QWEN-2.5-VL model.
|
||||||
|
It provides methods for generating captions and answering questions about images.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
summary_model ([type], optional): An instance of MultimodalSummaryModel to be used for analysis.
|
||||||
|
subdict (dict, optional): Dictionary containing the image to be analysed. Defaults to {}.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
None.
|
||||||
|
"""
|
||||||
|
if subdict is None:
|
||||||
|
subdict = {}
|
||||||
|
|
||||||
|
super().__init__(subdict)
|
||||||
|
self.summary_model = summary_model
|
||||||
|
|
||||||
|
def _load_pil_if_needed(
|
||||||
|
self, filename: Union[str, os.PathLike, Image.Image]
|
||||||
|
) -> Image.Image:
|
||||||
|
if isinstance(filename, (str, os.PathLike)):
|
||||||
|
return Image.open(filename).convert("RGB")
|
||||||
|
elif isinstance(filename, Image.Image):
|
||||||
|
return filename.convert("RGB")
|
||||||
|
else:
|
||||||
|
raise ValueError("filename must be a path or PIL.Image")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _is_sequence_but_not_str(obj: Any) -> bool:
|
||||||
|
"""True for sequence-like but not a string/bytes/PIL.Image."""
|
||||||
|
return isinstance(obj, _Sequence) and not isinstance(
|
||||||
|
obj, (str, bytes, Image.Image)
|
||||||
|
)
|
||||||
|
|
||||||
|
def _prepare_inputs(
|
||||||
|
self, list_of_questions: list[str], entry: Optional[Dict[str, Any]] = None
|
||||||
|
) -> Dict[str, torch.Tensor]:
|
||||||
|
filename = entry.get("filename")
|
||||||
|
if filename is None:
|
||||||
|
raise ValueError("entry must contain key 'filename'")
|
||||||
|
|
||||||
|
if isinstance(filename, (str, os.PathLike, Image.Image)):
|
||||||
|
images_context = self._load_pil_if_needed(filename)
|
||||||
|
elif self._is_sequence_but_not_str(filename):
|
||||||
|
images_context = [self._load_pil_if_needed(i) for i in filename]
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"Unsupported 'filename' entry: expected path, PIL.Image, or sequence."
|
||||||
|
)
|
||||||
|
|
||||||
|
images_only_messages = [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
*(
|
||||||
|
[{"type": "image", "image": img} for img in images_context]
|
||||||
|
if isinstance(images_context, list)
|
||||||
|
else [{"type": "image", "image": images_context}]
|
||||||
|
)
|
||||||
|
],
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
try:
|
||||||
|
image_inputs, _ = process_vision_info(images_only_messages)
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(f"Image processing failed: {e}")
|
||||||
|
|
||||||
|
texts: List[str] = []
|
||||||
|
for q in list_of_questions:
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
*(
|
||||||
|
[
|
||||||
|
{"type": "image", "image": image}
|
||||||
|
for image in images_context
|
||||||
|
]
|
||||||
|
if isinstance(images_context, list)
|
||||||
|
else [{"type": "image", "image": images_context}]
|
||||||
|
),
|
||||||
|
{"type": "text", "text": q},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
]
|
||||||
|
text = self.summary_model.processor.apply_chat_template(
|
||||||
|
messages, tokenize=False, add_generation_prompt=True
|
||||||
|
)
|
||||||
|
texts.append(text)
|
||||||
|
|
||||||
|
images_batch = [image_inputs] * len(texts)
|
||||||
|
inputs = self.summary_model.processor(
|
||||||
|
text=texts,
|
||||||
|
images=images_batch,
|
||||||
|
padding=True,
|
||||||
|
return_tensors="pt",
|
||||||
|
)
|
||||||
|
inputs = {k: v.to(self.summary_model.device) for k, v in inputs.items()}
|
||||||
|
|
||||||
|
return inputs
|
||||||
|
|
||||||
|
def analyse_image(
|
||||||
|
self,
|
||||||
|
entry: dict,
|
||||||
|
analysis_type: Union[str, AnalysisType] = AnalysisType.SUMMARY,
|
||||||
|
list_of_questions: Optional[List[str]] = None,
|
||||||
|
is_concise_summary: bool = True,
|
||||||
|
is_concise_answer: bool = True,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Analyse a single image entry. Returns dict with keys depending on analysis_type:
|
||||||
|
- 'caption' (str) if summary requested
|
||||||
|
- 'vqa' (dict) if questions requested
|
||||||
|
"""
|
||||||
|
self.subdict = entry
|
||||||
|
analysis_type, is_summary, is_questions = AnalysisType._validate_analysis_type(
|
||||||
|
analysis_type, list_of_questions
|
||||||
|
)
|
||||||
|
|
||||||
|
if is_summary:
|
||||||
|
try:
|
||||||
|
caps = self.generate_caption(
|
||||||
|
entry,
|
||||||
|
num_return_sequences=1,
|
||||||
|
is_concise_summary=is_concise_summary,
|
||||||
|
)
|
||||||
|
self.subdict["caption"] = caps[0] if caps else ""
|
||||||
|
except Exception as e:
|
||||||
|
warnings.warn(f"Caption generation failed: {e}")
|
||||||
|
|
||||||
|
if is_questions:
|
||||||
|
try:
|
||||||
|
vqa_map = self.answer_questions(
|
||||||
|
list_of_questions, entry, is_concise_answer
|
||||||
|
)
|
||||||
|
self.subdict["vqa"] = vqa_map
|
||||||
|
except Exception as e:
|
||||||
|
warnings.warn(f"VQA failed: {e}")
|
||||||
|
|
||||||
|
return self.subdict
|
||||||
|
|
||||||
|
def analyse_images_from_dict(
|
||||||
|
self,
|
||||||
|
analysis_type: Union[AnalysisType, str] = AnalysisType.SUMMARY,
|
||||||
|
list_of_questions: Optional[List[str]] = None,
|
||||||
|
keys_batch_size: int = 16,
|
||||||
|
is_concise_summary: bool = True,
|
||||||
|
is_concise_answer: bool = True,
|
||||||
|
) -> Dict[str, dict]:
|
||||||
|
"""
|
||||||
|
Analyse image with model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
analysis_type (str): type of the analysis.
|
||||||
|
list_of_questions (list[str]): list of questions.
|
||||||
|
keys_batch_size (int): number of images to process in a batch.
|
||||||
|
is_concise_summary (bool): whether to generate concise summary.
|
||||||
|
is_concise_answer (bool): whether to generate concise answers.
|
||||||
|
Returns:
|
||||||
|
self.subdict (dict): dictionary with analysis results.
|
||||||
|
"""
|
||||||
|
# TODO: add option to ask multiple questions per image as one batch.
|
||||||
|
analysis_type, is_summary, is_questions = AnalysisType._validate_analysis_type(
|
||||||
|
analysis_type, list_of_questions
|
||||||
|
)
|
||||||
|
|
||||||
|
keys = list(self.subdict.keys())
|
||||||
|
for batch_start in range(0, len(keys), keys_batch_size):
|
||||||
|
batch_keys = keys[batch_start : batch_start + keys_batch_size]
|
||||||
|
for key in batch_keys:
|
||||||
|
entry = self.subdict[key]
|
||||||
|
if is_summary:
|
||||||
|
try:
|
||||||
|
caps = self.generate_caption(
|
||||||
|
entry,
|
||||||
|
num_return_sequences=1,
|
||||||
|
is_concise_summary=is_concise_summary,
|
||||||
|
)
|
||||||
|
entry["caption"] = caps[0] if caps else ""
|
||||||
|
except Exception as e:
|
||||||
|
warnings.warn(f"Caption generation failed: {e}")
|
||||||
|
|
||||||
|
if is_questions:
|
||||||
|
try:
|
||||||
|
vqa_map = self.answer_questions(
|
||||||
|
list_of_questions, entry, is_concise_answer
|
||||||
|
)
|
||||||
|
entry["vqa"] = vqa_map
|
||||||
|
except Exception as e:
|
||||||
|
warnings.warn(f"VQA failed: {e}")
|
||||||
|
|
||||||
|
self.subdict[key] = entry
|
||||||
|
return self.subdict
|
||||||
|
|
||||||
|
def generate_caption(
|
||||||
|
self,
|
||||||
|
entry: Optional[Dict[str, Any]] = None,
|
||||||
|
num_return_sequences: int = 1,
|
||||||
|
is_concise_summary: bool = True,
|
||||||
|
) -> List[str]:
|
||||||
|
"""
|
||||||
|
Create caption for image. Depending on is_concise_summary it will be either concise or detailed.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
entry (dict): dictionary containing the image to be captioned.
|
||||||
|
num_return_sequences (int): number of captions to generate.
|
||||||
|
is_concise_summary (bool): whether to generate concise summary.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
results (list[str]): list of generated captions.
|
||||||
|
"""
|
||||||
|
if is_concise_summary:
|
||||||
|
prompt = ["Describe this image in one concise caption."]
|
||||||
|
max_new_tokens = 64
|
||||||
|
else:
|
||||||
|
prompt = ["Describe this image."]
|
||||||
|
max_new_tokens = 256
|
||||||
|
inputs = self._prepare_inputs(prompt, entry)
|
||||||
|
|
||||||
|
gen_conf = GenerationConfig(
|
||||||
|
max_new_tokens=max_new_tokens,
|
||||||
|
do_sample=False,
|
||||||
|
num_return_sequences=num_return_sequences,
|
||||||
|
)
|
||||||
|
|
||||||
|
with torch.inference_mode():
|
||||||
|
try:
|
||||||
|
if self.summary_model.device == "cuda":
|
||||||
|
with torch.amp.autocast("cuda", enabled=True):
|
||||||
|
generated_ids = self.summary_model.model.generate(
|
||||||
|
**inputs, generation_config=gen_conf
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
generated_ids = self.summary_model.model.generate(
|
||||||
|
**inputs, generation_config=gen_conf
|
||||||
|
)
|
||||||
|
except RuntimeError as e:
|
||||||
|
warnings.warn(
|
||||||
|
f"Retry without autocast failed: {e}. Attempting cudnn-disabled retry."
|
||||||
|
)
|
||||||
|
cudnn_was_enabled = (
|
||||||
|
torch.backends.cudnn.is_available() and torch.backends.cudnn.enabled
|
||||||
|
)
|
||||||
|
if cudnn_was_enabled:
|
||||||
|
torch.backends.cudnn.enabled = False
|
||||||
|
try:
|
||||||
|
generated_ids = self.summary_model.model.generate(
|
||||||
|
**inputs, generation_config=gen_conf
|
||||||
|
)
|
||||||
|
except Exception as retry_error:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Failed to generate ids after retry: {retry_error}"
|
||||||
|
) from retry_error
|
||||||
|
finally:
|
||||||
|
if cudnn_was_enabled:
|
||||||
|
torch.backends.cudnn.enabled = True
|
||||||
|
|
||||||
|
decoded = None
|
||||||
|
if "input_ids" in inputs:
|
||||||
|
in_ids = inputs["input_ids"]
|
||||||
|
trimmed = [
|
||||||
|
out_ids[len(inp_ids) :]
|
||||||
|
for inp_ids, out_ids in zip(in_ids, generated_ids)
|
||||||
|
]
|
||||||
|
decoded = self.summary_model.tokenizer.batch_decode(
|
||||||
|
trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
decoded = self.summary_model.tokenizer.batch_decode(
|
||||||
|
generated_ids,
|
||||||
|
skip_special_tokens=True,
|
||||||
|
clean_up_tokenization_spaces=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
results = [d.strip() for d in decoded]
|
||||||
|
return results
|
||||||
|
|
||||||
|
def answer_questions(
|
||||||
|
self,
|
||||||
|
list_of_questions: list[str],
|
||||||
|
entry: Optional[Dict[str, Any]] = None,
|
||||||
|
is_concise_answer: bool = True,
|
||||||
|
) -> List[str]:
|
||||||
|
"""
|
||||||
|
Create answers for list of questions about image.
|
||||||
|
Args:
|
||||||
|
list_of_questions (list[str]): list of questions.
|
||||||
|
entry (dict): dictionary containing the image to be captioned.
|
||||||
|
is_concise_answer (bool): whether to generate concise answers.
|
||||||
|
Returns:
|
||||||
|
answers (list[str]): list of answers.
|
||||||
|
"""
|
||||||
|
if is_concise_answer:
|
||||||
|
gen_conf = GenerationConfig(max_new_tokens=64, do_sample=False)
|
||||||
|
for i in range(len(list_of_questions)):
|
||||||
|
if not list_of_questions[i].strip().endswith("?"):
|
||||||
|
list_of_questions[i] = list_of_questions[i].strip() + "?"
|
||||||
|
if not list_of_questions[i].lower().startswith("answer concisely"):
|
||||||
|
list_of_questions[i] = "Answer concisely: " + list_of_questions[i]
|
||||||
|
else:
|
||||||
|
gen_conf = GenerationConfig(max_new_tokens=128, do_sample=False)
|
||||||
|
|
||||||
|
question_chunk_size = 8
|
||||||
|
answers: List[str] = []
|
||||||
|
n = len(list_of_questions)
|
||||||
|
for i in range(0, n, question_chunk_size):
|
||||||
|
chunk = list_of_questions[i : i + question_chunk_size]
|
||||||
|
inputs = self._prepare_inputs(chunk, entry)
|
||||||
|
with torch.inference_mode():
|
||||||
|
if self.summary_model.device == "cuda":
|
||||||
|
with torch.amp.autocast("cuda", enabled=True):
|
||||||
|
out_ids = self.summary_model.model.generate(
|
||||||
|
**inputs, generation_config=gen_conf
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
out_ids = self.summary_model.model.generate(
|
||||||
|
**inputs, generation_config=gen_conf
|
||||||
|
)
|
||||||
|
|
||||||
|
if "input_ids" in inputs:
|
||||||
|
in_ids = inputs["input_ids"]
|
||||||
|
trimmed_batch = [
|
||||||
|
out_row[len(inp_row) :] for inp_row, out_row in zip(in_ids, out_ids)
|
||||||
|
]
|
||||||
|
decoded = self.summary_model.tokenizer.batch_decode(
|
||||||
|
trimmed_batch,
|
||||||
|
skip_special_tokens=True,
|
||||||
|
clean_up_tokenization_spaces=False,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
decoded = self.summary_model.tokenizer.batch_decode(
|
||||||
|
out_ids,
|
||||||
|
skip_special_tokens=True,
|
||||||
|
clean_up_tokenization_spaces=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
answers.extend([d.strip() for d in decoded])
|
||||||
|
|
||||||
|
if len(answers) != len(list_of_questions):
|
||||||
|
raise ValueError(
|
||||||
|
f"Expected {len(list_of_questions)} answers, but got {len(answers)}, try vary amount of questions"
|
||||||
|
)
|
||||||
|
|
||||||
|
return answers
|
||||||
126
ammico/model.py
Обычный файл
126
ammico/model.py
Обычный файл
@ -0,0 +1,126 @@
|
|||||||
|
import torch
|
||||||
|
import warnings
|
||||||
|
from transformers import (
|
||||||
|
Qwen2_5_VLForConditionalGeneration,
|
||||||
|
AutoProcessor,
|
||||||
|
BitsAndBytesConfig,
|
||||||
|
AutoTokenizer,
|
||||||
|
)
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
|
||||||
|
class MultimodalSummaryModel:
|
||||||
|
DEFAULT_CUDA_MODEL = "Qwen/Qwen2.5-VL-7B-Instruct"
|
||||||
|
DEFAULT_CPU_MODEL = "Qwen/Qwen2.5-VL-3B-Instruct"
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_id: Optional[str] = None,
|
||||||
|
device: Optional[str] = None,
|
||||||
|
cache_dir: Optional[str] = None,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Class for QWEN-2.5-VL model loading and inference.
|
||||||
|
Args:
|
||||||
|
model_id: Type of model to load, defaults to a smaller version for CPU if device is "cpu".
|
||||||
|
device: "cuda" or "cpu" (auto-detected when None).
|
||||||
|
cache_dir: huggingface cache dir (optional).
|
||||||
|
"""
|
||||||
|
self.device = self._resolve_device(device)
|
||||||
|
|
||||||
|
if model_id is not None and model_id not in (
|
||||||
|
self.DEFAULT_CUDA_MODEL,
|
||||||
|
self.DEFAULT_CPU_MODEL,
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
f"model_id must be one of {self.DEFAULT_CUDA_MODEL} or {self.DEFAULT_CPU_MODEL}"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.model_id = model_id or (
|
||||||
|
self.DEFAULT_CUDA_MODEL if self.device == "cuda" else self.DEFAULT_CPU_MODEL
|
||||||
|
)
|
||||||
|
|
||||||
|
self.cache_dir = cache_dir
|
||||||
|
self._trust_remote_code = True
|
||||||
|
self._quantize = True
|
||||||
|
|
||||||
|
self.model = None
|
||||||
|
self.processor = None
|
||||||
|
self.tokenizer = None
|
||||||
|
|
||||||
|
self._load_model_and_processor()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _resolve_device(device: Optional[str]) -> str:
|
||||||
|
if device is None:
|
||||||
|
return "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
if device.lower() not in ("cuda", "cpu"):
|
||||||
|
raise ValueError("device must be 'cuda' or 'cpu'")
|
||||||
|
if device.lower() == "cuda" and not torch.cuda.is_available():
|
||||||
|
warnings.warn(
|
||||||
|
"Although 'cuda' was requested, no CUDA device is available. Using CPU instead.",
|
||||||
|
RuntimeWarning,
|
||||||
|
stacklevel=2,
|
||||||
|
)
|
||||||
|
return "cpu"
|
||||||
|
return device.lower()
|
||||||
|
|
||||||
|
def _load_model_and_processor(self):
|
||||||
|
load_kwargs = {"trust_remote_code": self._trust_remote_code, "use_cache": True}
|
||||||
|
if self.cache_dir:
|
||||||
|
load_kwargs["cache_dir"] = self.cache_dir
|
||||||
|
|
||||||
|
self.processor = AutoProcessor.from_pretrained(
|
||||||
|
self.model_id, padding_side="left", **load_kwargs
|
||||||
|
)
|
||||||
|
self.tokenizer = AutoTokenizer.from_pretrained(self.model_id, **load_kwargs)
|
||||||
|
|
||||||
|
if self.device == "cuda":
|
||||||
|
compute_dtype = (
|
||||||
|
torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
|
||||||
|
)
|
||||||
|
bnb_config = BitsAndBytesConfig(
|
||||||
|
load_in_4bit=True,
|
||||||
|
bnb_4bit_use_double_quant=True,
|
||||||
|
bnb_4bit_quant_type="nf4",
|
||||||
|
bnb_4bit_compute_dtype=compute_dtype,
|
||||||
|
)
|
||||||
|
load_kwargs["quantization_config"] = bnb_config
|
||||||
|
load_kwargs["device_map"] = "auto"
|
||||||
|
|
||||||
|
else:
|
||||||
|
load_kwargs.pop("quantization_config", None)
|
||||||
|
load_kwargs.pop("device_map", None)
|
||||||
|
|
||||||
|
self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
||||||
|
self.model_id, **load_kwargs
|
||||||
|
)
|
||||||
|
self.model.eval()
|
||||||
|
|
||||||
|
def _close(self) -> None:
|
||||||
|
"""Free model resources (helpful in long-running processes)."""
|
||||||
|
try:
|
||||||
|
if self.model is not None:
|
||||||
|
del self.model
|
||||||
|
self.model = None
|
||||||
|
if self.processor is not None:
|
||||||
|
del self.processor
|
||||||
|
self.processor = None
|
||||||
|
if self.tokenizer is not None:
|
||||||
|
del self.tokenizer
|
||||||
|
self.tokenizer = None
|
||||||
|
finally:
|
||||||
|
try:
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
except Exception as e:
|
||||||
|
warnings.warn(
|
||||||
|
"Failed to empty CUDA cache. This is not critical, but may lead to memory lingering: "
|
||||||
|
f"{e!r}",
|
||||||
|
RuntimeWarning,
|
||||||
|
stacklevel=2,
|
||||||
|
)
|
||||||
|
|
||||||
|
def close(self) -> None:
|
||||||
|
"""Free model resources (helpful in long-running processes)."""
|
||||||
|
self._close()
|
||||||
190
ammico/notebooks/DemoImageSummaryVQA.ipynb
Обычный файл
190
ammico/notebooks/DemoImageSummaryVQA.ipynb
Обычный файл
@ -0,0 +1,190 @@
|
|||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "0",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"# Image summary and visual question answering"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "1",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"This notebook shows how to generate image captions and use the visual question answering with AMMICO. \n",
|
||||||
|
"\n",
|
||||||
|
"The first cell imports `ammico`.\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "2",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"import ammico"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "3",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"The cell below loads the model for VQA tasks. By default, it loads a large model on the GPU (if your device supports CUDA), otherwise it loads a relatively smaller model on the CPU. But you can specify other settings (e.g., a small model on the GPU) if you want."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "4",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"model = ammico.MultimodalSummaryModel()"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "5",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"Here you need to provide the path to your google drive folder or local folder containing the images"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "6",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"image_dict = ammico.find_files(\n",
|
||||||
|
" path=str(\"/insert/your/path/here/\"),\n",
|
||||||
|
" limit=-1, # -1 means no limit on the number of files, by default it is set to 20\n",
|
||||||
|
")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "7",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"The cell below creates an object that analyzes images and generates a summary using a specific model and image data."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "8",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"img = ammico.ImageSummaryDetector(summary_model=model, subdict=image_dict)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "9",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Image summary "
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "10",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"To start your work with images, you should call the `analyse_images` method.\n",
|
||||||
|
"\n",
|
||||||
|
"You can specify what kind of analysis you want to perform with `analysis_type`. `\"summary\"` will generate a summary for all pictures in your dictionary, `\"questions\"` will prepare answers to your questions for all pictures, and `\"summary_and_questions\"` will do both.\n",
|
||||||
|
"\n",
|
||||||
|
"Parameter `\"is_concise_summary\"` regulates the length of an answer.\n",
|
||||||
|
"\n",
|
||||||
|
"Here we want to get a long summary on each object in our image dictionary."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "11",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"summaries = img.analyse_images(analysis_type=\"summary\", is_concise_summary=False)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "12",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## VQA"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "13",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"In addition to analyzing images in `ammico`, the same model can be used in VQA mode. To do this, you need to define the questions that will be applied to all images from your dict."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "14",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"questions = [\"Are there any visible signs of violence?\", \"Is it safe to be there?\"]"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "15",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"Here is an example of VQA mode usage. You can specify whether you want to receive short answers (recommended option) or not."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "16",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"vqa_results = img.analyse_images(\n",
|
||||||
|
" analysis_type=\"questions\",\n",
|
||||||
|
" list_of_questions=questions,\n",
|
||||||
|
" is_concise_answer=True,\n",
|
||||||
|
")"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "ammico-dev",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.11.13"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 5
|
||||||
|
}
|
||||||
@ -104,7 +104,8 @@
|
|||||||
"import ammico\n",
|
"import ammico\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# for displaying a progress bar\n",
|
"# for displaying a progress bar\n",
|
||||||
"from tqdm import tqdm"
|
"from tqdm import tqdm\n",
|
||||||
|
"import os"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -140,7 +141,9 @@
|
|||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"# os.environ[\"GOOGLE_APPLICATION_CREDENTIALS\"] = \"/content/drive/MyDrive/misinformation-data/misinformation-campaign-981aa55a3b13.json\""
|
"os.environ[\"GOOGLE_APPLICATION_CREDENTIALS\"] = (\n",
|
||||||
|
" \"/home/inga/projects/misinformation-project/misinformation-notes/misinformation-campaign-981aa55a3b13.json\"\n",
|
||||||
|
")"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -171,6 +174,7 @@
|
|||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
|
"data_path = \"./data-test\"\n",
|
||||||
"image_dict = ammico.find_files(\n",
|
"image_dict = ammico.find_files(\n",
|
||||||
" # path = \"/content/drive/MyDrive/misinformation-data/\",\n",
|
" # path = \"/content/drive/MyDrive/misinformation-data/\",\n",
|
||||||
" path=str(data_path),\n",
|
" path=str(data_path),\n",
|
||||||
@ -337,7 +341,7 @@
|
|||||||
" enumerate(image_dict.keys()), total=len(image_dict)\n",
|
" enumerate(image_dict.keys()), total=len(image_dict)\n",
|
||||||
"): # loop through all images\n",
|
"): # loop through all images\n",
|
||||||
" image_dict[key] = ammico.TextDetector(\n",
|
" image_dict[key] = ammico.TextDetector(\n",
|
||||||
" image_dict[key], analyse_text=True\n",
|
" image_dict[key]\n",
|
||||||
" ).analyse_image() # analyse image with EmotionDetector and update dict\n",
|
" ).analyse_image() # analyse image with EmotionDetector and update dict\n",
|
||||||
"\n",
|
"\n",
|
||||||
" if (\n",
|
" if (\n",
|
||||||
@ -361,23 +365,12 @@
|
|||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"# initialize the models\n",
|
"# initialize the models\n",
|
||||||
"image_summary_detector = ammico.SummaryDetector(\n",
|
"model = ammico.MultimodalSummaryModel()\n",
|
||||||
" subdict=image_dict, analysis_type=\"summary\", model_type=\"base\"\n",
|
"image_summary_detector = ammico.ImageSummaryDetector(\n",
|
||||||
|
" subdict=image_dict, summary_model=model\n",
|
||||||
")\n",
|
")\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# run the analysis without having to re-iniatialize the model\n",
|
"image_summary_detector.analyse_images(analysis_type=\"summary\")"
|
||||||
"for num, key in tqdm(\n",
|
|
||||||
" enumerate(image_dict.keys()), total=len(image_dict)\n",
|
|
||||||
"): # loop through all images\n",
|
|
||||||
" image_dict[key] = image_summary_detector.analyse_image(\n",
|
|
||||||
" subdict=image_dict[key], analysis_type=\"summary\"\n",
|
|
||||||
" ) # analyse image with SummaryDetector and update dict\n",
|
|
||||||
"\n",
|
|
||||||
" if (\n",
|
|
||||||
" num % dump_every == 0 | num == len(image_dict) - 1\n",
|
|
||||||
" ): # save results every dump_every to dump_file\n",
|
|
||||||
" image_df = ammico.get_dataframe(image_dict)\n",
|
|
||||||
" image_df.to_csv(dump_file)"
|
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -394,6 +387,7 @@
|
|||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"# initialize the models\n",
|
"# initialize the models\n",
|
||||||
|
"# currently this does not work because of the way the summary detector is implemented\n",
|
||||||
"image_summary_detector = ammico.SummaryDetector(\n",
|
"image_summary_detector = ammico.SummaryDetector(\n",
|
||||||
" subdict=image_dict, analysis_type=\"summary\", model_type=\"base\"\n",
|
" subdict=image_dict, analysis_type=\"summary\", model_type=\"base\"\n",
|
||||||
")\n",
|
")\n",
|
||||||
|
|||||||
96
ammico/notebooks/DemoVideoSummaryVQA.ipynb
Обычный файл
96
ammico/notebooks/DemoVideoSummaryVQA.ipynb
Обычный файл
@ -0,0 +1,96 @@
|
|||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"# Video summary and visual question answering"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"import ammico"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"Currently this module supports only video summarization, but it will be updated in the nearest future"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"video_dict = ammico.find_videos(\n",
|
||||||
|
" path=str(\"/insert/your/path/here/\"), # path to the folder with images\n",
|
||||||
|
" limit=-1, # -1 means no limit on the number of files, by default it is set to 20\n",
|
||||||
|
")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"model = ammico.MultimodalSummaryModel()"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"vid_summary_model = ammico.VideoSummaryDetector(summary_model=model, subdict=video_dict)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"summary_dict = vid_summary_model.analyse_video()"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"summary_dict[\"summary\"]"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "ammico-dev",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.11.13"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 2
|
||||||
|
}
|
||||||
@ -1,5 +1,6 @@
|
|||||||
import os
|
import os
|
||||||
import pytest
|
import pytest
|
||||||
|
from ammico.model import MultimodalSummaryModel
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
@ -46,3 +47,12 @@ def get_test_my_dict(get_path):
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
return test_my_dict
|
return test_my_dict
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def model():
|
||||||
|
m = MultimodalSummaryModel(device="cpu")
|
||||||
|
try:
|
||||||
|
yield m
|
||||||
|
finally:
|
||||||
|
m.close()
|
||||||
|
|||||||
@ -50,6 +50,8 @@ def test_right_output_analysis_emotions(get_AE, get_options, monkeypatch):
|
|||||||
get_options[3],
|
get_options[3],
|
||||||
get_options[0],
|
get_options[0],
|
||||||
"EmotionDetector",
|
"EmotionDetector",
|
||||||
|
"summary",
|
||||||
|
"Some question",
|
||||||
True,
|
True,
|
||||||
"SOME_VAR",
|
"SOME_VAR",
|
||||||
50,
|
50,
|
||||||
|
|||||||
37
ammico/test/test_image_summary.py
Обычный файл
37
ammico/test/test_image_summary.py
Обычный файл
@ -0,0 +1,37 @@
|
|||||||
|
from ammico.image_summary import ImageSummaryDetector
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.long
|
||||||
|
def test_image_summary_detector(model, get_testdict):
|
||||||
|
detector = ImageSummaryDetector(summary_model=model, subdict=get_testdict)
|
||||||
|
results = detector.analyse_images_from_dict(analysis_type="summary")
|
||||||
|
assert len(results) == 2
|
||||||
|
for key in get_testdict.keys():
|
||||||
|
assert key in results
|
||||||
|
assert "caption" in results[key]
|
||||||
|
assert isinstance(results[key]["caption"], str)
|
||||||
|
assert len(results[key]["caption"]) > 0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.long
|
||||||
|
def test_image_summary_detector_questions(model, get_testdict):
|
||||||
|
list_of_questions = [
|
||||||
|
"What is happening in the image?",
|
||||||
|
"How many cars are in the image in total?",
|
||||||
|
]
|
||||||
|
detector = ImageSummaryDetector(summary_model=model, subdict=get_testdict)
|
||||||
|
results = detector.analyse_images_from_dict(
|
||||||
|
analysis_type="questions", list_of_questions=list_of_questions
|
||||||
|
)
|
||||||
|
assert len(results) == 2
|
||||||
|
for key in get_testdict.keys():
|
||||||
|
assert "vqa" in results[key]
|
||||||
|
if key == "IMG_2746":
|
||||||
|
assert "marathon" in results[key]["vqa"][0].lower()
|
||||||
|
|
||||||
|
if key == "IMG_2809":
|
||||||
|
assert (
|
||||||
|
"two" in results[key]["vqa"][1].lower() or "2" in results[key]["vqa"][1]
|
||||||
|
)
|
||||||
30
ammico/test/test_model.py
Обычный файл
30
ammico/test/test_model.py
Обычный файл
@ -0,0 +1,30 @@
|
|||||||
|
import pytest
|
||||||
|
from ammico.model import MultimodalSummaryModel
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.long
|
||||||
|
def test_model_init(model):
|
||||||
|
assert model.model is not None
|
||||||
|
assert model.processor is not None
|
||||||
|
assert model.tokenizer is not None
|
||||||
|
assert model.device is not None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.long
|
||||||
|
def test_model_invalid_device():
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
MultimodalSummaryModel(device="invalid_device")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.long
|
||||||
|
def test_model_invalid_model_id():
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
MultimodalSummaryModel(model_id="non_existent_model", device="cpu")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.long
|
||||||
|
def test_free_resources():
|
||||||
|
model = MultimodalSummaryModel(device="cpu")
|
||||||
|
model.close()
|
||||||
|
assert model.model is None
|
||||||
|
assert model.processor is None
|
||||||
@ -52,24 +52,16 @@ def test_privacy_statement(monkeypatch):
|
|||||||
def test_TextDetector(set_testdict, accepted):
|
def test_TextDetector(set_testdict, accepted):
|
||||||
for item in set_testdict:
|
for item in set_testdict:
|
||||||
test_obj = tt.TextDetector(set_testdict[item], accept_privacy=accepted)
|
test_obj = tt.TextDetector(set_testdict[item], accept_privacy=accepted)
|
||||||
assert not test_obj.analyse_text
|
|
||||||
assert not test_obj.skip_extraction
|
assert not test_obj.skip_extraction
|
||||||
assert test_obj.subdict["filename"] == set_testdict[item]["filename"]
|
assert test_obj.subdict["filename"] == set_testdict[item]["filename"]
|
||||||
test_obj = tt.TextDetector(
|
test_obj = tt.TextDetector({}, skip_extraction=True, accept_privacy=accepted)
|
||||||
{}, analyse_text=True, skip_extraction=True, accept_privacy=accepted
|
|
||||||
)
|
|
||||||
assert test_obj.analyse_text
|
|
||||||
assert test_obj.skip_extraction
|
assert test_obj.skip_extraction
|
||||||
with pytest.raises(ValueError):
|
|
||||||
tt.TextDetector({}, analyse_text=1.0, accept_privacy=accepted)
|
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
tt.TextDetector({}, skip_extraction=1.0, accept_privacy=accepted)
|
tt.TextDetector({}, skip_extraction=1.0, accept_privacy=accepted)
|
||||||
|
|
||||||
|
|
||||||
def test_run_spacy(set_testdict, get_path, accepted):
|
def test_run_spacy(set_testdict, get_path, accepted):
|
||||||
test_obj = tt.TextDetector(
|
test_obj = tt.TextDetector(set_testdict["IMG_3755"], accept_privacy=accepted)
|
||||||
set_testdict["IMG_3755"], analyse_text=True, accept_privacy=accepted
|
|
||||||
)
|
|
||||||
ref_file = get_path + "text_IMG_3755.txt"
|
ref_file = get_path + "text_IMG_3755.txt"
|
||||||
with open(ref_file, "r") as file:
|
with open(ref_file, "r") as file:
|
||||||
reference_text = file.read()
|
reference_text = file.read()
|
||||||
@ -108,15 +100,11 @@ def test_analyse_image(set_testdict, set_environ, accepted):
|
|||||||
for item in set_testdict:
|
for item in set_testdict:
|
||||||
test_obj = tt.TextDetector(set_testdict[item], accept_privacy=accepted)
|
test_obj = tt.TextDetector(set_testdict[item], accept_privacy=accepted)
|
||||||
test_obj.analyse_image()
|
test_obj.analyse_image()
|
||||||
test_obj = tt.TextDetector(
|
test_obj = tt.TextDetector(set_testdict[item], accept_privacy=accepted)
|
||||||
set_testdict[item], analyse_text=True, accept_privacy=accepted
|
|
||||||
)
|
|
||||||
test_obj.analyse_image()
|
test_obj.analyse_image()
|
||||||
testdict = {}
|
testdict = {}
|
||||||
testdict["text"] = 20000 * "m"
|
testdict["text"] = 20000 * "m"
|
||||||
test_obj = tt.TextDetector(
|
test_obj = tt.TextDetector(testdict, skip_extraction=True, accept_privacy=accepted)
|
||||||
testdict, skip_extraction=True, analyse_text=True, accept_privacy=accepted
|
|
||||||
)
|
|
||||||
test_obj.analyse_image()
|
test_obj.analyse_image()
|
||||||
assert test_obj.subdict["text_truncated"] == 5000 * "m"
|
assert test_obj.subdict["text_truncated"] == 5000 * "m"
|
||||||
assert test_obj.subdict["text"] == 20000 * "m"
|
assert test_obj.subdict["text"] == 20000 * "m"
|
||||||
|
|||||||
@ -67,7 +67,6 @@ class TextDetector(AnalysisMethod):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
subdict: dict,
|
subdict: dict,
|
||||||
analyse_text: bool = False,
|
|
||||||
skip_extraction: bool = False,
|
skip_extraction: bool = False,
|
||||||
accept_privacy: str = "PRIVACY_AMMICO",
|
accept_privacy: str = "PRIVACY_AMMICO",
|
||||||
) -> None:
|
) -> None:
|
||||||
@ -76,8 +75,6 @@ class TextDetector(AnalysisMethod):
|
|||||||
Args:
|
Args:
|
||||||
subdict (dict): Dictionary containing file name/path, and possibly previous
|
subdict (dict): Dictionary containing file name/path, and possibly previous
|
||||||
analysis results from other modules.
|
analysis results from other modules.
|
||||||
analyse_text (bool, optional): Decide if extracted text will be further subject
|
|
||||||
to analysis. Defaults to False.
|
|
||||||
skip_extraction (bool, optional): Decide if text will be extracted from images or
|
skip_extraction (bool, optional): Decide if text will be extracted from images or
|
||||||
is already provided via a csv. Defaults to False.
|
is already provided via a csv. Defaults to False.
|
||||||
accept_privacy (str, optional): Environment variable to accept the privacy
|
accept_privacy (str, optional): Environment variable to accept the privacy
|
||||||
@ -96,17 +93,13 @@ class TextDetector(AnalysisMethod):
|
|||||||
"Privacy disclosure not accepted - skipping text detection."
|
"Privacy disclosure not accepted - skipping text detection."
|
||||||
)
|
)
|
||||||
self.translator = Translator(raise_exception=True)
|
self.translator = Translator(raise_exception=True)
|
||||||
if not isinstance(analyse_text, bool):
|
|
||||||
raise ValueError("analyse_text needs to be set to true or false")
|
|
||||||
self.analyse_text = analyse_text
|
|
||||||
self.skip_extraction = skip_extraction
|
self.skip_extraction = skip_extraction
|
||||||
if not isinstance(skip_extraction, bool):
|
if not isinstance(skip_extraction, bool):
|
||||||
raise ValueError("skip_extraction needs to be set to true or false")
|
raise ValueError("skip_extraction needs to be set to true or false")
|
||||||
if self.skip_extraction:
|
if self.skip_extraction:
|
||||||
print("Skipping text extraction from image.")
|
print("Skipping text extraction from image.")
|
||||||
print("Reading text directly from provided dictionary.")
|
print("Reading text directly from provided dictionary.")
|
||||||
if self.analyse_text:
|
self._initialize_spacy()
|
||||||
self._initialize_spacy()
|
|
||||||
|
|
||||||
def set_keys(self) -> dict:
|
def set_keys(self) -> dict:
|
||||||
"""Set the default keys for text analysis.
|
"""Set the default keys for text analysis.
|
||||||
@ -183,7 +176,7 @@ class TextDetector(AnalysisMethod):
|
|||||||
self._truncate_text()
|
self._truncate_text()
|
||||||
self.translate_text()
|
self.translate_text()
|
||||||
self.remove_linebreaks()
|
self.remove_linebreaks()
|
||||||
if self.analyse_text and self.subdict["text_english"]:
|
if self.subdict["text_english"]:
|
||||||
self._run_spacy()
|
self._run_spacy()
|
||||||
return self.subdict
|
return self.subdict
|
||||||
|
|
||||||
|
|||||||
145
ammico/utils.py
145
ammico/utils.py
@ -5,6 +5,9 @@ import pooch
|
|||||||
import importlib_resources
|
import importlib_resources
|
||||||
import collections
|
import collections
|
||||||
import random
|
import random
|
||||||
|
from enum import Enum
|
||||||
|
from typing import List, Tuple, Optional, Union
|
||||||
|
import re
|
||||||
|
|
||||||
|
|
||||||
pkg = importlib_resources.files("ammico")
|
pkg = importlib_resources.files("ammico")
|
||||||
@ -40,6 +43,41 @@ def ammico_prefetch_models():
|
|||||||
res.get()
|
res.get()
|
||||||
|
|
||||||
|
|
||||||
|
class AnalysisType(str, Enum):
|
||||||
|
SUMMARY = "summary"
|
||||||
|
QUESTIONS = "questions"
|
||||||
|
SUMMARY_AND_QUESTIONS = "summary_and_questions"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _validate_analysis_type(
|
||||||
|
cls,
|
||||||
|
analysis_type: Union["AnalysisType", str],
|
||||||
|
list_of_questions: Optional[List[str]],
|
||||||
|
) -> Tuple[str, bool, bool]:
|
||||||
|
max_questions_per_image = 15 # safety cap to avoid too many questions
|
||||||
|
if isinstance(analysis_type, AnalysisType):
|
||||||
|
analysis_type = analysis_type.value
|
||||||
|
|
||||||
|
allowed = {item.value for item in AnalysisType}
|
||||||
|
if analysis_type not in allowed:
|
||||||
|
raise ValueError(f"analysis_type must be one of {allowed}")
|
||||||
|
|
||||||
|
if analysis_type in ("questions", "summary_and_questions"):
|
||||||
|
if not list_of_questions:
|
||||||
|
raise ValueError(
|
||||||
|
"list_of_questions must be provided for QUESTIONS analysis type."
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(list_of_questions) > max_questions_per_image:
|
||||||
|
raise ValueError(
|
||||||
|
f"Number of questions per image ({len(list_of_questions)}) exceeds safety cap ({max_questions_per_image}). Reduce questions or increase max_questions_per_image."
|
||||||
|
)
|
||||||
|
|
||||||
|
is_summary = analysis_type in ("summary", "summary_and_questions")
|
||||||
|
is_questions = analysis_type in ("questions", "summary_and_questions")
|
||||||
|
return analysis_type, is_summary, is_questions
|
||||||
|
|
||||||
|
|
||||||
class AnalysisMethod:
|
class AnalysisMethod:
|
||||||
"""Base class to be inherited by all analysis methods."""
|
"""Base class to be inherited by all analysis methods."""
|
||||||
|
|
||||||
@ -94,6 +132,113 @@ def _limit_results(results, limit):
|
|||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
def _categorize_outputs(
|
||||||
|
collected: List[Tuple[float, str]],
|
||||||
|
include_questions: bool = False,
|
||||||
|
) -> Tuple[List[str], List[str]]:
|
||||||
|
"""
|
||||||
|
Categorize collected outputs into summary bullets and VQA bullets.
|
||||||
|
Args:
|
||||||
|
collected (List[Tuple[float, str]]): List of tuples containing timestamps and generated texts.
|
||||||
|
Returns:
|
||||||
|
Tuple[List[str], List[str]]: A tuple containing two lists - summary bullets and VQA bullets.
|
||||||
|
"""
|
||||||
|
MAX_CAPTIONS_FOR_SUMMARY = 600 # TODO For now, this is a constant value, but later we need to make it adjustable, with the idea of cutting out the most similar frames to reduce the load on the system.
|
||||||
|
caps_for_summary_vqa = (
|
||||||
|
collected[-MAX_CAPTIONS_FOR_SUMMARY:]
|
||||||
|
if len(collected) > MAX_CAPTIONS_FOR_SUMMARY
|
||||||
|
else collected
|
||||||
|
)
|
||||||
|
bullets_summary = []
|
||||||
|
bullets_vqa = []
|
||||||
|
|
||||||
|
for t, c in caps_for_summary_vqa:
|
||||||
|
if include_questions:
|
||||||
|
result_sections = c.strip()
|
||||||
|
m = re.search(
|
||||||
|
r"Summary\s*:\s*(.*?)\s*(?:VQA\s+Answers\s*:\s*(.*))?$",
|
||||||
|
result_sections,
|
||||||
|
flags=re.IGNORECASE | re.DOTALL,
|
||||||
|
)
|
||||||
|
if m:
|
||||||
|
summary_text = (
|
||||||
|
m.group(1).replace("\n", " ").strip() if m.group(1) else None
|
||||||
|
)
|
||||||
|
vqa_text = m.group(2).strip() if m.group(2) else None
|
||||||
|
if not summary_text or not vqa_text:
|
||||||
|
raise ValueError(
|
||||||
|
f"Model output is missing either summary or VQA answers: {c}"
|
||||||
|
)
|
||||||
|
bullets_summary.append(f"- [{t:.3f}s] {summary_text}")
|
||||||
|
bullets_vqa.append(f"- [{t:.3f}s] {vqa_text}")
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Failed to parse summary and VQA answers from model output: {c}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
snippet = c.replace("\n", " ").strip()
|
||||||
|
bullets_summary.append(f"- [{t:.3f}s] {snippet}")
|
||||||
|
return bullets_summary, bullets_vqa
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_whitespace(s: str) -> str:
|
||||||
|
return re.sub(r"\s+", " ", s).strip()
|
||||||
|
|
||||||
|
|
||||||
|
def _strip_prompt_prefix_literal(decoded: str, prompt: str) -> str:
|
||||||
|
"""
|
||||||
|
Remove any literal prompt prefix from decoded text using a normalized-substring match.
|
||||||
|
Guarantees no prompt text remains at the start of returned string (best-effort).
|
||||||
|
"""
|
||||||
|
if not decoded:
|
||||||
|
return ""
|
||||||
|
if not prompt:
|
||||||
|
return decoded.strip()
|
||||||
|
|
||||||
|
d_norm = _normalize_whitespace(decoded)
|
||||||
|
p_norm = _normalize_whitespace(prompt)
|
||||||
|
|
||||||
|
idx = d_norm.find(p_norm)
|
||||||
|
if idx != -1:
|
||||||
|
running = []
|
||||||
|
for i, ch in enumerate(decoded):
|
||||||
|
running.append(ch if not ch.isspace() else " ")
|
||||||
|
cur_norm = _normalize_whitespace("".join(running))
|
||||||
|
if cur_norm.endswith(p_norm):
|
||||||
|
return decoded[i + 1 :].lstrip() if i + 1 < len(decoded) else ""
|
||||||
|
m = re.match(
|
||||||
|
r"^(?:\s*(system|user|assistant)[:\s-]*\n?)+", decoded, flags=re.IGNORECASE
|
||||||
|
)
|
||||||
|
if m:
|
||||||
|
return decoded[m.end() :].lstrip()
|
||||||
|
|
||||||
|
return decoded.lstrip("\n\r ").lstrip(":;- ").strip()
|
||||||
|
|
||||||
|
|
||||||
|
def find_videos(
|
||||||
|
path: str = None,
|
||||||
|
pattern=["mp4"], # TODO: test with more video formats
|
||||||
|
recursive: bool = True,
|
||||||
|
limit=5,
|
||||||
|
random_seed: int = None,
|
||||||
|
) -> dict:
|
||||||
|
"""Find video files on the file system."""
|
||||||
|
if path is None:
|
||||||
|
path = os.environ.get("AMMICO_DATA_HOME", ".")
|
||||||
|
if isinstance(pattern, str):
|
||||||
|
pattern = [pattern]
|
||||||
|
results = []
|
||||||
|
for p in pattern:
|
||||||
|
results.extend(_match_pattern(path, p, recursive=recursive))
|
||||||
|
if len(results) == 0:
|
||||||
|
raise FileNotFoundError(f"No files found in {path} with pattern '{pattern}'")
|
||||||
|
if random_seed is not None:
|
||||||
|
random.seed(random_seed)
|
||||||
|
random.shuffle(results)
|
||||||
|
videos = _limit_results(results, limit)
|
||||||
|
return initialize_dict(videos)
|
||||||
|
|
||||||
|
|
||||||
def find_files(
|
def find_files(
|
||||||
path: str = None,
|
path: str = None,
|
||||||
pattern=["png", "jpg", "jpeg", "gif", "webp", "avif", "tiff"],
|
pattern=["png", "jpg", "jpeg", "gif", "webp", "avif", "tiff"],
|
||||||
|
|||||||
514
ammico/video_summary.py
Обычный файл
514
ammico/video_summary.py
Обычный файл
@ -0,0 +1,514 @@
|
|||||||
|
import re
|
||||||
|
import math
|
||||||
|
import torch
|
||||||
|
import warnings
|
||||||
|
from PIL import Image
|
||||||
|
from torchcodec.decoders import VideoDecoder
|
||||||
|
|
||||||
|
from ammico.model import MultimodalSummaryModel
|
||||||
|
from ammico.utils import (
|
||||||
|
AnalysisMethod,
|
||||||
|
AnalysisType,
|
||||||
|
_categorize_outputs,
|
||||||
|
_strip_prompt_prefix_literal,
|
||||||
|
)
|
||||||
|
|
||||||
|
from typing import List, Dict, Any, Generator, Tuple, Union, Optional
|
||||||
|
from transformers import GenerationConfig
|
||||||
|
|
||||||
|
|
||||||
|
class VideoSummaryDetector(AnalysisMethod):
|
||||||
|
MAX_SAMPLES_CAP = 1000 # safety cap for total extracted frames
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
summary_model: MultimodalSummaryModel,
|
||||||
|
subdict: Optional[Dict[str, Any]] = None,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Class for analysing videos using QWEN-2.5-VL model.
|
||||||
|
It provides methods for generating captions and answering questions about videos.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
summary_model ([type], optional): An instance of MultimodalSummaryModel to be used for analysis.
|
||||||
|
subdict (dict, optional): Dictionary containing the video to be analysed. Defaults to {}.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
None.
|
||||||
|
"""
|
||||||
|
if subdict is None:
|
||||||
|
subdict = {}
|
||||||
|
|
||||||
|
super().__init__(subdict)
|
||||||
|
self.summary_model = summary_model
|
||||||
|
|
||||||
|
def _frame_batch_generator(
|
||||||
|
self,
|
||||||
|
timestamps: torch.Tensor,
|
||||||
|
batch_size: int,
|
||||||
|
video_decoder: VideoDecoder,
|
||||||
|
) -> Generator[Tuple[torch.Tensor, torch.Tensor], None, None]:
|
||||||
|
"""
|
||||||
|
Yield batches of (frames, timestamps) for given frame indices.
|
||||||
|
- frames are returned as a torch.Tensor with shape (B, C, H, W).
|
||||||
|
- timestamps is a 1D torch.Tensor with B elements.
|
||||||
|
"""
|
||||||
|
total = int(timestamps.numel())
|
||||||
|
|
||||||
|
for start in range(0, total, batch_size):
|
||||||
|
batch_secs = timestamps[start : start + batch_size].tolist()
|
||||||
|
fb = video_decoder.get_frames_played_at(batch_secs)
|
||||||
|
frames = fb.data
|
||||||
|
|
||||||
|
if not frames.is_contiguous():
|
||||||
|
frames = frames.contiguous()
|
||||||
|
pts = fb.pts_seconds
|
||||||
|
pts_out = (
|
||||||
|
pts.cpu().to(dtype=torch.float32)
|
||||||
|
if isinstance(pts, torch.Tensor)
|
||||||
|
else torch.tensor(pts, dtype=torch.float32)
|
||||||
|
)
|
||||||
|
yield frames, pts_out
|
||||||
|
|
||||||
|
def _extract_video_frames(
|
||||||
|
self,
|
||||||
|
entry: Dict[str, Any],
|
||||||
|
frame_rate_per_second: float = 2.0,
|
||||||
|
batch_size: int = 32,
|
||||||
|
) -> Generator[Tuple[torch.Tensor, torch.Tensor], None, None]:
|
||||||
|
"""
|
||||||
|
Extract frames from a video at a specified frame rate and return them as a generator of batches.
|
||||||
|
Args:
|
||||||
|
filename (Union[str, os.PathLike]): Path to the video file.
|
||||||
|
frame_rate_per_second (float, optional): Frame extraction rate in frames per second. Default is 2.
|
||||||
|
batch_size (int, optional): Number of frames to include in each batch. Default is 32.
|
||||||
|
Returns:
|
||||||
|
Generator[Tuple[torch.Tensor, torch.Tensor], None, None]: A generator yielding tuples of
|
||||||
|
(frames, timestamps), where frames is a tensor of shape (B, C, H, W) and timestamps is a 1D tensor of length B.
|
||||||
|
"""
|
||||||
|
|
||||||
|
filename = entry.get("filename")
|
||||||
|
if not filename:
|
||||||
|
raise ValueError("entry must contain key 'filename'")
|
||||||
|
|
||||||
|
video_decoder = VideoDecoder(filename)
|
||||||
|
meta = video_decoder.metadata
|
||||||
|
|
||||||
|
video_fps = getattr(meta, "average_fps", None)
|
||||||
|
if video_fps is None or not (
|
||||||
|
isinstance(video_fps, (int, float)) and video_fps > 0
|
||||||
|
):
|
||||||
|
video_fps = 30.0
|
||||||
|
|
||||||
|
begin_stream_seconds = getattr(meta, "begin_stream_seconds", None)
|
||||||
|
end_stream_seconds = getattr(meta, "end_stream_seconds", None)
|
||||||
|
nframes = len(video_decoder)
|
||||||
|
if getattr(meta, "duration_seconds", None) is not None:
|
||||||
|
duration = float(meta.duration_seconds)
|
||||||
|
elif begin_stream_seconds is not None and end_stream_seconds is not None:
|
||||||
|
duration = float(end_stream_seconds) - float(begin_stream_seconds)
|
||||||
|
elif nframes:
|
||||||
|
duration = float(nframes) / float(video_fps)
|
||||||
|
else:
|
||||||
|
duration = 0.0
|
||||||
|
|
||||||
|
if frame_rate_per_second <= 0:
|
||||||
|
raise ValueError("frame_rate_per_second must be > 0")
|
||||||
|
|
||||||
|
n_samples = max(1, int(math.floor(duration * frame_rate_per_second)))
|
||||||
|
n_samples = min(n_samples, self.MAX_SAMPLES_CAP)
|
||||||
|
|
||||||
|
if begin_stream_seconds is not None and end_stream_seconds is not None:
|
||||||
|
sample_times = torch.linspace(
|
||||||
|
float(begin_stream_seconds), float(end_stream_seconds), steps=n_samples
|
||||||
|
)
|
||||||
|
if sample_times.numel() > 1:
|
||||||
|
sample_times = torch.clamp(
|
||||||
|
sample_times,
|
||||||
|
min=float(begin_stream_seconds),
|
||||||
|
max=float(end_stream_seconds) - 1e-6,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
sample_times = torch.linspace(0.0, max(0.0, duration), steps=n_samples)
|
||||||
|
|
||||||
|
sample_times = sample_times.to(dtype=torch.float32, device="cpu")
|
||||||
|
generator = self._frame_batch_generator(sample_times, batch_size, video_decoder)
|
||||||
|
|
||||||
|
return generator
|
||||||
|
|
||||||
|
def _decode_trimmed_outputs(
|
||||||
|
self,
|
||||||
|
generated_ids: torch.Tensor,
|
||||||
|
inputs: Dict[str, torch.Tensor],
|
||||||
|
tokenizer,
|
||||||
|
prompt_texts: List[str],
|
||||||
|
) -> List[str]:
|
||||||
|
"""
|
||||||
|
Trim prompt tokens using attention_mask/input_ids when available and decode to strings.
|
||||||
|
Then remove any literal prompt prefix using prompt_texts (one per batch element).
|
||||||
|
"""
|
||||||
|
|
||||||
|
batch_size = generated_ids.shape[0]
|
||||||
|
|
||||||
|
if "input_ids" in inputs:
|
||||||
|
token_for_padding = (
|
||||||
|
tokenizer.pad_token_id
|
||||||
|
if getattr(tokenizer, "pad_token_id", None) is not None
|
||||||
|
else getattr(tokenizer, "eos_token_id", None)
|
||||||
|
)
|
||||||
|
if token_for_padding is None:
|
||||||
|
lengths = [int(inputs["input_ids"].shape[1])] * batch_size
|
||||||
|
else:
|
||||||
|
lengths = inputs["input_ids"].ne(token_for_padding).sum(dim=1).tolist()
|
||||||
|
else:
|
||||||
|
lengths = [0] * batch_size
|
||||||
|
|
||||||
|
trimmed_ids = []
|
||||||
|
for i in range(batch_size):
|
||||||
|
out_ids = generated_ids[i]
|
||||||
|
in_len = int(lengths[i]) if i < len(lengths) else 0
|
||||||
|
if out_ids.size(0) > in_len:
|
||||||
|
t = out_ids[in_len:]
|
||||||
|
else:
|
||||||
|
t = out_ids.new_empty((0,), dtype=out_ids.dtype)
|
||||||
|
t_cpu = t.to("cpu")
|
||||||
|
trimmed_ids.append(t_cpu.tolist())
|
||||||
|
|
||||||
|
decoded = tokenizer.batch_decode(
|
||||||
|
trimmed_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True
|
||||||
|
)
|
||||||
|
decoded_results = []
|
||||||
|
for ptext, raw in zip(prompt_texts, decoded):
|
||||||
|
cleaned = _strip_prompt_prefix_literal(raw, ptext)
|
||||||
|
decoded_results.append(cleaned)
|
||||||
|
return decoded_results
|
||||||
|
|
||||||
|
def _generate_from_processor_inputs(
|
||||||
|
self,
|
||||||
|
processor_inputs: Dict[str, torch.Tensor],
|
||||||
|
prompt_texts: List[str],
|
||||||
|
tokenizer,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Run model.generate on already-processed processor_inputs (tensors moved to device),
|
||||||
|
then decode and trim prompt tokens & remove literal prompt prefixes using prompt_texts.
|
||||||
|
"""
|
||||||
|
gen_conf = GenerationConfig(
|
||||||
|
max_new_tokens=64,
|
||||||
|
do_sample=False,
|
||||||
|
num_return_sequences=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
for k, v in list(processor_inputs.items()):
|
||||||
|
if isinstance(v, torch.Tensor):
|
||||||
|
processor_inputs[k] = v.to(self.summary_model.device)
|
||||||
|
|
||||||
|
with torch.inference_mode():
|
||||||
|
try:
|
||||||
|
if self.summary_model.device == "cuda":
|
||||||
|
with torch.amp.autocast("cuda", enabled=True):
|
||||||
|
generated_ids = self.summary_model.model.generate(
|
||||||
|
**processor_inputs, generation_config=gen_conf
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
generated_ids = self.summary_model.model.generate(
|
||||||
|
**processor_inputs, generation_config=gen_conf
|
||||||
|
)
|
||||||
|
except RuntimeError as e:
|
||||||
|
warnings.warn(
|
||||||
|
f"Generation failed with error: {e}. Retrying with cuDNN disabled.",
|
||||||
|
RuntimeWarning,
|
||||||
|
)
|
||||||
|
cudnn_was_enabled = (
|
||||||
|
torch.backends.cudnn.is_available() and torch.backends.cudnn.enabled
|
||||||
|
)
|
||||||
|
if cudnn_was_enabled:
|
||||||
|
torch.backends.cudnn.enabled = False
|
||||||
|
try:
|
||||||
|
generated_ids = self.summary_model.model.generate(
|
||||||
|
**processor_inputs, generation_config=gen_conf
|
||||||
|
)
|
||||||
|
except Exception as retry_error:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Failed to generate ids after retry: {retry_error}"
|
||||||
|
) from retry_error
|
||||||
|
finally:
|
||||||
|
if cudnn_was_enabled:
|
||||||
|
torch.backends.cudnn.enabled = True
|
||||||
|
|
||||||
|
decoded = self._decode_trimmed_outputs(
|
||||||
|
generated_ids, processor_inputs, tokenizer, prompt_texts
|
||||||
|
)
|
||||||
|
return decoded
|
||||||
|
|
||||||
|
def _tensor_batch_to_pil_list(self, batch: torch.Tensor) -> List[Image.Image]:
|
||||||
|
"""
|
||||||
|
Convert a uint8 torch tensor batch (B, C, H, W) on CPU to list of PIL images (RGB).
|
||||||
|
The conversion is done on CPU and returns PIL.Image objects.
|
||||||
|
"""
|
||||||
|
if batch.device.type != "cpu":
|
||||||
|
batch = batch.to("cpu")
|
||||||
|
|
||||||
|
batch = batch.contiguous()
|
||||||
|
if batch.dtype.is_floating_point:
|
||||||
|
batch = (batch.clamp(0.0, 1.0) * 255.0).to(torch.uint8)
|
||||||
|
elif batch.dtype != torch.uint8:
|
||||||
|
batch = batch.to(torch.uint8)
|
||||||
|
pil_list: List[Image.Image] = []
|
||||||
|
for frame in batch:
|
||||||
|
arr = frame.permute(1, 2, 0).numpy()
|
||||||
|
pil_list.append(Image.fromarray(arr))
|
||||||
|
return pil_list
|
||||||
|
|
||||||
|
def make_captions_from_extracted_frames(
|
||||||
|
self,
|
||||||
|
extracted_video_gen: Generator[Tuple[torch.Tensor, torch.Tensor], None, None],
|
||||||
|
list_of_questions: Optional[List[str]] = None,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Generate captions for all extracted frames and then produce a concise summary of the video.
|
||||||
|
Args:
|
||||||
|
extracted_video_dict (Dict[str, Any]): Dictionary containing the frame generator and number of frames.
|
||||||
|
summary_instruction (str, optional): Instruction for summarizing the captions. Defaults to a concise paragraph.
|
||||||
|
Returns:
|
||||||
|
Dict[str, Any]: A dictionary containing the list of captions with timestamps and the final summary.
|
||||||
|
"""
|
||||||
|
|
||||||
|
caption_instruction = "Describe this image in one concise caption."
|
||||||
|
include_questions = bool(list_of_questions)
|
||||||
|
if include_questions:
|
||||||
|
q_block = "\n".join(
|
||||||
|
[f"{i + 1}. {q.strip()}" for i, q in enumerate(list_of_questions)]
|
||||||
|
)
|
||||||
|
caption_instruction += (
|
||||||
|
" In addition to the concise caption, also answer the following questions based ONLY on the image. Answers must be very brief and concise."
|
||||||
|
" Produce exactly two labeled sections: \n\n"
|
||||||
|
"Summary: <concise summary>\n\n"
|
||||||
|
"VQA Answers: \n1. <answer to question 1>\n2. <answer to question 2>\n etc."
|
||||||
|
"\nReturn only those two sections for each image (do not add extra commentary)."
|
||||||
|
"\nIf the answer cannot be determined based on the provided answer blocks,"
|
||||||
|
' reply with the line "The answer cannot be determined based on the information provided."'
|
||||||
|
f"\n\nQuestions:\n{q_block}"
|
||||||
|
)
|
||||||
|
collected: List[Tuple[float, str]] = []
|
||||||
|
proc = self.summary_model.processor
|
||||||
|
try:
|
||||||
|
for batch_frames, batch_times in extracted_video_gen:
|
||||||
|
pil_list = self._tensor_batch_to_pil_list(batch_frames.cpu())
|
||||||
|
|
||||||
|
prompt_texts = []
|
||||||
|
for p in pil_list:
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{"type": "image", "image": p},
|
||||||
|
{"type": "text", "text": caption_instruction},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
prompt_text = proc.apply_chat_template(
|
||||||
|
messages, tokenize=False, add_generation_prompt=True
|
||||||
|
)
|
||||||
|
prompt_texts.append(prompt_text)
|
||||||
|
|
||||||
|
processor_inputs = proc(
|
||||||
|
text=prompt_texts,
|
||||||
|
images=pil_list,
|
||||||
|
return_tensors="pt",
|
||||||
|
padding=True,
|
||||||
|
)
|
||||||
|
captions = self._generate_from_processor_inputs(
|
||||||
|
processor_inputs,
|
||||||
|
prompt_texts,
|
||||||
|
self.summary_model.tokenizer,
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(batch_times, torch.Tensor):
|
||||||
|
batch_times_list = batch_times.cpu().tolist()
|
||||||
|
else:
|
||||||
|
batch_times_list = list(batch_times)
|
||||||
|
for t, c in zip(batch_times_list, captions):
|
||||||
|
collected.append((float(t), c))
|
||||||
|
finally:
|
||||||
|
try:
|
||||||
|
extracted_video_gen.close()
|
||||||
|
except Exception:
|
||||||
|
warnings.warn("Failed to close video frame generator.", RuntimeWarning)
|
||||||
|
|
||||||
|
collected.sort(key=lambda x: x[0])
|
||||||
|
bullets_summary, bullets_vqa = _categorize_outputs(collected, include_questions)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"summary_bullets": bullets_summary,
|
||||||
|
"vqa_bullets": bullets_vqa,
|
||||||
|
} # TODO consider taking out time stamps from the returned structure
|
||||||
|
|
||||||
|
def final_summary(self, summary_dict: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Produce a concise summary of the video, based on generated captions for all extracted frames.
|
||||||
|
Args:
|
||||||
|
summary_dict (Dict[str, Any]): Dictionary containing captions for the frames.
|
||||||
|
Returns:
|
||||||
|
Dict[str, Any]: A dictionary containing the list of captions with timestamps and the final summary.
|
||||||
|
"""
|
||||||
|
summary_instruction = "Analyze the following captions from multiple frames of the same video and summarize the overall content of the video in one concise paragraph (1-3 sentences). Focus on the key themes, actions, or events across the video, not just the individual frames."
|
||||||
|
proc = self.summary_model.processor
|
||||||
|
|
||||||
|
bullets = summary_dict.get("summary_bullets", [])
|
||||||
|
if not bullets:
|
||||||
|
raise ValueError("No captions available for summary generation.")
|
||||||
|
|
||||||
|
combined_captions_text = "\n".join(bullets)
|
||||||
|
summary_user_text = summary_instruction + "\n\n" + combined_captions_text + "\n"
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [{"type": "text", "text": summary_user_text}],
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
summary_prompt_text = proc.apply_chat_template(
|
||||||
|
messages, tokenize=False, add_generation_prompt=True
|
||||||
|
)
|
||||||
|
|
||||||
|
summary_inputs = proc(
|
||||||
|
text=[summary_prompt_text], return_tensors="pt", padding=True
|
||||||
|
)
|
||||||
|
|
||||||
|
summary_inputs = {
|
||||||
|
k: v.to(self.summary_model.device) if isinstance(v, torch.Tensor) else v
|
||||||
|
for k, v in summary_inputs.items()
|
||||||
|
}
|
||||||
|
final_summary_list = self._generate_from_processor_inputs(
|
||||||
|
summary_inputs,
|
||||||
|
[summary_prompt_text],
|
||||||
|
self.summary_model.tokenizer,
|
||||||
|
)
|
||||||
|
final_summary = final_summary_list[0].strip() if final_summary_list else ""
|
||||||
|
|
||||||
|
return {
|
||||||
|
"summary": final_summary,
|
||||||
|
}
|
||||||
|
|
||||||
|
def final_answers(
|
||||||
|
self,
|
||||||
|
answers_dict: Dict[str, Any],
|
||||||
|
list_of_questions: List[str],
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Answer the list of questions for the video based on the VQA bullets from the frames.
|
||||||
|
Args:
|
||||||
|
answers_dict (Dict[str, Any]): Dictionary containing the VQA bullets.
|
||||||
|
Returns:
|
||||||
|
Dict[str, Any]: A dictionary containing the list of answers to the questions.
|
||||||
|
"""
|
||||||
|
vqa_bullets = answers_dict.get("vqa_bullets", [])
|
||||||
|
if not vqa_bullets:
|
||||||
|
raise ValueError(
|
||||||
|
"No VQA bullets generated for single frames available for answering questions."
|
||||||
|
)
|
||||||
|
|
||||||
|
include_questions = bool(list_of_questions)
|
||||||
|
if include_questions:
|
||||||
|
q_block = "\n".join(
|
||||||
|
[f"{i + 1}. {q.strip()}" for i, q in enumerate(list_of_questions)]
|
||||||
|
)
|
||||||
|
prompt = (
|
||||||
|
"You are provided with a set of short VQA-captions, each of which is a block of short answers"
|
||||||
|
" extracted from individual frames of the same video.\n\n"
|
||||||
|
"VQA-captions (use ONLY these to answer):\n"
|
||||||
|
f"{vqa_bullets}\n\n"
|
||||||
|
"Answer the following questions briefly, based ONLY on the lists of answers provided above. The VQA-captions above contain answers"
|
||||||
|
" to the same questions you are about to answer. If the answer cannot be determined based on the provided answer blocks,"
|
||||||
|
' reply with the line "The answer cannot be determined based on the information provided."'
|
||||||
|
"Questions:\n"
|
||||||
|
f"{q_block}\n\n"
|
||||||
|
"Produce an ordered list with answers in the same order as the questions. You must have this structure of your output: "
|
||||||
|
"Answers: \n1. <answer to question 1>\n2. <answer to question 2>\n etc."
|
||||||
|
"Return ONLY the ordered list with answers and NOTHING else — no commentary, no explanation, no surrounding markdown."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"list_of_questions must be provided for making final answers."
|
||||||
|
)
|
||||||
|
|
||||||
|
proc = self.summary_model.processor
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [{"type": "text", "text": prompt}],
|
||||||
|
}
|
||||||
|
]
|
||||||
|
final_vqa_prompt_text = proc.apply_chat_template(
|
||||||
|
messages, tokenize=False, add_generation_prompt=True
|
||||||
|
)
|
||||||
|
final_vqa_inputs = proc(
|
||||||
|
text=[final_vqa_prompt_text], return_tensors="pt", padding=True
|
||||||
|
)
|
||||||
|
final_vqa_inputs = {
|
||||||
|
k: v.to(self.summary_model.device) if isinstance(v, torch.Tensor) else v
|
||||||
|
for k, v in final_vqa_inputs.items()
|
||||||
|
}
|
||||||
|
|
||||||
|
final_vqa_list = self._generate_from_processor_inputs(
|
||||||
|
final_vqa_inputs,
|
||||||
|
[final_vqa_prompt_text],
|
||||||
|
self.summary_model.tokenizer,
|
||||||
|
)
|
||||||
|
|
||||||
|
final_vqa_output = final_vqa_list[0].strip() if final_vqa_list else ""
|
||||||
|
vqa_answers = []
|
||||||
|
answer_matches = re.findall(
|
||||||
|
r"\d+\.\s*(.*?)(?=\n\d+\.|$)", final_vqa_output, flags=re.DOTALL
|
||||||
|
)
|
||||||
|
for answer in answer_matches:
|
||||||
|
vqa_answers.append(answer.strip())
|
||||||
|
return {
|
||||||
|
"vqa_answers": vqa_answers,
|
||||||
|
}
|
||||||
|
|
||||||
|
def analyse_videos_from_dict(
|
||||||
|
self,
|
||||||
|
analysis_type: Union[AnalysisType, str] = AnalysisType.SUMMARY,
|
||||||
|
frame_rate_per_second: float = 2.0,
|
||||||
|
list_of_questions: Optional[List[str]] = None,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Analyse the video specified in self.subdict using frame extraction and captioning.
|
||||||
|
Args:
|
||||||
|
analysis_type (Union[AnalysisType, str], optional): Type of analysis to perform. Defaults to AnalysisType.SUMMARY.
|
||||||
|
frame_rate_per_second (float): Frame extraction rate in frames per second. Default is 2.0.
|
||||||
|
list_of_questions (List[str], optional): List of questions to answer about the video. Required if analysis_type includes questions.
|
||||||
|
Returns:
|
||||||
|
Dict[str, Any]: A dictionary containing the analysis results, including summary and answers for provided questions(if any).
|
||||||
|
"""
|
||||||
|
|
||||||
|
all_answers = {}
|
||||||
|
analysis_type, is_summary, is_questions = AnalysisType._validate_analysis_type(
|
||||||
|
analysis_type, list_of_questions
|
||||||
|
)
|
||||||
|
|
||||||
|
for video_key, entry in self.subdict.items():
|
||||||
|
summary_and_vqa = {}
|
||||||
|
frames_generator = self._extract_video_frames(
|
||||||
|
entry, frame_rate_per_second=frame_rate_per_second
|
||||||
|
)
|
||||||
|
|
||||||
|
answers_dict = self.make_captions_from_extracted_frames(
|
||||||
|
frames_generator, list_of_questions=list_of_questions
|
||||||
|
)
|
||||||
|
# TODO: captions has to be post-processed with foreseeing audio analysis
|
||||||
|
# TODO: captions and answers may lead to prompt, that superior model limits. Consider hierarchical approach.
|
||||||
|
if is_summary:
|
||||||
|
answer = self.final_summary(answers_dict)
|
||||||
|
summary_and_vqa["summary"] = answer["summary"]
|
||||||
|
if is_questions:
|
||||||
|
answer = self.final_answers(answers_dict, list_of_questions)
|
||||||
|
summary_and_vqa["vqa_answers"] = answer["vqa_answers"]
|
||||||
|
|
||||||
|
all_answers[video_key] = summary_and_vqa
|
||||||
|
|
||||||
|
return all_answers
|
||||||
20
environment.yml
Обычный файл
20
environment.yml
Обычный файл
@ -0,0 +1,20 @@
|
|||||||
|
name: ammico-dev
|
||||||
|
channels:
|
||||||
|
- pytorch
|
||||||
|
- nvidia
|
||||||
|
- rapidsai
|
||||||
|
- conda-forge
|
||||||
|
- defaults
|
||||||
|
|
||||||
|
dependencies:
|
||||||
|
- python=3.11
|
||||||
|
- cudatoolkit=11.8
|
||||||
|
- pytorch=2.5.1
|
||||||
|
- pytorch-cuda=11.8
|
||||||
|
- torchvision=0.20.1
|
||||||
|
- torchaudio=2.5.1
|
||||||
|
- faiss-gpu-raft=1.8.0
|
||||||
|
- ipykernel
|
||||||
|
- jupyterlab
|
||||||
|
- jupyterlab_widgets
|
||||||
|
- ffmpeg<8
|
||||||
@ -18,16 +18,19 @@ classifiers = [
|
|||||||
"Operating System :: OS Independent",
|
"Operating System :: OS Independent",
|
||||||
"License :: OSI Approved :: MIT License",
|
"License :: OSI Approved :: MIT License",
|
||||||
]
|
]
|
||||||
|
|
||||||
dependencies = [
|
dependencies = [
|
||||||
|
"accelerate>=0.22",
|
||||||
|
"bitsandbytes",
|
||||||
"colorgram.py",
|
"colorgram.py",
|
||||||
"colour-science",
|
"colour-science",
|
||||||
"dash",
|
"dash",
|
||||||
"dash-bootstrap-components",
|
"dash-bootstrap-components",
|
||||||
"deepface",
|
"deepface",
|
||||||
"google-cloud-vision",
|
"google-cloud-vision",
|
||||||
"googletrans==4.0.0rc1",
|
"googletrans-py", # instead of googletrans4.0.0rc1, for a temporary solution due to incompatibility with jupyterlab
|
||||||
"grpcio",
|
"grpcio",
|
||||||
|
"huggingface-hub>=0.34.0",
|
||||||
"importlib_metadata",
|
"importlib_metadata",
|
||||||
"importlib_resources",
|
"importlib_resources",
|
||||||
"matplotlib",
|
"matplotlib",
|
||||||
@ -36,12 +39,17 @@ dependencies = [
|
|||||||
"pandas",
|
"pandas",
|
||||||
"Pillow",
|
"Pillow",
|
||||||
"pooch",
|
"pooch",
|
||||||
|
"qwen-vl-utils",
|
||||||
"retina_face",
|
"retina_face",
|
||||||
|
"safetensors>=0.6.2",
|
||||||
"setuptools",
|
"setuptools",
|
||||||
"spacy",
|
"spacy",
|
||||||
"tensorflow<=2.16.0",
|
"tensorflow<2.15", # instead of <=2.16.0 to make it compatible with CUDA 11.8, may change after updating CUDA version.
|
||||||
"tf-keras",
|
"tf-keras",
|
||||||
|
"torchvision",
|
||||||
"tqdm",
|
"tqdm",
|
||||||
|
"transformers>=4.54",
|
||||||
|
"torchcodec<0.2",
|
||||||
"webcolors",
|
"webcolors",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|||||||
Загрузка…
x
Ссылка в новой задаче
Block a user