зеркало из
https://github.com/ssciwr/AMMICO.git
synced 2025-10-29 21:16:06 +02:00
add base model tests
Этот коммит содержится в:
родитель
3018800ed4
Коммит
d810dbc366
@ -27,6 +27,15 @@ class MultimodalSummaryModel:
|
||||
cache_dir: huggingface cache dir (optional).
|
||||
"""
|
||||
self.device = self._resolve_device(device)
|
||||
|
||||
if model_id is not None and model_id not in (
|
||||
self.DEFAULT_CUDA_MODEL,
|
||||
self.DEFAULT_CPU_MODEL,
|
||||
):
|
||||
raise ValueError(
|
||||
f"model_id must be one of {self.DEFAULT_CUDA_MODEL} or {self.DEFAULT_CPU_MODEL}"
|
||||
)
|
||||
|
||||
self.model_id = model_id or (
|
||||
self.DEFAULT_CUDA_MODEL if self.device == "cuda" else self.DEFAULT_CPU_MODEL
|
||||
)
|
||||
@ -94,6 +103,12 @@ class MultimodalSummaryModel:
|
||||
if self.model is not None:
|
||||
del self.model
|
||||
self.model = None
|
||||
if self.processor is not None:
|
||||
del self.processor
|
||||
self.processor = None
|
||||
if self.tokenizer is not None:
|
||||
del self.tokenizer
|
||||
self.tokenizer = None
|
||||
finally:
|
||||
try:
|
||||
if torch.cuda.is_available():
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
import os
|
||||
import pytest
|
||||
from ammico.model import MultimodalSummaryModel
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -46,3 +47,12 @@ def get_test_my_dict(get_path):
|
||||
},
|
||||
}
|
||||
return test_my_dict
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def model():
|
||||
m = MultimodalSummaryModel(device="cpu")
|
||||
try:
|
||||
yield m
|
||||
finally:
|
||||
m.close()
|
||||
|
||||
27
ammico/test/test_model.py
Обычный файл
27
ammico/test/test_model.py
Обычный файл
@ -0,0 +1,27 @@
|
||||
import pytest
|
||||
import torch
|
||||
from ammico.model import MultimodalSummaryModel
|
||||
|
||||
|
||||
def test_model_init(model):
|
||||
assert model.model is not None
|
||||
assert model.processor is not None
|
||||
assert model.tokenizer is not None
|
||||
assert model.device is not None
|
||||
|
||||
|
||||
def test_model_invalid_device():
|
||||
with pytest.raises(ValueError):
|
||||
MultimodalSummaryModel(device="invalid_device")
|
||||
|
||||
|
||||
def test_model_invalid_model_id():
|
||||
with pytest.raises(ValueError):
|
||||
MultimodalSummaryModel(model_id="non_existent_model", device="cpu")
|
||||
|
||||
|
||||
def test_free_resources():
|
||||
model = MultimodalSummaryModel(device="cpu")
|
||||
model.close()
|
||||
assert model.model is None
|
||||
assert model.processor is None
|
||||
@ -26,10 +26,9 @@ dependencies = [
|
||||
"colour-science",
|
||||
"dash",
|
||||
"dash-bootstrap-components",
|
||||
"decord",
|
||||
"deepface",
|
||||
"google-cloud-vision",
|
||||
"googletrans-py", # instead of googletrans4.0.0rc1, for a temporary solution due the incompatibility with jupyterlab
|
||||
"googletrans-py", # instead of googletrans4.0.0rc1, for a temporary solution due to incompatibility with jupyterlab
|
||||
"grpcio",
|
||||
"huggingface-hub>=0.34.0",
|
||||
"importlib_metadata",
|
||||
|
||||
Загрузка…
x
Ссылка в новой задаче
Block a user