Этот коммит содержится в:
DimasfromLavoisier 2025-08-22 15:43:38 +02:00
родитель efe1851fea
Коммит 36a0f90a76

101
ammico/model.py Обычный файл
Просмотреть файл

@ -0,0 +1,101 @@
import torch
from transformers import (
Qwen2_5_VLForConditionalGeneration,
AutoProcessor,
BitsAndBytesConfig,
AutoTokenizer,
)
from typing import Optional
class MultimodalSummaryModel:
DEFAULT_CUDA_MODEL = "Qwen/Qwen2.5-VL-7B-Instruct"
DEFAULT_CPU_MODEL = "Qwen/Qwen2.5-VL-3B-Instruct"
def __init__(
self,
model_id: Optional[str] = None,
device: Optional[str] = None,
cache_dir: Optional[str] = None,
) -> None:
"""
Class for QWEN-2.5-VL model loading and inference.
Args:
model_id: Type of model to load, defaults to a smaller version for CPU if device is "cpu".
device: "cuda" or "cpu" (auto-detected when None).
cache_dir: huggingface cache dir (optional).
"""
self.device = self._resolve_device(device)
self.model_id = model_id or (
self.DEFAULT_CUDA_MODEL if self.device == "cuda" else self.DEFAULT_CPU_MODEL
)
self.cache_dir = cache_dir
self._trust_remote_code = True
self._quantize = True
self.model = None
self.processor = None
self.tokenizer = None
self._load_model_and_processor()
@staticmethod
def _resolve_device(device: Optional[str]) -> str:
if device is None:
return "cuda" if torch.cuda.is_available() else "cpu"
if device.lower() not in ("cuda", "cpu"):
raise ValueError("device must be 'cuda' or 'cpu'")
if device.lower() == "cuda" and not torch.cuda.is_available():
Warning(
"Although 'cuda' was requested, no CUDA device is available. Using CPU instead."
)
return "cpu"
return device.lower()
def _load_model_and_processor(self):
load_kwargs = dict(trust_remote_code=self._trust_remote_code, use_cache=True)
if self.cache_dir:
load_kwargs["cache_dir"] = self.cache_dir
self.processor = AutoProcessor.from_pretrained(
self.model_id, padding_side="left", **load_kwargs
)
self.tokenizer = AutoTokenizer.from_pretrained(self.model_id, **load_kwargs)
if self.device == "cuda":
compute_dtype = (
torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
)
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=compute_dtype,
)
load_kwargs["quantization_config"] = bnb_config
load_kwargs["device_map"] = "auto"
else:
load_kwargs.pop("quantization_config", None)
load_kwargs.pop("device_map", None)
self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
self.model_id, **load_kwargs
)
self.model.eval()
def _close(self) -> None:
"""Free model resources (helpful in long-running processes)."""
try:
if self.model is not None:
del self.model
self.model = None
finally:
try:
if torch.cuda.is_available():
torch.cuda.empty_cache()
except Exception:
Warning(
"Failed to empty CUDA cache. This is not critical, but may lead to memory leaks in long-running processes."
)