AMMICO/ammico/test/conftest.py
2025-10-27 09:42:12 +01:00

107 строки
3.0 KiB
Python

import os
import pytest
from unittest.mock import Mock, MagicMock
from ammico.model import MultimodalSummaryModel
@pytest.fixture
def get_path(request):
mypath = os.path.dirname(request.module.__file__)
mypath = mypath + "/data/"
return mypath
@pytest.fixture
def set_environ(request):
mypath = os.path.dirname(request.module.__file__)
os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = (
mypath + "/../../data/seismic-bonfire-329406-412821a70264.json"
)
print(os.environ.get("GOOGLE_APPLICATION_CREDENTIALS"))
@pytest.fixture
def get_testdict(get_path):
testdict = {
"IMG_2746": {"filename": get_path + "IMG_2746.png"},
"IMG_2809": {"filename": get_path + "IMG_2809.png"},
}
return testdict
@pytest.fixture
def get_test_my_dict(get_path):
test_my_dict = {
"IMG_2746": {
"filename": get_path + "IMG_2746.png",
"rank A bus": 1,
"A bus": 0.15640679001808167,
"rank " + get_path + "IMG_3758.png": 1,
get_path + "IMG_3758.png": 0.7533495426177979,
},
"IMG_2809": {
"filename": get_path + "IMG_2809.png",
"rank A bus": 0,
"A bus": 0.1970970332622528,
"rank " + get_path + "IMG_3758.png": 0,
get_path + "IMG_3758.png": 0.8907483816146851,
},
}
return test_my_dict
@pytest.fixture(scope="session")
def model():
m = MultimodalSummaryModel(device="cpu")
try:
yield m
finally:
m.close()
@pytest.fixture
def mock_model():
"""
Mock model fixture that doesn't load the actual model.
Useful for faster unit tests that don't need actual model inference.
"""
import torch
# Create a mock model object
mock_model_obj = MagicMock(spec=["generate", "eval"])
mock_model_obj.device = "cpu"
mock_model_obj.eval = MagicMock(return_value=mock_model_obj)
# Create mock processor with necessary methods
mock_processor = MagicMock()
mock_processor.apply_chat_template = MagicMock(
side_effect=lambda messages, **kwargs: "processed_text"
)
# Mock processor to return tensor-like inputs
def mock_processor_call(text, images, **kwargs):
batch_size = len(text) if isinstance(text, list) else 1
return {
"input_ids": torch.randint(0, 1000, (batch_size, 10)),
"pixel_values": torch.randn(batch_size, 3, 224, 224),
"attention_mask": torch.ones(batch_size, 10),
}
mock_processor.__call__ = MagicMock(side_effect=mock_processor_call)
# Create mock tokenizer
mock_tokenizer = MagicMock()
mock_tokenizer.batch_decode = MagicMock(
side_effect=lambda ids, **kwargs: ["mock caption" for _ in range(len(ids))]
)
# Create the mock model instance
mock_m = Mock()
mock_m.model = mock_model_obj
mock_m.processor = mock_processor
mock_m.tokenizer = mock_tokenizer
mock_m.device = "cpu"
mock_m.close = MagicMock()
return mock_m