cognitive/Things/Simple_POMDP/run_simple_pomdp.py
Daniel Ari Friedman 6caa1a7cb1 Update
2025-02-07 08:16:25 -08:00

393 строки
14 KiB
Python

#!/usr/bin/env python3
"""
Run script for SimplePOMDP simulation.
This script follows a structured cognitive modeling workflow:
1. Model Configuration
2. Matrix Validation & Visualization
3. Model Component Analysis
4. Simulation Execution
5. Results Analysis & Visualization
"""
import numpy as np
from pathlib import Path
import matplotlib
matplotlib.use('Agg') # Use non-interactive backend
import matplotlib.pyplot as plt
import seaborn as sns
from simple_pomdp import SimplePOMDP
def create_config():
"""Create configuration for the POMDP model."""
return {
'model': {
'name': 'SimplePOMDP',
'description': 'Three-state POMDP with Active Inference',
'version': '0.1.0'
},
'state_space': {
'num_states': 3,
'state_labels': ['Low', 'Medium', 'High'],
'initial_state': 1 # Start in Medium state
},
'observation_space': {
'num_observations': 3,
'observation_labels': ['Low', 'Medium', 'High'],
},
'action_space': {
'num_actions': 3,
'action_labels': ['Decrease', 'Stay', 'Increase']
},
'matrices': {
'A_matrix': {
'shape': [3, 3], # [num_observations, num_states]
'initialization': 'identity_based',
'initialization_params': {'strength': 0.7}, # 70% accurate observations
'constraints': ['column_stochastic']
},
'B_matrix': {
'shape': [3, 3, 3], # [next_state, current_state, action]
'initialization': 'custom',
'initialization_params': {
'strength': 0.8, # 80% success rate for actions
'transitions': {
'Decrease': [ # Action 0: Tend to decrease state
[0.8, 0.2, 0.0], # From Low
[0.7, 0.2, 0.1], # From Medium
[0.2, 0.7, 0.1] # From High
],
'Stay': [ # Action 1: Tend to maintain state
[0.8, 0.1, 0.1], # From Low
[0.1, 0.8, 0.1], # From Medium
[0.1, 0.1, 0.8] # From High
],
'Increase': [ # Action 2: Tend to increase state
[0.1, 0.7, 0.2], # From Low
[0.1, 0.2, 0.7], # From Medium
[0.0, 0.2, 0.8] # From High
]
}
},
'constraints': ['column_stochastic']
},
'C_matrix': {
'shape': [3], # [num_observations] - log preferences over observations
'initialization': 'log_preferences',
'initialization_params': {
'preferences': [0.1, 2.0, 0.1], # Strong preference for medium observations
'description': 'Log-preferences: Low=0.1 (avoid), Medium=2.0 (prefer), High=0.1 (avoid)'
}
},
'D_matrix': {
'shape': [3], # [num_states] - initial state prior
'initialization': 'uniform',
'description': 'Uniform prior over states'
},
'E_matrix': {
'shape': [3], # [num_actions] - initial action prior
'initialization': 'uniform',
'description': 'Initial uniform prior over actions (Decrease, Stay, Increase)',
'learning_rate': 0.2 # Rate at which policy prior is updated
}
},
'inference': {
'time_horizon': 1, # Single-step policies
'num_iterations': 10,
'learning_rate': 0.5, # For belief updates
'temperature': 1.0, # For policy selection
'policy_learning_rate': 0.2 # For E matrix updates
},
'visualization': {
'output_dir': 'Output',
'style': {
'figure_size': (10, 8),
'dpi': 100,
'colormap': 'RdYlBu_r', # Better for showing preferences
'colormap_3d': 'viridis',
'font_size': 12,
'file_format': 'png'
}
}
}
def validate_matrices(model, output_dir: Path):
"""Validate and visualize model matrices.
Args:
model: SimplePOMDP model instance
output_dir: Output directory for visualizations
"""
print("\n=== Matrix Validation and Visualization ===")
# Create validation report
report = []
report.append("Matrix Validation Report")
report.append("======================")
# A Matrix (Observation Model)
report.append("\nA Matrix (Observation Model):")
report.append("- Shape: {}".format(model.A.shape))
report.append("- Column stochastic: {}".format(np.allclose(model.A.sum(axis=0), 1.0)))
report.append("- Non-negative: {}".format(np.all(model.A >= 0)))
report.append("\nObservation probabilities:")
for i, obs in enumerate(model.config['observation_space']['observation_labels']):
for j, state in enumerate(model.config['state_space']['state_labels']):
report.append(f" P({obs}|{state}) = {model.A[i,j]:.3f}")
# B Matrix (Transition Model)
report.append("\nB Matrix (Transition Model):")
report.append("- Shape: {}".format(model.B.shape))
for a, action in enumerate(model.config['action_space']['action_labels']):
report.append(f"\nAction: {action}")
report.append("- Column stochastic: {}".format(np.allclose(model.B[:,:,a].sum(axis=0), 1.0)))
report.append("- Non-negative: {}".format(np.all(model.B[:,:,a] >= 0)))
report.append("\nTransition probabilities:")
for i, next_state in enumerate(model.config['state_space']['state_labels']):
for j, curr_state in enumerate(model.config['state_space']['state_labels']):
report.append(f" P({next_state}|{curr_state},{action}) = {model.B[i,j,a]:.3f}")
# C Matrix (Log Preferences)
report.append("\nC Matrix (Log Preferences over Observations):")
report.append("- Shape: {}".format(model.C.shape))
report.append("\nLog preference values:")
for i, obs in enumerate(model.config['observation_space']['observation_labels']):
report.append(f" ln P({obs}) = {model.C[i]:.3f}")
report.append("\nNormalized preference probabilities:")
probs = np.exp(model.C) / np.sum(np.exp(model.C))
for i, obs in enumerate(model.config['observation_space']['observation_labels']):
report.append(f" P({obs}) = {probs[i]:.3f}")
# Save report
report_file = output_dir / "matrix_validation.txt"
with open(report_file, "w") as f:
f.write("\n".join(report))
print(f"Validation report saved to: {report_file}")
# Visualize matrices
print("\nGenerating matrix visualizations...")
# Plot A matrix
fig, ax = plt.subplots(figsize=(8, 6))
sns.heatmap(model.A,
annot=True,
fmt='.2f',
xticklabels=model.config['state_space']['state_labels'],
yticklabels=model.config['observation_space']['observation_labels'],
ax=ax)
ax.set_title('A Matrix: Observation Model')
ax.set_xlabel('State')
ax.set_ylabel('Observation')
plt.tight_layout()
plt.savefig(output_dir / 'A_matrix.png')
plt.close()
# Plot B matrices (one for each action)
fig, axes = plt.subplots(1, model.B.shape[2], figsize=(15, 5))
for a, (ax, action) in enumerate(zip(axes, model.config['action_space']['action_labels'])):
sns.heatmap(model.B[:,:,a],
annot=True,
fmt='.2f',
xticklabels=model.config['state_space']['state_labels'],
yticklabels=model.config['state_space']['state_labels'],
ax=ax)
ax.set_title(f'B Matrix: {action} Action')
ax.set_xlabel('Current State')
ax.set_ylabel('Next State')
plt.tight_layout()
plt.savefig(output_dir / 'B_matrices.png')
plt.close()
# Plot C matrix (log preferences and probabilities)
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))
# Log preferences
sns.barplot(
x=model.config['observation_space']['observation_labels'],
y=model.C,
ax=ax1
)
ax1.set_title('C Matrix: Log Preferences')
ax1.set_xlabel('Observation')
ax1.set_ylabel('Log Preference')
# Normalized probabilities
sns.barplot(
x=model.config['observation_space']['observation_labels'],
y=probs,
ax=ax2
)
ax2.set_title('Normalized Preference Probabilities')
ax2.set_xlabel('Observation')
ax2.set_ylabel('Probability')
plt.tight_layout()
plt.savefig(output_dir / 'C_preferences.png')
plt.close()
# Plot all matrices combined
fig = plt.figure(figsize=(20, 12))
gs = plt.GridSpec(3, 4, height_ratios=[1, 1, 0.8])
# A matrix (top-left)
ax1 = fig.add_subplot(gs[0, 0])
sns.heatmap(model.A,
annot=True,
fmt='.2f',
xticklabels=model.config['state_space']['state_labels'],
yticklabels=model.config['observation_space']['observation_labels'],
ax=ax1)
ax1.set_title('A: Observation Model P(o|s)')
# B matrices (top-right three plots)
for a in range(model.B.shape[2]):
ax = fig.add_subplot(gs[0, a+1])
sns.heatmap(model.B[:,:,a],
annot=True,
fmt='.2f',
xticklabels=model.config['state_space']['state_labels'],
yticklabels=model.config['state_space']['state_labels'],
ax=ax)
ax.set_title(f'B: Transition P(s\'|s,{model.config["action_space"]["action_labels"][a]})')
# C matrix (middle-left)
ax_c = fig.add_subplot(gs[1, 0:2])
sns.barplot(
x=model.config['observation_space']['observation_labels'],
y=model.C,
ax=ax_c
)
ax_c.set_title('C: Log Preferences ln P(o)')
ax_c.set_xlabel('Observation')
ax_c.set_ylabel('Log Preference')
# D matrix (middle-right)
ax_d = fig.add_subplot(gs[1, 2:])
sns.barplot(
x=model.config['state_space']['state_labels'],
y=model.D,
ax=ax_d
)
ax_d.set_title('D: Initial State Prior P(s₁)')
ax_d.set_xlabel('State')
ax_d.set_ylabel('Probability')
# E matrix (bottom)
ax_e = fig.add_subplot(gs[2, :])
sns.barplot(
x=model.config['action_space']['action_labels'],
y=model.E,
ax=ax_e
)
ax_e.set_title('E: Action Prior P(a)')
ax_e.set_xlabel('Action')
ax_e.set_ylabel('Probability')
# Add overall title
fig.suptitle('Active Inference POMDP Model Components', fontsize=16, y=0.98)
plt.tight_layout()
plt.savefig(output_dir / 'all_matrices.png', bbox_inches='tight', dpi=150)
plt.close()
print("Matrix visualizations saved to output directory")
def run_simulation(n_steps: int = 20):
"""Run the POMDP simulation.
Args:
n_steps: Number of steps to simulate
"""
# Create output directory
output_dir = Path('Output')
output_dir.mkdir(exist_ok=True)
print("\n=== Model Configuration and Initialization ===")
config = create_config()
model = SimplePOMDP(config)
# First validate and visualize matrices
validate_matrices(model, output_dir)
print("\n=== Starting Simulation ===")
print(f"Initial state: {config['state_space']['state_labels'][model.state.current_state]}")
# Create simulation log file
log_file = output_dir / "simulation_log.txt"
with open(log_file, "w") as f:
f.write("Simulation Log\n")
f.write("==============\n\n")
# Run simulation
for step in range(n_steps):
# Take a step and get observation and free energy
obs, vfe = model.step()
# Get current state and action
state = model.state.current_state
action = model.state.history['actions'][-1]
# Get current expected free energies
efe = model.state.history['expected_fe'][-1]
# Format step information
step_info = [
f"\nStep {step + 1}:",
f"State: {config['state_space']['state_labels'][state]}",
f"Observation: {config['observation_space']['observation_labels'][obs]}",
f"Action: {config['action_space']['action_labels'][action]}",
f"Variational FE: {vfe:.3f}",
"\nExpected Free Energies:"
]
# Add Expected Free Energies for each action
for a, efe_a in enumerate(efe):
step_info.append(f"{config['action_space']['action_labels'][a]}: {efe_a:.3f}")
step_info.append("\nBeliefs:")
# Add belief distribution
beliefs = model.state.beliefs
for label, prob in zip(config['state_space']['state_labels'], beliefs):
step_info.append(f"{label}: {prob:.3f}")
# Write to log and print to console
log_text = "\n".join(step_info)
f.write(log_text + "\n")
print(log_text)
print("\n=== Generating Simulation Visualizations ===")
# Plot belief evolution
model.visualize("belief_evolution")
# Plot state transitions
model.visualize("state_transitions")
# Plot observation likelihood
model.visualize("observation_likelihood")
# Plot action history
model.visualize("action_history")
# Plot belief history
model.visualize("belief_history")
# Plot Free Energies (both VFE and EFE)
model.visualize("free_energies")
# Plot policy evolution
model.visualize("policy_evolution")
# Plot detailed EFE components
print("\nGenerating detailed EFE visualization...")
model.visualize("efe_components_detailed")
print(f"\nSimulation results and visualizations saved to: {output_dir.absolute()}")
# Print list of generated files
print("\nGenerated visualization files:")
for file in sorted(output_dir.glob("*.png")):
print(f"- {file.name}")
if __name__ == "__main__":
run_simulation()