зеркало из
https://github.com/docxology/cognitive.git
synced 2025-10-30 12:46:04 +02:00
393 строки
14 KiB
Python
393 строки
14 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Run script for SimplePOMDP simulation.
|
|
|
|
This script follows a structured cognitive modeling workflow:
|
|
1. Model Configuration
|
|
2. Matrix Validation & Visualization
|
|
3. Model Component Analysis
|
|
4. Simulation Execution
|
|
5. Results Analysis & Visualization
|
|
"""
|
|
|
|
import numpy as np
|
|
from pathlib import Path
|
|
import matplotlib
|
|
matplotlib.use('Agg') # Use non-interactive backend
|
|
import matplotlib.pyplot as plt
|
|
import seaborn as sns
|
|
from simple_pomdp import SimplePOMDP
|
|
|
|
def create_config():
|
|
"""Create configuration for the POMDP model."""
|
|
return {
|
|
'model': {
|
|
'name': 'SimplePOMDP',
|
|
'description': 'Three-state POMDP with Active Inference',
|
|
'version': '0.1.0'
|
|
},
|
|
'state_space': {
|
|
'num_states': 3,
|
|
'state_labels': ['Low', 'Medium', 'High'],
|
|
'initial_state': 1 # Start in Medium state
|
|
},
|
|
'observation_space': {
|
|
'num_observations': 3,
|
|
'observation_labels': ['Low', 'Medium', 'High'],
|
|
},
|
|
'action_space': {
|
|
'num_actions': 3,
|
|
'action_labels': ['Decrease', 'Stay', 'Increase']
|
|
},
|
|
'matrices': {
|
|
'A_matrix': {
|
|
'shape': [3, 3], # [num_observations, num_states]
|
|
'initialization': 'identity_based',
|
|
'initialization_params': {'strength': 0.7}, # 70% accurate observations
|
|
'constraints': ['column_stochastic']
|
|
},
|
|
'B_matrix': {
|
|
'shape': [3, 3, 3], # [next_state, current_state, action]
|
|
'initialization': 'custom',
|
|
'initialization_params': {
|
|
'strength': 0.8, # 80% success rate for actions
|
|
'transitions': {
|
|
'Decrease': [ # Action 0: Tend to decrease state
|
|
[0.8, 0.2, 0.0], # From Low
|
|
[0.7, 0.2, 0.1], # From Medium
|
|
[0.2, 0.7, 0.1] # From High
|
|
],
|
|
'Stay': [ # Action 1: Tend to maintain state
|
|
[0.8, 0.1, 0.1], # From Low
|
|
[0.1, 0.8, 0.1], # From Medium
|
|
[0.1, 0.1, 0.8] # From High
|
|
],
|
|
'Increase': [ # Action 2: Tend to increase state
|
|
[0.1, 0.7, 0.2], # From Low
|
|
[0.1, 0.2, 0.7], # From Medium
|
|
[0.0, 0.2, 0.8] # From High
|
|
]
|
|
}
|
|
},
|
|
'constraints': ['column_stochastic']
|
|
},
|
|
'C_matrix': {
|
|
'shape': [3], # [num_observations] - log preferences over observations
|
|
'initialization': 'log_preferences',
|
|
'initialization_params': {
|
|
'preferences': [0.1, 2.0, 0.1], # Strong preference for medium observations
|
|
'description': 'Log-preferences: Low=0.1 (avoid), Medium=2.0 (prefer), High=0.1 (avoid)'
|
|
}
|
|
},
|
|
'D_matrix': {
|
|
'shape': [3], # [num_states] - initial state prior
|
|
'initialization': 'uniform',
|
|
'description': 'Uniform prior over states'
|
|
},
|
|
'E_matrix': {
|
|
'shape': [3], # [num_actions] - initial action prior
|
|
'initialization': 'uniform',
|
|
'description': 'Initial uniform prior over actions (Decrease, Stay, Increase)',
|
|
'learning_rate': 0.2 # Rate at which policy prior is updated
|
|
}
|
|
},
|
|
'inference': {
|
|
'time_horizon': 1, # Single-step policies
|
|
'num_iterations': 10,
|
|
'learning_rate': 0.5, # For belief updates
|
|
'temperature': 1.0, # For policy selection
|
|
'policy_learning_rate': 0.2 # For E matrix updates
|
|
},
|
|
'visualization': {
|
|
'output_dir': 'Output',
|
|
'style': {
|
|
'figure_size': (10, 8),
|
|
'dpi': 100,
|
|
'colormap': 'RdYlBu_r', # Better for showing preferences
|
|
'colormap_3d': 'viridis',
|
|
'font_size': 12,
|
|
'file_format': 'png'
|
|
}
|
|
}
|
|
}
|
|
|
|
def validate_matrices(model, output_dir: Path):
|
|
"""Validate and visualize model matrices.
|
|
|
|
Args:
|
|
model: SimplePOMDP model instance
|
|
output_dir: Output directory for visualizations
|
|
"""
|
|
print("\n=== Matrix Validation and Visualization ===")
|
|
|
|
# Create validation report
|
|
report = []
|
|
report.append("Matrix Validation Report")
|
|
report.append("======================")
|
|
|
|
# A Matrix (Observation Model)
|
|
report.append("\nA Matrix (Observation Model):")
|
|
report.append("- Shape: {}".format(model.A.shape))
|
|
report.append("- Column stochastic: {}".format(np.allclose(model.A.sum(axis=0), 1.0)))
|
|
report.append("- Non-negative: {}".format(np.all(model.A >= 0)))
|
|
report.append("\nObservation probabilities:")
|
|
for i, obs in enumerate(model.config['observation_space']['observation_labels']):
|
|
for j, state in enumerate(model.config['state_space']['state_labels']):
|
|
report.append(f" P({obs}|{state}) = {model.A[i,j]:.3f}")
|
|
|
|
# B Matrix (Transition Model)
|
|
report.append("\nB Matrix (Transition Model):")
|
|
report.append("- Shape: {}".format(model.B.shape))
|
|
for a, action in enumerate(model.config['action_space']['action_labels']):
|
|
report.append(f"\nAction: {action}")
|
|
report.append("- Column stochastic: {}".format(np.allclose(model.B[:,:,a].sum(axis=0), 1.0)))
|
|
report.append("- Non-negative: {}".format(np.all(model.B[:,:,a] >= 0)))
|
|
report.append("\nTransition probabilities:")
|
|
for i, next_state in enumerate(model.config['state_space']['state_labels']):
|
|
for j, curr_state in enumerate(model.config['state_space']['state_labels']):
|
|
report.append(f" P({next_state}|{curr_state},{action}) = {model.B[i,j,a]:.3f}")
|
|
|
|
# C Matrix (Log Preferences)
|
|
report.append("\nC Matrix (Log Preferences over Observations):")
|
|
report.append("- Shape: {}".format(model.C.shape))
|
|
report.append("\nLog preference values:")
|
|
for i, obs in enumerate(model.config['observation_space']['observation_labels']):
|
|
report.append(f" ln P({obs}) = {model.C[i]:.3f}")
|
|
report.append("\nNormalized preference probabilities:")
|
|
probs = np.exp(model.C) / np.sum(np.exp(model.C))
|
|
for i, obs in enumerate(model.config['observation_space']['observation_labels']):
|
|
report.append(f" P({obs}) = {probs[i]:.3f}")
|
|
|
|
# Save report
|
|
report_file = output_dir / "matrix_validation.txt"
|
|
with open(report_file, "w") as f:
|
|
f.write("\n".join(report))
|
|
print(f"Validation report saved to: {report_file}")
|
|
|
|
# Visualize matrices
|
|
print("\nGenerating matrix visualizations...")
|
|
|
|
# Plot A matrix
|
|
fig, ax = plt.subplots(figsize=(8, 6))
|
|
sns.heatmap(model.A,
|
|
annot=True,
|
|
fmt='.2f',
|
|
xticklabels=model.config['state_space']['state_labels'],
|
|
yticklabels=model.config['observation_space']['observation_labels'],
|
|
ax=ax)
|
|
ax.set_title('A Matrix: Observation Model')
|
|
ax.set_xlabel('State')
|
|
ax.set_ylabel('Observation')
|
|
plt.tight_layout()
|
|
plt.savefig(output_dir / 'A_matrix.png')
|
|
plt.close()
|
|
|
|
# Plot B matrices (one for each action)
|
|
fig, axes = plt.subplots(1, model.B.shape[2], figsize=(15, 5))
|
|
for a, (ax, action) in enumerate(zip(axes, model.config['action_space']['action_labels'])):
|
|
sns.heatmap(model.B[:,:,a],
|
|
annot=True,
|
|
fmt='.2f',
|
|
xticklabels=model.config['state_space']['state_labels'],
|
|
yticklabels=model.config['state_space']['state_labels'],
|
|
ax=ax)
|
|
ax.set_title(f'B Matrix: {action} Action')
|
|
ax.set_xlabel('Current State')
|
|
ax.set_ylabel('Next State')
|
|
plt.tight_layout()
|
|
plt.savefig(output_dir / 'B_matrices.png')
|
|
plt.close()
|
|
|
|
# Plot C matrix (log preferences and probabilities)
|
|
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))
|
|
|
|
# Log preferences
|
|
sns.barplot(
|
|
x=model.config['observation_space']['observation_labels'],
|
|
y=model.C,
|
|
ax=ax1
|
|
)
|
|
ax1.set_title('C Matrix: Log Preferences')
|
|
ax1.set_xlabel('Observation')
|
|
ax1.set_ylabel('Log Preference')
|
|
|
|
# Normalized probabilities
|
|
sns.barplot(
|
|
x=model.config['observation_space']['observation_labels'],
|
|
y=probs,
|
|
ax=ax2
|
|
)
|
|
ax2.set_title('Normalized Preference Probabilities')
|
|
ax2.set_xlabel('Observation')
|
|
ax2.set_ylabel('Probability')
|
|
|
|
plt.tight_layout()
|
|
plt.savefig(output_dir / 'C_preferences.png')
|
|
plt.close()
|
|
|
|
# Plot all matrices combined
|
|
fig = plt.figure(figsize=(20, 12))
|
|
gs = plt.GridSpec(3, 4, height_ratios=[1, 1, 0.8])
|
|
|
|
# A matrix (top-left)
|
|
ax1 = fig.add_subplot(gs[0, 0])
|
|
sns.heatmap(model.A,
|
|
annot=True,
|
|
fmt='.2f',
|
|
xticklabels=model.config['state_space']['state_labels'],
|
|
yticklabels=model.config['observation_space']['observation_labels'],
|
|
ax=ax1)
|
|
ax1.set_title('A: Observation Model P(o|s)')
|
|
|
|
# B matrices (top-right three plots)
|
|
for a in range(model.B.shape[2]):
|
|
ax = fig.add_subplot(gs[0, a+1])
|
|
sns.heatmap(model.B[:,:,a],
|
|
annot=True,
|
|
fmt='.2f',
|
|
xticklabels=model.config['state_space']['state_labels'],
|
|
yticklabels=model.config['state_space']['state_labels'],
|
|
ax=ax)
|
|
ax.set_title(f'B: Transition P(s\'|s,{model.config["action_space"]["action_labels"][a]})')
|
|
|
|
# C matrix (middle-left)
|
|
ax_c = fig.add_subplot(gs[1, 0:2])
|
|
sns.barplot(
|
|
x=model.config['observation_space']['observation_labels'],
|
|
y=model.C,
|
|
ax=ax_c
|
|
)
|
|
ax_c.set_title('C: Log Preferences ln P(o)')
|
|
ax_c.set_xlabel('Observation')
|
|
ax_c.set_ylabel('Log Preference')
|
|
|
|
# D matrix (middle-right)
|
|
ax_d = fig.add_subplot(gs[1, 2:])
|
|
sns.barplot(
|
|
x=model.config['state_space']['state_labels'],
|
|
y=model.D,
|
|
ax=ax_d
|
|
)
|
|
ax_d.set_title('D: Initial State Prior P(s₁)')
|
|
ax_d.set_xlabel('State')
|
|
ax_d.set_ylabel('Probability')
|
|
|
|
# E matrix (bottom)
|
|
ax_e = fig.add_subplot(gs[2, :])
|
|
sns.barplot(
|
|
x=model.config['action_space']['action_labels'],
|
|
y=model.E,
|
|
ax=ax_e
|
|
)
|
|
ax_e.set_title('E: Action Prior P(a)')
|
|
ax_e.set_xlabel('Action')
|
|
ax_e.set_ylabel('Probability')
|
|
|
|
# Add overall title
|
|
fig.suptitle('Active Inference POMDP Model Components', fontsize=16, y=0.98)
|
|
|
|
plt.tight_layout()
|
|
plt.savefig(output_dir / 'all_matrices.png', bbox_inches='tight', dpi=150)
|
|
plt.close()
|
|
|
|
print("Matrix visualizations saved to output directory")
|
|
|
|
def run_simulation(n_steps: int = 20):
|
|
"""Run the POMDP simulation.
|
|
|
|
Args:
|
|
n_steps: Number of steps to simulate
|
|
"""
|
|
# Create output directory
|
|
output_dir = Path('Output')
|
|
output_dir.mkdir(exist_ok=True)
|
|
|
|
print("\n=== Model Configuration and Initialization ===")
|
|
config = create_config()
|
|
model = SimplePOMDP(config)
|
|
|
|
# First validate and visualize matrices
|
|
validate_matrices(model, output_dir)
|
|
|
|
print("\n=== Starting Simulation ===")
|
|
print(f"Initial state: {config['state_space']['state_labels'][model.state.current_state]}")
|
|
|
|
# Create simulation log file
|
|
log_file = output_dir / "simulation_log.txt"
|
|
with open(log_file, "w") as f:
|
|
f.write("Simulation Log\n")
|
|
f.write("==============\n\n")
|
|
|
|
# Run simulation
|
|
for step in range(n_steps):
|
|
# Take a step and get observation and free energy
|
|
obs, vfe = model.step()
|
|
|
|
# Get current state and action
|
|
state = model.state.current_state
|
|
action = model.state.history['actions'][-1]
|
|
|
|
# Get current expected free energies
|
|
efe = model.state.history['expected_fe'][-1]
|
|
|
|
# Format step information
|
|
step_info = [
|
|
f"\nStep {step + 1}:",
|
|
f"State: {config['state_space']['state_labels'][state]}",
|
|
f"Observation: {config['observation_space']['observation_labels'][obs]}",
|
|
f"Action: {config['action_space']['action_labels'][action]}",
|
|
f"Variational FE: {vfe:.3f}",
|
|
"\nExpected Free Energies:"
|
|
]
|
|
|
|
# Add Expected Free Energies for each action
|
|
for a, efe_a in enumerate(efe):
|
|
step_info.append(f"{config['action_space']['action_labels'][a]}: {efe_a:.3f}")
|
|
|
|
step_info.append("\nBeliefs:")
|
|
# Add belief distribution
|
|
beliefs = model.state.beliefs
|
|
for label, prob in zip(config['state_space']['state_labels'], beliefs):
|
|
step_info.append(f"{label}: {prob:.3f}")
|
|
|
|
# Write to log and print to console
|
|
log_text = "\n".join(step_info)
|
|
f.write(log_text + "\n")
|
|
print(log_text)
|
|
|
|
print("\n=== Generating Simulation Visualizations ===")
|
|
|
|
# Plot belief evolution
|
|
model.visualize("belief_evolution")
|
|
|
|
# Plot state transitions
|
|
model.visualize("state_transitions")
|
|
|
|
# Plot observation likelihood
|
|
model.visualize("observation_likelihood")
|
|
|
|
# Plot action history
|
|
model.visualize("action_history")
|
|
|
|
# Plot belief history
|
|
model.visualize("belief_history")
|
|
|
|
# Plot Free Energies (both VFE and EFE)
|
|
model.visualize("free_energies")
|
|
|
|
# Plot policy evolution
|
|
model.visualize("policy_evolution")
|
|
|
|
# Plot detailed EFE components
|
|
print("\nGenerating detailed EFE visualization...")
|
|
model.visualize("efe_components_detailed")
|
|
|
|
print(f"\nSimulation results and visualizations saved to: {output_dir.absolute()}")
|
|
|
|
# Print list of generated files
|
|
print("\nGenerated visualization files:")
|
|
for file in sorted(output_dir.glob("*.png")):
|
|
print(f"- {file.name}")
|
|
|
|
if __name__ == "__main__":
|
|
run_simulation() |