Этот коммит содержится в:
DimasfromLavoisier 2025-10-10 17:05:48 +02:00
родитель 3018800ed4
Коммит d810dbc366
4 изменённых файлов: 53 добавлений и 2 удалений

Просмотреть файл

@ -27,6 +27,15 @@ class MultimodalSummaryModel:
cache_dir: huggingface cache dir (optional).
"""
self.device = self._resolve_device(device)
if model_id is not None and model_id not in (
self.DEFAULT_CUDA_MODEL,
self.DEFAULT_CPU_MODEL,
):
raise ValueError(
f"model_id must be one of {self.DEFAULT_CUDA_MODEL} or {self.DEFAULT_CPU_MODEL}"
)
self.model_id = model_id or (
self.DEFAULT_CUDA_MODEL if self.device == "cuda" else self.DEFAULT_CPU_MODEL
)
@ -94,6 +103,12 @@ class MultimodalSummaryModel:
if self.model is not None:
del self.model
self.model = None
if self.processor is not None:
del self.processor
self.processor = None
if self.tokenizer is not None:
del self.tokenizer
self.tokenizer = None
finally:
try:
if torch.cuda.is_available():

Просмотреть файл

@ -1,5 +1,6 @@
import os
import pytest
from ammico.model import MultimodalSummaryModel
@pytest.fixture
@ -46,3 +47,12 @@ def get_test_my_dict(get_path):
},
}
return test_my_dict
@pytest.fixture(scope="session")
def model():
m = MultimodalSummaryModel(device="cpu")
try:
yield m
finally:
m.close()

27
ammico/test/test_model.py Обычный файл
Просмотреть файл

@ -0,0 +1,27 @@
import pytest
import torch
from ammico.model import MultimodalSummaryModel
def test_model_init(model):
assert model.model is not None
assert model.processor is not None
assert model.tokenizer is not None
assert model.device is not None
def test_model_invalid_device():
with pytest.raises(ValueError):
MultimodalSummaryModel(device="invalid_device")
def test_model_invalid_model_id():
with pytest.raises(ValueError):
MultimodalSummaryModel(model_id="non_existent_model", device="cpu")
def test_free_resources():
model = MultimodalSummaryModel(device="cpu")
model.close()
assert model.model is None
assert model.processor is None

Просмотреть файл

@ -26,10 +26,9 @@ dependencies = [
"colour-science",
"dash",
"dash-bootstrap-components",
"decord",
"deepface",
"google-cloud-vision",
"googletrans-py", # instead of googletrans4.0.0rc1, for a temporary solution due the incompatibility with jupyterlab
"googletrans-py", # instead of googletrans4.0.0rc1, for a temporary solution due to incompatibility with jupyterlab
"grpcio",
"huggingface-hub>=0.34.0",
"importlib_metadata",