diff --git a/ammico/model.py b/ammico/model.py index 80cc31f..cdc1161 100644 --- a/ammico/model.py +++ b/ammico/model.py @@ -27,6 +27,15 @@ class MultimodalSummaryModel: cache_dir: huggingface cache dir (optional). """ self.device = self._resolve_device(device) + + if model_id is not None and model_id not in ( + self.DEFAULT_CUDA_MODEL, + self.DEFAULT_CPU_MODEL, + ): + raise ValueError( + f"model_id must be one of {self.DEFAULT_CUDA_MODEL} or {self.DEFAULT_CPU_MODEL}" + ) + self.model_id = model_id or ( self.DEFAULT_CUDA_MODEL if self.device == "cuda" else self.DEFAULT_CPU_MODEL ) @@ -94,6 +103,12 @@ class MultimodalSummaryModel: if self.model is not None: del self.model self.model = None + if self.processor is not None: + del self.processor + self.processor = None + if self.tokenizer is not None: + del self.tokenizer + self.tokenizer = None finally: try: if torch.cuda.is_available(): diff --git a/ammico/test/conftest.py b/ammico/test/conftest.py index cb42774..2010e1e 100644 --- a/ammico/test/conftest.py +++ b/ammico/test/conftest.py @@ -1,5 +1,6 @@ import os import pytest +from ammico.model import MultimodalSummaryModel @pytest.fixture @@ -46,3 +47,12 @@ def get_test_my_dict(get_path): }, } return test_my_dict + + +@pytest.fixture(scope="session") +def model(): + m = MultimodalSummaryModel(device="cpu") + try: + yield m + finally: + m.close() diff --git a/ammico/test/test_model.py b/ammico/test/test_model.py new file mode 100644 index 0000000..ac652c0 --- /dev/null +++ b/ammico/test/test_model.py @@ -0,0 +1,27 @@ +import pytest +import torch +from ammico.model import MultimodalSummaryModel + + +def test_model_init(model): + assert model.model is not None + assert model.processor is not None + assert model.tokenizer is not None + assert model.device is not None + + +def test_model_invalid_device(): + with pytest.raises(ValueError): + MultimodalSummaryModel(device="invalid_device") + + +def test_model_invalid_model_id(): + with pytest.raises(ValueError): + MultimodalSummaryModel(model_id="non_existent_model", device="cpu") + + +def test_free_resources(): + model = MultimodalSummaryModel(device="cpu") + model.close() + assert model.model is None + assert model.processor is None diff --git a/pyproject.toml b/pyproject.toml index c0f5440..4f29217 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,10 +26,9 @@ dependencies = [ "colour-science", "dash", "dash-bootstrap-components", - "decord", "deepface", "google-cloud-vision", - "googletrans-py", # instead of googletrans4.0.0rc1, for a temporary solution due the incompatibility with jupyterlab + "googletrans-py", # instead of googletrans4.0.0rc1, for a temporary solution due to incompatibility with jupyterlab "grpcio", "huggingface-hub>=0.34.0", "importlib_metadata",