test: some small changes to mock model

Этот коммит содержится в:
Inga Ulusoy 2025-10-27 09:49:41 +01:00
родитель 731077be7d
Коммит 237c6265fe
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 7E8998B3002D3B7C
3 изменённых файлов: 48 добавлений и 38 удалений

Просмотреть файл

Просмотреть файл

@ -1,7 +1,7 @@
import os import os
import pytest import pytest
from unittest.mock import Mock, MagicMock
from ammico.model import MultimodalSummaryModel from ammico.model import MultimodalSummaryModel
import torch
@pytest.fixture @pytest.fixture
@ -65,42 +65,55 @@ def mock_model():
Mock model fixture that doesn't load the actual model. Mock model fixture that doesn't load the actual model.
Useful for faster unit tests that don't need actual model inference. Useful for faster unit tests that don't need actual model inference.
""" """
import torch
# Create a mock model object class MockProcessor:
mock_model_obj = MagicMock(spec=["generate", "eval"]) """Mock processor that mimics AutoProcessor behavior."""
mock_model_obj.device = "cpu"
mock_model_obj.eval = MagicMock(return_value=mock_model_obj)
# Create mock processor with necessary methods def apply_chat_template(self, messages, **kwargs):
mock_processor = MagicMock() return "processed_text"
mock_processor.apply_chat_template = MagicMock(
side_effect=lambda messages, **kwargs: "processed_text"
)
# Mock processor to return tensor-like inputs def __call__(self, text, images, **kwargs):
def mock_processor_call(text, images, **kwargs): """Mock processing that returns tensor-like inputs."""
batch_size = len(text) if isinstance(text, list) else 1 batch_size = len(text) if isinstance(text, list) else 1
return { return {
"input_ids": torch.randint(0, 1000, (batch_size, 10)), "input_ids": torch.randint(0, 1000, (batch_size, 10)),
"pixel_values": torch.randn(batch_size, 3, 224, 224), "pixel_values": torch.randn(batch_size, 3, 224, 224),
"attention_mask": torch.ones(batch_size, 10), "attention_mask": torch.ones(batch_size, 10),
} }
mock_processor.__call__ = MagicMock(side_effect=mock_processor_call) class MockTokenizer:
"""Mock tokenizer that mimics AutoTokenizer behavior."""
# Create mock tokenizer def batch_decode(self, ids, **kwargs):
mock_tokenizer = MagicMock() """Return mock captions for the given batch size."""
mock_tokenizer.batch_decode = MagicMock( batch_size = ids.shape[0] if hasattr(ids, "shape") else len(ids)
side_effect=lambda ids, **kwargs: ["mock caption" for _ in range(len(ids))] return ["mock caption" for _ in range(batch_size)]
)
# Create the mock model instance class MockModelObj:
mock_m = Mock() """Mock model object that mimics the model.generate behavior."""
mock_m.model = mock_model_obj
mock_m.processor = mock_processor
mock_m.tokenizer = mock_tokenizer
mock_m.device = "cpu"
mock_m.close = MagicMock()
return mock_m def __init__(self):
self.device = "cpu"
def eval(self):
return self
def generate(self, input_ids=None, **kwargs):
"""Generate mock token IDs."""
batch_size = input_ids.shape[0] if hasattr(input_ids, "shape") else 1
return torch.randint(0, 1000, (batch_size, 20))
class MockMultimodalSummaryModel:
"""Mock MultimodalSummaryModel that doesn't load actual models."""
def __init__(self):
self.model = MockModelObj()
self.processor = MockProcessor()
self.tokenizer = MockTokenizer()
self.device = "cpu"
def close(self):
"""Mock close method - no actual cleanup needed."""
pass
return MockMultimodalSummaryModel()

Просмотреть файл

@ -1,5 +1,6 @@
from ammico.image_summary import ImageSummaryDetector from ammico.image_summary import ImageSummaryDetector
import os
from PIL import Image
import pytest import pytest
@ -70,13 +71,9 @@ def test_load_pil_if_needed_string(mock_model):
"""Test loading image from file path.""" """Test loading image from file path."""
detector = ImageSummaryDetector(summary_model=mock_model) detector = ImageSummaryDetector(summary_model=mock_model)
# This will try to actually load a file, so we'll use a test image # This will try to actually load a file, so we'll use a test image
import os
test_image_path = os.path.join(os.path.dirname(__file__), "data", "IMG_2746.png") test_image_path = os.path.join(os.path.dirname(__file__), "data", "IMG_2746.png")
if os.path.exists(test_image_path): if os.path.exists(test_image_path):
img = detector._load_pil_if_needed(test_image_path) img = detector._load_pil_if_needed(test_image_path)
from PIL import Image
assert isinstance(img, Image.Image) assert isinstance(img, Image.Image)
assert img.mode == "RGB" assert img.mode == "RGB"