зеркало из
https://github.com/ssciwr/AMMICO.git
synced 2025-10-29 13:06:04 +02:00
107 строки
3.0 KiB
Python
107 строки
3.0 KiB
Python
import os
|
|
import pytest
|
|
from unittest.mock import Mock, MagicMock
|
|
from ammico.model import MultimodalSummaryModel
|
|
|
|
|
|
@pytest.fixture
|
|
def get_path(request):
|
|
mypath = os.path.dirname(request.module.__file__)
|
|
mypath = mypath + "/data/"
|
|
return mypath
|
|
|
|
|
|
@pytest.fixture
|
|
def set_environ(request):
|
|
mypath = os.path.dirname(request.module.__file__)
|
|
os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = (
|
|
mypath + "/../../data/seismic-bonfire-329406-412821a70264.json"
|
|
)
|
|
print(os.environ.get("GOOGLE_APPLICATION_CREDENTIALS"))
|
|
|
|
|
|
@pytest.fixture
|
|
def get_testdict(get_path):
|
|
testdict = {
|
|
"IMG_2746": {"filename": get_path + "IMG_2746.png"},
|
|
"IMG_2809": {"filename": get_path + "IMG_2809.png"},
|
|
}
|
|
return testdict
|
|
|
|
|
|
@pytest.fixture
|
|
def get_test_my_dict(get_path):
|
|
test_my_dict = {
|
|
"IMG_2746": {
|
|
"filename": get_path + "IMG_2746.png",
|
|
"rank A bus": 1,
|
|
"A bus": 0.15640679001808167,
|
|
"rank " + get_path + "IMG_3758.png": 1,
|
|
get_path + "IMG_3758.png": 0.7533495426177979,
|
|
},
|
|
"IMG_2809": {
|
|
"filename": get_path + "IMG_2809.png",
|
|
"rank A bus": 0,
|
|
"A bus": 0.1970970332622528,
|
|
"rank " + get_path + "IMG_3758.png": 0,
|
|
get_path + "IMG_3758.png": 0.8907483816146851,
|
|
},
|
|
}
|
|
return test_my_dict
|
|
|
|
|
|
@pytest.fixture(scope="session")
|
|
def model():
|
|
m = MultimodalSummaryModel(device="cpu")
|
|
try:
|
|
yield m
|
|
finally:
|
|
m.close()
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_model():
|
|
"""
|
|
Mock model fixture that doesn't load the actual model.
|
|
Useful for faster unit tests that don't need actual model inference.
|
|
"""
|
|
import torch
|
|
|
|
# Create a mock model object
|
|
mock_model_obj = MagicMock(spec=["generate", "eval"])
|
|
mock_model_obj.device = "cpu"
|
|
mock_model_obj.eval = MagicMock(return_value=mock_model_obj)
|
|
|
|
# Create mock processor with necessary methods
|
|
mock_processor = MagicMock()
|
|
mock_processor.apply_chat_template = MagicMock(
|
|
side_effect=lambda messages, **kwargs: "processed_text"
|
|
)
|
|
|
|
# Mock processor to return tensor-like inputs
|
|
def mock_processor_call(text, images, **kwargs):
|
|
batch_size = len(text) if isinstance(text, list) else 1
|
|
return {
|
|
"input_ids": torch.randint(0, 1000, (batch_size, 10)),
|
|
"pixel_values": torch.randn(batch_size, 3, 224, 224),
|
|
"attention_mask": torch.ones(batch_size, 10),
|
|
}
|
|
|
|
mock_processor.__call__ = MagicMock(side_effect=mock_processor_call)
|
|
|
|
# Create mock tokenizer
|
|
mock_tokenizer = MagicMock()
|
|
mock_tokenizer.batch_decode = MagicMock(
|
|
side_effect=lambda ids, **kwargs: ["mock caption" for _ in range(len(ids))]
|
|
)
|
|
|
|
# Create the mock model instance
|
|
mock_m = Mock()
|
|
mock_m.model = mock_model_obj
|
|
mock_m.processor = mock_processor
|
|
mock_m.tokenizer = mock_tokenizer
|
|
mock_m.device = "cpu"
|
|
mock_m.close = MagicMock()
|
|
|
|
return mock_m
|