diff --git a/ammico/test/TESTING_WITH_MOCKS.md b/ammico/test/TESTING_WITH_MOCKS.md new file mode 100644 index 0000000..3b406da --- /dev/null +++ b/ammico/test/TESTING_WITH_MOCKS.md @@ -0,0 +1,82 @@ +# Testing with Mock Models + +This document explains how to use the mock model fixture to write fast unit tests that don't require loading the actual model. + +## Mock Model Fixture + +A `mock_model` fixture has been added to `conftest.py` that creates a lightweight mock of the `MultimodalSummaryModel` class. This fixture: + +- **Does not load any actual models** (super fast) +- **Mocks all necessary methods** (processor, tokenizer, model.generate, etc.) +- **Returns realistic tensor shapes** (so the code doesn't crash) +- **Can be used for fast unit tests** that don't need actual model inference + +## Usage + +Simply use `mock_model` instead of `model` in your test fixtures: + +```python +def test_my_feature(mock_model): + detector = ImageSummaryDetector(summary_model=mock_model, subdict={}) + # Your test code here + pass +``` + +## When to Use Mock vs Real Model + +### Use `mock_model` when: +- Testing utility functions (like `_clean_list_of_questions`) +- Testing input validation logic +- Testing data processing methods +- Testing class initialization +- **Any test that doesn't need actual model inference** + +### Use `model` (real model) when: +- Testing end-to-end functionality +- Testing actual caption generation quality +- Testing actual question answering +- Integration tests that verify model behavior +- **Any test marked with `@pytest.mark.long`** + +## Example Tests Added + +The following new tests use the mock model: + +1. `test_image_summary_detector_init_mock` - Tests initialization +2. `test_load_pil_if_needed_string` - Tests image loading +3. `test_is_sequence_but_not_str` - Tests utility methods +4. `test_validate_analysis_type` - Tests validation logic + +All of these run quickly without loading the model. + +## Running Tests + +### Run only fast tests (with mocks): +```bash +pytest ammico/test/test_image_summary.py -v +``` + +### Run only long tests (with real model): +```bash +pytest ammico/test/test_image_summary.py -m long -v +``` + +### Run all tests: +```bash +pytest ammico/test/test_image_summary.py -v +``` + +## Customizing the Mock + +If you need to customize the mock's behavior for specific tests, you can override its methods: + +```python +def test_custom_behavior(mock_model): + # Customize the mock's return value + mock_model.tokenizer.batch_decode.return_value = ["custom", "output"] + + detector = ImageSummaryDetector(summary_model=mock_model, subdict={}) + # Test with custom behavior + pass +``` + diff --git a/ammico/test/conftest.py b/ammico/test/conftest.py index 2010e1e..c2d13a5 100644 --- a/ammico/test/conftest.py +++ b/ammico/test/conftest.py @@ -1,5 +1,6 @@ import os import pytest +from unittest.mock import Mock, MagicMock from ammico.model import MultimodalSummaryModel @@ -56,3 +57,50 @@ def model(): yield m finally: m.close() + + +@pytest.fixture +def mock_model(): + """ + Mock model fixture that doesn't load the actual model. + Useful for faster unit tests that don't need actual model inference. + """ + import torch + + # Create a mock model object + mock_model_obj = MagicMock(spec=["generate", "eval"]) + mock_model_obj.device = "cpu" + mock_model_obj.eval = MagicMock(return_value=mock_model_obj) + + # Create mock processor with necessary methods + mock_processor = MagicMock() + mock_processor.apply_chat_template = MagicMock( + side_effect=lambda messages, **kwargs: "processed_text" + ) + + # Mock processor to return tensor-like inputs + def mock_processor_call(text, images, **kwargs): + batch_size = len(text) if isinstance(text, list) else 1 + return { + "input_ids": torch.randint(0, 1000, (batch_size, 10)), + "pixel_values": torch.randn(batch_size, 3, 224, 224), + "attention_mask": torch.ones(batch_size, 10), + } + + mock_processor.__call__ = MagicMock(side_effect=mock_processor_call) + + # Create mock tokenizer + mock_tokenizer = MagicMock() + mock_tokenizer.batch_decode = MagicMock( + side_effect=lambda ids, **kwargs: ["mock caption" for _ in range(len(ids))] + ) + + # Create the mock model instance + mock_m = Mock() + mock_m.model = mock_model_obj + mock_m.processor = mock_processor + mock_m.tokenizer = mock_tokenizer + mock_m.device = "cpu" + mock_m.close = MagicMock() + + return mock_m diff --git a/ammico/test/test_image_summary.py b/ammico/test/test_image_summary.py index b56d806..475fbac 100644 --- a/ammico/test/test_image_summary.py +++ b/ammico/test/test_image_summary.py @@ -37,7 +37,7 @@ def test_image_summary_detector_questions(model, get_testdict): ) -def test_clean_list_of_questions(model): +def test_clean_list_of_questions(mock_model): list_of_questions = [ "What is happening in the image?", "", @@ -45,7 +45,7 @@ def test_clean_list_of_questions(model): None, "How many cars are in the image in total", ] - detector = ImageSummaryDetector(summary_model=model, subdict={}) + detector = ImageSummaryDetector(summary_model=mock_model, subdict={}) prompt = detector.token_prompt_config["default"]["questions"]["prompt"] cleaned_questions = detector._clean_list_of_questions(list_of_questions, prompt) assert len(cleaned_questions) == 2 @@ -56,3 +56,70 @@ def test_clean_list_of_questions(model): assert len(cleaned_questions) == 2 assert cleaned_questions[0] == prompt + "What is happening in the image?" assert cleaned_questions[1] == prompt + "How many cars are in the image in total?" + + +# Fast tests using mock model (no actual model loading) +def test_image_summary_detector_init_mock(mock_model, get_testdict): + """Test detector initialization with mocked model.""" + detector = ImageSummaryDetector(summary_model=mock_model, subdict=get_testdict) + assert detector.summary_model is mock_model + assert len(detector.subdict) == 2 + + +def test_load_pil_if_needed_string(mock_model): + """Test loading image from file path.""" + detector = ImageSummaryDetector(summary_model=mock_model) + # This will try to actually load a file, so we'll use a test image + import os + + test_image_path = os.path.join(os.path.dirname(__file__), "data", "IMG_2746.png") + if os.path.exists(test_image_path): + img = detector._load_pil_if_needed(test_image_path) + from PIL import Image + + assert isinstance(img, Image.Image) + assert img.mode == "RGB" + + +def test_is_sequence_but_not_str(mock_model): + """Test sequence detection utility.""" + detector = ImageSummaryDetector(summary_model=mock_model) + assert detector._is_sequence_but_not_str([1, 2, 3]) is True + assert detector._is_sequence_but_not_str("string") is False + assert detector._is_sequence_but_not_str(b"bytes") is False + assert ( + detector._is_sequence_but_not_str({"a": 1}) is False + ) # dict is sequence-like but not handled as such + + +def test_validate_analysis_type(mock_model): + """Test analysis type validation.""" + detector = ImageSummaryDetector(summary_model=mock_model) + # Test valid types + _, _, is_summary, is_questions = detector._validate_analysis_type( + "summary", None, 10 + ) + assert is_summary is True + assert is_questions is False + + _, _, is_summary, is_questions = detector._validate_analysis_type( + "questions", ["What is this?"], 10 + ) + assert is_summary is False + assert is_questions is True + + _, _, is_summary, is_questions = detector._validate_analysis_type( + "summary_and_questions", ["What is this?"], 10 + ) + assert is_summary is True + assert is_questions is True + + # Test invalid type + with pytest.raises(ValueError): + detector._validate_analysis_type("invalid", None, 10) + + # Test too many questions + with pytest.raises(ValueError): + detector._validate_analysis_type( + "questions", ["Q" + str(i) for i in range(33)], 32 + )