diff --git a/ammico/test/TESTING_WITH_MOCKS.md b/TESTING_WITH_MOCKS.md similarity index 100% rename from ammico/test/TESTING_WITH_MOCKS.md rename to TESTING_WITH_MOCKS.md diff --git a/ammico/test/conftest.py b/ammico/test/conftest.py index c2d13a5..f4c8040 100644 --- a/ammico/test/conftest.py +++ b/ammico/test/conftest.py @@ -1,7 +1,7 @@ import os import pytest -from unittest.mock import Mock, MagicMock from ammico.model import MultimodalSummaryModel +import torch @pytest.fixture @@ -65,42 +65,55 @@ def mock_model(): Mock model fixture that doesn't load the actual model. Useful for faster unit tests that don't need actual model inference. """ - import torch - # Create a mock model object - mock_model_obj = MagicMock(spec=["generate", "eval"]) - mock_model_obj.device = "cpu" - mock_model_obj.eval = MagicMock(return_value=mock_model_obj) + class MockProcessor: + """Mock processor that mimics AutoProcessor behavior.""" - # Create mock processor with necessary methods - mock_processor = MagicMock() - mock_processor.apply_chat_template = MagicMock( - side_effect=lambda messages, **kwargs: "processed_text" - ) + def apply_chat_template(self, messages, **kwargs): + return "processed_text" - # Mock processor to return tensor-like inputs - def mock_processor_call(text, images, **kwargs): - batch_size = len(text) if isinstance(text, list) else 1 - return { - "input_ids": torch.randint(0, 1000, (batch_size, 10)), - "pixel_values": torch.randn(batch_size, 3, 224, 224), - "attention_mask": torch.ones(batch_size, 10), - } + def __call__(self, text, images, **kwargs): + """Mock processing that returns tensor-like inputs.""" + batch_size = len(text) if isinstance(text, list) else 1 + return { + "input_ids": torch.randint(0, 1000, (batch_size, 10)), + "pixel_values": torch.randn(batch_size, 3, 224, 224), + "attention_mask": torch.ones(batch_size, 10), + } - mock_processor.__call__ = MagicMock(side_effect=mock_processor_call) + class MockTokenizer: + """Mock tokenizer that mimics AutoTokenizer behavior.""" - # Create mock tokenizer - mock_tokenizer = MagicMock() - mock_tokenizer.batch_decode = MagicMock( - side_effect=lambda ids, **kwargs: ["mock caption" for _ in range(len(ids))] - ) + def batch_decode(self, ids, **kwargs): + """Return mock captions for the given batch size.""" + batch_size = ids.shape[0] if hasattr(ids, "shape") else len(ids) + return ["mock caption" for _ in range(batch_size)] - # Create the mock model instance - mock_m = Mock() - mock_m.model = mock_model_obj - mock_m.processor = mock_processor - mock_m.tokenizer = mock_tokenizer - mock_m.device = "cpu" - mock_m.close = MagicMock() + class MockModelObj: + """Mock model object that mimics the model.generate behavior.""" - return mock_m + def __init__(self): + self.device = "cpu" + + def eval(self): + return self + + def generate(self, input_ids=None, **kwargs): + """Generate mock token IDs.""" + batch_size = input_ids.shape[0] if hasattr(input_ids, "shape") else 1 + return torch.randint(0, 1000, (batch_size, 20)) + + class MockMultimodalSummaryModel: + """Mock MultimodalSummaryModel that doesn't load actual models.""" + + def __init__(self): + self.model = MockModelObj() + self.processor = MockProcessor() + self.tokenizer = MockTokenizer() + self.device = "cpu" + + def close(self): + """Mock close method - no actual cleanup needed.""" + pass + + return MockMultimodalSummaryModel() diff --git a/ammico/test/test_image_summary.py b/ammico/test/test_image_summary.py index 475fbac..1d0dda2 100644 --- a/ammico/test/test_image_summary.py +++ b/ammico/test/test_image_summary.py @@ -1,5 +1,6 @@ from ammico.image_summary import ImageSummaryDetector - +import os +from PIL import Image import pytest @@ -70,13 +71,9 @@ def test_load_pil_if_needed_string(mock_model): """Test loading image from file path.""" detector = ImageSummaryDetector(summary_model=mock_model) # This will try to actually load a file, so we'll use a test image - import os - test_image_path = os.path.join(os.path.dirname(__file__), "data", "IMG_2746.png") if os.path.exists(test_image_path): img = detector._load_pil_if_needed(test_image_path) - from PIL import Image - assert isinstance(img, Image.Image) assert img.mode == "RGB"