зеркало из
https://github.com/ssciwr/AMMICO.git
synced 2025-10-29 13:06:04 +02:00
test: some small changes to mock model
Этот коммит содержится в:
родитель
731077be7d
Коммит
237c6265fe
@ -1,7 +1,7 @@
|
|||||||
import os
|
import os
|
||||||
import pytest
|
import pytest
|
||||||
from unittest.mock import Mock, MagicMock
|
|
||||||
from ammico.model import MultimodalSummaryModel
|
from ammico.model import MultimodalSummaryModel
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
@ -65,42 +65,55 @@ def mock_model():
|
|||||||
Mock model fixture that doesn't load the actual model.
|
Mock model fixture that doesn't load the actual model.
|
||||||
Useful for faster unit tests that don't need actual model inference.
|
Useful for faster unit tests that don't need actual model inference.
|
||||||
"""
|
"""
|
||||||
import torch
|
|
||||||
|
|
||||||
# Create a mock model object
|
class MockProcessor:
|
||||||
mock_model_obj = MagicMock(spec=["generate", "eval"])
|
"""Mock processor that mimics AutoProcessor behavior."""
|
||||||
mock_model_obj.device = "cpu"
|
|
||||||
mock_model_obj.eval = MagicMock(return_value=mock_model_obj)
|
|
||||||
|
|
||||||
# Create mock processor with necessary methods
|
def apply_chat_template(self, messages, **kwargs):
|
||||||
mock_processor = MagicMock()
|
return "processed_text"
|
||||||
mock_processor.apply_chat_template = MagicMock(
|
|
||||||
side_effect=lambda messages, **kwargs: "processed_text"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Mock processor to return tensor-like inputs
|
def __call__(self, text, images, **kwargs):
|
||||||
def mock_processor_call(text, images, **kwargs):
|
"""Mock processing that returns tensor-like inputs."""
|
||||||
batch_size = len(text) if isinstance(text, list) else 1
|
batch_size = len(text) if isinstance(text, list) else 1
|
||||||
return {
|
return {
|
||||||
"input_ids": torch.randint(0, 1000, (batch_size, 10)),
|
"input_ids": torch.randint(0, 1000, (batch_size, 10)),
|
||||||
"pixel_values": torch.randn(batch_size, 3, 224, 224),
|
"pixel_values": torch.randn(batch_size, 3, 224, 224),
|
||||||
"attention_mask": torch.ones(batch_size, 10),
|
"attention_mask": torch.ones(batch_size, 10),
|
||||||
}
|
}
|
||||||
|
|
||||||
mock_processor.__call__ = MagicMock(side_effect=mock_processor_call)
|
class MockTokenizer:
|
||||||
|
"""Mock tokenizer that mimics AutoTokenizer behavior."""
|
||||||
|
|
||||||
# Create mock tokenizer
|
def batch_decode(self, ids, **kwargs):
|
||||||
mock_tokenizer = MagicMock()
|
"""Return mock captions for the given batch size."""
|
||||||
mock_tokenizer.batch_decode = MagicMock(
|
batch_size = ids.shape[0] if hasattr(ids, "shape") else len(ids)
|
||||||
side_effect=lambda ids, **kwargs: ["mock caption" for _ in range(len(ids))]
|
return ["mock caption" for _ in range(batch_size)]
|
||||||
)
|
|
||||||
|
|
||||||
# Create the mock model instance
|
class MockModelObj:
|
||||||
mock_m = Mock()
|
"""Mock model object that mimics the model.generate behavior."""
|
||||||
mock_m.model = mock_model_obj
|
|
||||||
mock_m.processor = mock_processor
|
|
||||||
mock_m.tokenizer = mock_tokenizer
|
|
||||||
mock_m.device = "cpu"
|
|
||||||
mock_m.close = MagicMock()
|
|
||||||
|
|
||||||
return mock_m
|
def __init__(self):
|
||||||
|
self.device = "cpu"
|
||||||
|
|
||||||
|
def eval(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def generate(self, input_ids=None, **kwargs):
|
||||||
|
"""Generate mock token IDs."""
|
||||||
|
batch_size = input_ids.shape[0] if hasattr(input_ids, "shape") else 1
|
||||||
|
return torch.randint(0, 1000, (batch_size, 20))
|
||||||
|
|
||||||
|
class MockMultimodalSummaryModel:
|
||||||
|
"""Mock MultimodalSummaryModel that doesn't load actual models."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.model = MockModelObj()
|
||||||
|
self.processor = MockProcessor()
|
||||||
|
self.tokenizer = MockTokenizer()
|
||||||
|
self.device = "cpu"
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
"""Mock close method - no actual cleanup needed."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
return MockMultimodalSummaryModel()
|
||||||
|
|||||||
@ -1,5 +1,6 @@
|
|||||||
from ammico.image_summary import ImageSummaryDetector
|
from ammico.image_summary import ImageSummaryDetector
|
||||||
|
import os
|
||||||
|
from PIL import Image
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
@ -70,13 +71,9 @@ def test_load_pil_if_needed_string(mock_model):
|
|||||||
"""Test loading image from file path."""
|
"""Test loading image from file path."""
|
||||||
detector = ImageSummaryDetector(summary_model=mock_model)
|
detector = ImageSummaryDetector(summary_model=mock_model)
|
||||||
# This will try to actually load a file, so we'll use a test image
|
# This will try to actually load a file, so we'll use a test image
|
||||||
import os
|
|
||||||
|
|
||||||
test_image_path = os.path.join(os.path.dirname(__file__), "data", "IMG_2746.png")
|
test_image_path = os.path.join(os.path.dirname(__file__), "data", "IMG_2746.png")
|
||||||
if os.path.exists(test_image_path):
|
if os.path.exists(test_image_path):
|
||||||
img = detector._load_pil_if_needed(test_image_path)
|
img = detector._load_pil_if_needed(test_image_path)
|
||||||
from PIL import Image
|
|
||||||
|
|
||||||
assert isinstance(img, Image.Image)
|
assert isinstance(img, Image.Image)
|
||||||
assert img.mode == "RGB"
|
assert img.mode == "RGB"
|
||||||
|
|
||||||
|
|||||||
Загрузка…
x
Ссылка в новой задаче
Block a user