cognitive/tests/test_visualization.py
Daniel Ari Friedman 6caa1a7cb1 Update
2025-02-07 08:16:25 -08:00

173 строки
6.3 KiB
Python

"""
Tests for visualization components.
"""
import pytest
import numpy as np
from src.visualization.matrix_plots import MatrixPlotter, StateSpacePlotter, NetworkPlotter
import matplotlib.pyplot as plt
@pytest.fixture
def style_config():
"""Default style configuration for testing."""
return {
'theme': 'default',
'figure_size': (8, 6),
'dpi': 100,
'colormap': 'viridis',
'font_size': 12,
'line_width': 1.5
}
class TestMatrixPlotter:
"""Test matrix plotting utilities."""
def test_plot_heatmap(self, sample_matrix_2d, output_dir, style_config):
"""Test heatmap plotting."""
plotter = MatrixPlotter(output_dir, style_config)
fig = plotter.plot_heatmap(
matrix=sample_matrix_2d,
title="Test Heatmap",
xlabel="States",
ylabel="Observations",
save_name="test_heatmap"
)
# Check figure properties
assert isinstance(fig, plt.Figure)
# Main axis and colorbar
assert len(fig.axes) == 2
# Check main axis properties
main_ax = fig.axes[0]
assert main_ax.get_title() == "Test Heatmap"
assert main_ax.get_xlabel() == "States"
assert main_ax.get_ylabel() == "Observations"
# Verify file was saved
assert (output_dir / "test_heatmap.png").exists()
def test_plot_multi_heatmap(self, sample_matrix_3d, output_dir, style_config):
"""Test multiple heatmap plotting."""
plotter = MatrixPlotter(output_dir, style_config)
fig = plotter.plot_multi_heatmap(
tensor=sample_matrix_3d,
title="Test Multi-Heatmap",
xlabel="Current State",
ylabel="Next State",
slice_names=["Action 1", "Action 2"],
save_name="test_multi_heatmap"
)
# Check figure properties
assert isinstance(fig, plt.Figure)
# Two main axes and two colorbars
assert len(fig.axes) == 4
# Check titles
assert fig.axes[0].get_title() == "Test Multi-Heatmap - Action 1"
assert fig.axes[1].get_title() == "Test Multi-Heatmap - Action 2"
# Verify file was saved
assert (output_dir / "test_multi_heatmap.png").exists()
def test_plot_bar(self, sample_belief_vector, output_dir, style_config):
"""Test bar plot creation."""
plotter = MatrixPlotter(output_dir, style_config)
fig = plotter.plot_bar(
values=sample_belief_vector,
title="Test Bar Plot",
xlabel="States",
ylabel="Probability",
save_name="test_bar"
)
# Check figure properties
assert isinstance(fig, plt.Figure)
assert len(fig.axes) == 1
ax = fig.axes[0]
assert ax.get_title() == "Test Bar Plot"
assert ax.get_xlabel() == "States"
assert ax.get_ylabel() == "Probability"
assert (output_dir / "test_bar.png").exists()
class TestStateSpacePlotter:
"""Test state space plotting utilities."""
def test_plot_belief_evolution(self, output_dir):
"""Test belief evolution plotting."""
plotter = StateSpacePlotter(output_dir)
beliefs = np.array([[0.8, 0.2], [0.6, 0.4], [0.5, 0.5]])
fig = plotter.plot_belief_evolution(
beliefs=beliefs,
title="Belief Evolution",
state_labels=["State 1", "State 2"],
save_name="test_belief_evolution"
)
assert isinstance(fig, plt.Figure)
# Verify file was saved
assert (output_dir / "test_belief_evolution.png").exists()
def test_plot_free_energy_landscape(self, output_dir):
"""Test free energy landscape plotting."""
plotter = StateSpacePlotter(output_dir)
free_energy = np.array([[1.0, 2.0], [2.0, 3.0]])
fig = plotter.plot_free_energy_landscape(
free_energy=free_energy,
title="Free Energy Landscape",
save_name="test_landscape"
)
assert isinstance(fig, plt.Figure)
# Verify file was saved
assert (output_dir / "test_landscape.png").exists()
def test_plot_policy_evaluation(self, output_dir):
"""Test policy evaluation plotting."""
plotter = StateSpacePlotter(output_dir)
policy_values = np.array([0.8, 0.6, 0.4])
fig = plotter.plot_policy_evaluation(
policy_values=policy_values,
policy_labels=["Policy 1", "Policy 2", "Policy 3"],
title="Policy Evaluation",
save_name="test_policy_eval"
)
assert isinstance(fig, plt.Figure)
# Verify file was saved
assert (output_dir / "test_policy_eval.png").exists()
class TestNetworkPlotter:
"""Test network plotting utilities."""
def test_plot_belief_network(self, output_dir):
"""Test belief network plotting."""
plotter = NetworkPlotter(output_dir)
adjacency = np.array([[0, 1, 0], [0, 0, 1], [1, 0, 0]])
fig = plotter.plot_belief_network(
adjacency=adjacency,
node_labels=["A", "B", "C"],
title="Belief Network",
save_name="test_network"
)
assert isinstance(fig, plt.Figure)
# Verify file was saved
assert (output_dir / "test_network.png").exists()
@pytest.mark.parametrize("matrix_shape,expected_axes", [
((2, 2), 2), # Main axis + colorbar
((3, 3), 2),
((4, 4), 2)
])
def test_heatmap_shapes(matrix_shape, expected_axes, output_dir, style_config):
"""Test heatmap plotting with different matrix shapes."""
plotter = MatrixPlotter(output_dir, style_config)
matrix = np.random.rand(*matrix_shape)
fig = plotter.plot_heatmap(
matrix=matrix,
title=f"Test {matrix_shape[0]}x{matrix_shape[1]} Heatmap",
save_name=f"test_heatmap_{matrix_shape[0]}x{matrix_shape[1]}"
)
assert len(fig.axes) == expected_axes
# Get the data from the heatmap
heatmap_data = fig.axes[0].collections[0].get_array()
# Reshape the flattened data back to the original shape
heatmap_data = heatmap_data.reshape(matrix_shape)
assert heatmap_data.shape == matrix_shape
# Verify file was saved
assert (output_dir / f"test_heatmap_{matrix_shape[0]}x{matrix_shape[1]}.png").exists()