зеркало из
https://github.com/ssciwr/AMMICO.git
synced 2025-10-29 13:06:04 +02:00
Сравнить коммиты
5 Коммитов
ac6d70a355
...
4a18b1b5a9
| Автор | SHA1 | Дата | |
|---|---|---|---|
|
|
4a18b1b5a9 | ||
|
|
8e9f2b6d87 | ||
|
|
a65f1e2287 | ||
|
|
237c6265fe | ||
|
|
731077be7d |
82
TESTING_WITH_MOCKS.md
Обычный файл
82
TESTING_WITH_MOCKS.md
Обычный файл
@ -0,0 +1,82 @@
|
|||||||
|
# Testing with Mock Models
|
||||||
|
|
||||||
|
This document explains how to use the mock model fixture to write fast unit tests that don't require loading the actual model.
|
||||||
|
|
||||||
|
## Mock Model Fixture
|
||||||
|
|
||||||
|
A `mock_model` fixture has been added to `conftest.py` that creates a lightweight mock of the `MultimodalSummaryModel` class. This fixture:
|
||||||
|
|
||||||
|
- **Does not load any actual models** (super fast)
|
||||||
|
- **Mocks all necessary methods** (processor, tokenizer, model.generate, etc.)
|
||||||
|
- **Returns realistic tensor shapes** (so the code doesn't crash)
|
||||||
|
- **Can be used for fast unit tests** that don't need actual model inference
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
|
||||||
|
Simply use `mock_model` instead of `model` in your test fixtures:
|
||||||
|
|
||||||
|
```python
|
||||||
|
def test_my_feature(mock_model):
|
||||||
|
detector = ImageSummaryDetector(summary_model=mock_model, subdict={})
|
||||||
|
# Your test code here
|
||||||
|
pass
|
||||||
|
```
|
||||||
|
|
||||||
|
## When to Use Mock vs Real Model
|
||||||
|
|
||||||
|
### Use `mock_model` when:
|
||||||
|
- Testing utility functions (like `_clean_list_of_questions`)
|
||||||
|
- Testing input validation logic
|
||||||
|
- Testing data processing methods
|
||||||
|
- Testing class initialization
|
||||||
|
- **Any test that doesn't need actual model inference**
|
||||||
|
|
||||||
|
### Use `model` (real model) when:
|
||||||
|
- Testing end-to-end functionality
|
||||||
|
- Testing actual caption generation quality
|
||||||
|
- Testing actual question answering
|
||||||
|
- Integration tests that verify model behavior
|
||||||
|
- **Any test marked with `@pytest.mark.long`**
|
||||||
|
|
||||||
|
## Example Tests Added
|
||||||
|
|
||||||
|
The following new tests use the mock model:
|
||||||
|
|
||||||
|
1. `test_image_summary_detector_init_mock` - Tests initialization
|
||||||
|
2. `test_load_pil_if_needed_string` - Tests image loading
|
||||||
|
3. `test_is_sequence_but_not_str` - Tests utility methods
|
||||||
|
4. `test_validate_analysis_type` - Tests validation logic
|
||||||
|
|
||||||
|
All of these run quickly without loading the model.
|
||||||
|
|
||||||
|
## Running Tests
|
||||||
|
|
||||||
|
### Run only fast tests (with mocks):
|
||||||
|
```bash
|
||||||
|
pytest ammico/test/test_image_summary.py -v
|
||||||
|
```
|
||||||
|
|
||||||
|
### Run only long tests (with real model):
|
||||||
|
```bash
|
||||||
|
pytest ammico/test/test_image_summary.py -m long -v
|
||||||
|
```
|
||||||
|
|
||||||
|
### Run all tests:
|
||||||
|
```bash
|
||||||
|
pytest ammico/test/test_image_summary.py -v
|
||||||
|
```
|
||||||
|
|
||||||
|
## Customizing the Mock
|
||||||
|
|
||||||
|
If you need to customize the mock's behavior for specific tests, you can override its methods:
|
||||||
|
|
||||||
|
```python
|
||||||
|
def test_custom_behavior(mock_model):
|
||||||
|
# Customize the mock's return value
|
||||||
|
mock_model.tokenizer.batch_decode.return_value = ["custom", "output"]
|
||||||
|
|
||||||
|
detector = ImageSummaryDetector(summary_model=mock_model, subdict={})
|
||||||
|
# Test with custom behavior
|
||||||
|
pass
|
||||||
|
```
|
||||||
|
|
||||||
@ -289,7 +289,7 @@ class ImageSummaryDetector(AnalysisMethod):
|
|||||||
max_new_tokens = self.token_prompt_config[
|
max_new_tokens = self.token_prompt_config[
|
||||||
"concise" if is_concise_summary else "default"
|
"concise" if is_concise_summary else "default"
|
||||||
]["summary"]["max_new_tokens"]
|
]["summary"]["max_new_tokens"]
|
||||||
inputs = self._prepare_inputs(prompt, entry)
|
inputs = self._prepare_inputs([prompt], entry)
|
||||||
|
|
||||||
gen_conf = GenerationConfig(
|
gen_conf = GenerationConfig(
|
||||||
max_new_tokens=max_new_tokens,
|
max_new_tokens=max_new_tokens,
|
||||||
@ -384,10 +384,10 @@ class ImageSummaryDetector(AnalysisMethod):
|
|||||||
"""
|
"""
|
||||||
prompt = self.token_prompt_config[
|
prompt = self.token_prompt_config[
|
||||||
"concise" if is_concise_answer else "default"
|
"concise" if is_concise_answer else "default"
|
||||||
]["answer"]["prompt"]
|
]["questions"]["prompt"]
|
||||||
max_new_tokens = self.token_prompt_config[
|
max_new_tokens = self.token_prompt_config[
|
||||||
"concise" if is_concise_answer else "default"
|
"concise" if is_concise_answer else "default"
|
||||||
]["answer"]["max_new_tokens"]
|
]["questions"]["max_new_tokens"]
|
||||||
|
|
||||||
list_of_questions = self._clean_list_of_questions(list_of_questions, prompt)
|
list_of_questions = self._clean_list_of_questions(list_of_questions, prompt)
|
||||||
gen_conf = GenerationConfig(max_new_tokens=max_new_tokens, do_sample=False)
|
gen_conf = GenerationConfig(max_new_tokens=max_new_tokens, do_sample=False)
|
||||||
|
|||||||
@ -969,9 +969,7 @@
|
|||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"for key in image_dict.keys():\n",
|
"for key in image_dict.keys():\n",
|
||||||
" image_dict[key] = ammico.colors.ColorDetector(image_dict[key]).analyse_image()\n",
|
" image_dict[key] = ammico.colors.ColorDetector(image_dict[key]).analyse_image()"
|
||||||
"\n",
|
|
||||||
"print(\"testing signature\")"
|
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
|||||||
@ -1,6 +1,7 @@
|
|||||||
import os
|
import os
|
||||||
import pytest
|
import pytest
|
||||||
from ammico.model import MultimodalSummaryModel
|
from ammico.model import MultimodalSummaryModel
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
@ -56,3 +57,63 @@ def model():
|
|||||||
yield m
|
yield m
|
||||||
finally:
|
finally:
|
||||||
m.close()
|
m.close()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_model():
|
||||||
|
"""
|
||||||
|
Mock model fixture that doesn't load the actual model.
|
||||||
|
Useful for faster unit tests that don't need actual model inference.
|
||||||
|
"""
|
||||||
|
|
||||||
|
class MockProcessor:
|
||||||
|
"""Mock processor that mimics AutoProcessor behavior."""
|
||||||
|
|
||||||
|
def apply_chat_template(self, messages, **kwargs):
|
||||||
|
return "processed_text"
|
||||||
|
|
||||||
|
def __call__(self, text, images, **kwargs):
|
||||||
|
"""Mock processing that returns tensor-like inputs."""
|
||||||
|
batch_size = len(text) if isinstance(text, list) else 1
|
||||||
|
return {
|
||||||
|
"input_ids": torch.randint(0, 1000, (batch_size, 10)),
|
||||||
|
"pixel_values": torch.randn(batch_size, 3, 224, 224),
|
||||||
|
"attention_mask": torch.ones(batch_size, 10),
|
||||||
|
}
|
||||||
|
|
||||||
|
class MockTokenizer:
|
||||||
|
"""Mock tokenizer that mimics AutoTokenizer behavior."""
|
||||||
|
|
||||||
|
def batch_decode(self, ids, **kwargs):
|
||||||
|
"""Return mock captions for the given batch size."""
|
||||||
|
batch_size = ids.shape[0] if hasattr(ids, "shape") else len(ids)
|
||||||
|
return ["mock caption" for _ in range(batch_size)]
|
||||||
|
|
||||||
|
class MockModelObj:
|
||||||
|
"""Mock model object that mimics the model.generate behavior."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.device = "cpu"
|
||||||
|
|
||||||
|
def eval(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def generate(self, input_ids=None, **kwargs):
|
||||||
|
"""Generate mock token IDs."""
|
||||||
|
batch_size = input_ids.shape[0] if hasattr(input_ids, "shape") else 1
|
||||||
|
return torch.randint(0, 1000, (batch_size, 20))
|
||||||
|
|
||||||
|
class MockMultimodalSummaryModel:
|
||||||
|
"""Mock MultimodalSummaryModel that doesn't load actual models."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.model = MockModelObj()
|
||||||
|
self.processor = MockProcessor()
|
||||||
|
self.tokenizer = MockTokenizer()
|
||||||
|
self.device = "cpu"
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
"""Mock close method - no actual cleanup needed."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
return MockMultimodalSummaryModel()
|
||||||
|
|||||||
@ -1,5 +1,6 @@
|
|||||||
from ammico.image_summary import ImageSummaryDetector
|
from ammico.image_summary import ImageSummaryDetector
|
||||||
|
import os
|
||||||
|
from PIL import Image
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
@ -37,7 +38,7 @@ def test_image_summary_detector_questions(model, get_testdict):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_clean_list_of_questions(model):
|
def test_clean_list_of_questions(mock_model):
|
||||||
list_of_questions = [
|
list_of_questions = [
|
||||||
"What is happening in the image?",
|
"What is happening in the image?",
|
||||||
"",
|
"",
|
||||||
@ -45,7 +46,7 @@ def test_clean_list_of_questions(model):
|
|||||||
None,
|
None,
|
||||||
"How many cars are in the image in total",
|
"How many cars are in the image in total",
|
||||||
]
|
]
|
||||||
detector = ImageSummaryDetector(summary_model=model, subdict={})
|
detector = ImageSummaryDetector(summary_model=mock_model, subdict={})
|
||||||
prompt = detector.token_prompt_config["default"]["questions"]["prompt"]
|
prompt = detector.token_prompt_config["default"]["questions"]["prompt"]
|
||||||
cleaned_questions = detector._clean_list_of_questions(list_of_questions, prompt)
|
cleaned_questions = detector._clean_list_of_questions(list_of_questions, prompt)
|
||||||
assert len(cleaned_questions) == 2
|
assert len(cleaned_questions) == 2
|
||||||
@ -56,3 +57,66 @@ def test_clean_list_of_questions(model):
|
|||||||
assert len(cleaned_questions) == 2
|
assert len(cleaned_questions) == 2
|
||||||
assert cleaned_questions[0] == prompt + "What is happening in the image?"
|
assert cleaned_questions[0] == prompt + "What is happening in the image?"
|
||||||
assert cleaned_questions[1] == prompt + "How many cars are in the image in total?"
|
assert cleaned_questions[1] == prompt + "How many cars are in the image in total?"
|
||||||
|
|
||||||
|
|
||||||
|
# Fast tests using mock model (no actual model loading)
|
||||||
|
def test_image_summary_detector_init_mock(mock_model, get_testdict):
|
||||||
|
"""Test detector initialization with mocked model."""
|
||||||
|
detector = ImageSummaryDetector(summary_model=mock_model, subdict=get_testdict)
|
||||||
|
assert detector.summary_model is mock_model
|
||||||
|
assert len(detector.subdict) == 2
|
||||||
|
|
||||||
|
|
||||||
|
def test_load_pil_if_needed_string(mock_model):
|
||||||
|
"""Test loading image from file path."""
|
||||||
|
detector = ImageSummaryDetector(summary_model=mock_model)
|
||||||
|
# This will try to actually load a file, so we'll use a test image
|
||||||
|
test_image_path = os.path.join(os.path.dirname(__file__), "data", "IMG_2746.png")
|
||||||
|
if os.path.exists(test_image_path):
|
||||||
|
img = detector._load_pil_if_needed(test_image_path)
|
||||||
|
assert isinstance(img, Image.Image)
|
||||||
|
assert img.mode == "RGB"
|
||||||
|
|
||||||
|
|
||||||
|
def test_is_sequence_but_not_str(mock_model):
|
||||||
|
"""Test sequence detection utility."""
|
||||||
|
detector = ImageSummaryDetector(summary_model=mock_model)
|
||||||
|
assert detector._is_sequence_but_not_str([1, 2, 3]) is True
|
||||||
|
assert detector._is_sequence_but_not_str("string") is False
|
||||||
|
assert detector._is_sequence_but_not_str(b"bytes") is False
|
||||||
|
assert (
|
||||||
|
detector._is_sequence_but_not_str({"a": 1}) is False
|
||||||
|
) # dict is sequence-like but not handled as such
|
||||||
|
|
||||||
|
|
||||||
|
def test_validate_analysis_type(mock_model):
|
||||||
|
"""Test analysis type validation."""
|
||||||
|
detector = ImageSummaryDetector(summary_model=mock_model)
|
||||||
|
# Test valid types
|
||||||
|
_, _, is_summary, is_questions = detector._validate_analysis_type(
|
||||||
|
"summary", None, 10
|
||||||
|
)
|
||||||
|
assert is_summary is True
|
||||||
|
assert is_questions is False
|
||||||
|
|
||||||
|
_, _, is_summary, is_questions = detector._validate_analysis_type(
|
||||||
|
"questions", ["What is this?"], 10
|
||||||
|
)
|
||||||
|
assert is_summary is False
|
||||||
|
assert is_questions is True
|
||||||
|
|
||||||
|
_, _, is_summary, is_questions = detector._validate_analysis_type(
|
||||||
|
"summary_and_questions", ["What is this?"], 10
|
||||||
|
)
|
||||||
|
assert is_summary is True
|
||||||
|
assert is_questions is True
|
||||||
|
|
||||||
|
# Test invalid type
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
detector._validate_analysis_type("invalid", None, 10)
|
||||||
|
|
||||||
|
# Test too many questions
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
detector._validate_analysis_type(
|
||||||
|
"questions", ["Q" + str(i) for i in range(33)], 32
|
||||||
|
)
|
||||||
|
|||||||
Загрузка…
x
Ссылка в новой задаче
Block a user