Daniel Ari Friedman 6caa1a7cb1 Update
2025-02-07 08:16:25 -08:00

245 строки
8.7 KiB
Python

"""
Analysis utilities for Continuous Active Inference.
This module provides tools for analyzing and visualizing the relationship between
generalized coordinates, Taylor series expansions, and belief updating in the
continuous-time active inference framework.
"""
import numpy as np
import matplotlib.pyplot as plt
from scipy.special import factorial
from pathlib import Path
from typing import List, Optional, Tuple, Union
import logging
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class ContinuousAnalyzer:
"""Analysis tools for continuous active inference."""
def __init__(self, output_dir: Union[str, Path]):
"""Initialize analyzer.
Args:
output_dir: Directory to save analysis plots
"""
self.output_dir = Path(output_dir)
self.output_dir.mkdir(parents=True, exist_ok=True)
# Create analysis subdirectories
(self.output_dir / 'taylor_series').mkdir(exist_ok=True)
(self.output_dir / 'generalized_coords').mkdir(exist_ok=True)
(self.output_dir / 'belief_dynamics').mkdir(exist_ok=True)
(self.output_dir / 'prediction_errors').mkdir(exist_ok=True)
def plot_taylor_expansion(self,
states: np.ndarray,
time_points: np.ndarray,
orders: List[int],
save_path: Optional[Union[str, Path]] = None) -> None:
"""Plot Taylor series expansion of state trajectory.
Args:
states: State trajectory in generalized coordinates [n_states, n_orders]
time_points: Time points for expansion
orders: List of orders to include in expansion
save_path: Path to save plot
"""
plt.figure(figsize=(12, 8))
# Plot actual trajectory
plt.plot(time_points, states[:, 0], 'k-', label='Actual', linewidth=2)
# Plot Taylor expansions of different orders
colors = plt.cm.viridis(np.linspace(0, 1, len(orders)))
t0 = time_points[0]
x0 = states[0]
for order, color in zip(orders, colors):
# Compute Taylor expansion
expansion = np.zeros_like(time_points)
for n in range(order + 1):
expansion += (x0[n] / factorial(n)) * (time_points - t0)**n
plt.plot(time_points, expansion, '--',
color=color,
label=f'Order {order}',
alpha=0.7)
plt.title('Taylor Series Expansion of State Trajectory')
plt.xlabel('Time')
plt.ylabel('State')
plt.legend()
plt.grid(True)
if save_path:
plt.savefig(save_path)
plt.close()
def plot_generalized_coordinates(self,
states: np.ndarray,
time_points: np.ndarray,
save_path: Optional[Union[str, Path]] = None) -> None:
"""Plot state representation in generalized coordinates.
Args:
states: State trajectory in generalized coordinates [n_states, n_orders]
time_points: Time points
save_path: Path to save plot
"""
n_orders = states.shape[1]
fig, axes = plt.subplots(n_orders, 1, figsize=(12, 4*n_orders))
if n_orders == 1:
axes = [axes]
for i, ax in enumerate(axes):
ax.plot(time_points, states[:, i], 'b-', linewidth=2)
ax.set_title(f'Order {i} (d^{i}x/dt^{i})')
ax.set_xlabel('Time')
ax.set_ylabel('Value')
ax.grid(True)
plt.tight_layout()
if save_path:
plt.savefig(save_path)
plt.close()
def plot_belief_dynamics(self,
belief_means: List[np.ndarray],
belief_precisions: List[np.ndarray],
time_points: np.ndarray,
save_path: Optional[Union[str, Path]] = None) -> None:
"""Plot belief dynamics in phase space with uncertainty.
Args:
belief_means: List of belief means [n_states, n_orders]
belief_precisions: List of belief precisions [n_states, n_orders]
time_points: Time points
save_path: Path to save plot
"""
from matplotlib.patches import Ellipse
plt.figure(figsize=(12, 12))
# Convert to arrays
means = np.array(belief_means)
precisions = np.array(belief_precisions)
# Plot trajectory
plt.plot(means[:, 0, 0], means[:, 1, 0], 'b-', alpha=0.5, label='Trajectory')
# Plot uncertainty ellipses at intervals
n_points = len(time_points)
for i in range(0, n_points, n_points//5):
mean = means[i, :, 0]
prec = precisions[i, :, 0]
# Create covariance ellipse
std = 1.0 / np.sqrt(prec + 1e-8)
ellip = Ellipse(mean, width=2*std[0], height=2*std[1],
alpha=0.2, fc='gray', ec='none')
plt.gca().add_patch(ellip)
# Add time label
plt.annotate(f't={time_points[i]:.2f}',
xy=mean,
xytext=(10, 10),
textcoords='offset points')
plt.title('Belief Dynamics in Phase Space')
plt.xlabel('State 1')
plt.ylabel('State 2')
plt.axis('equal')
plt.grid(True)
if save_path:
plt.savefig(save_path)
plt.close()
def plot_prediction_errors(self,
observations: np.ndarray,
predictions: np.ndarray,
time_points: np.ndarray,
save_path: Optional[Union[str, Path]] = None) -> None:
"""Plot prediction errors across time.
Args:
observations: Actual observations [n_timesteps, n_obs]
predictions: Predicted observations [n_timesteps, n_obs]
time_points: Time points
save_path: Path to save plot
"""
n_obs = observations.shape[1]
fig, axes = plt.subplots(n_obs, 1, figsize=(12, 4*n_obs))
if n_obs == 1:
axes = [axes]
for i, ax in enumerate(axes):
# Plot actual and predicted
ax.plot(time_points, observations[:, i], 'b-', label='Actual', alpha=0.7)
ax.plot(time_points, predictions[:, i], 'r--', label='Predicted', alpha=0.7)
# Plot error
error = observations[:, i] - predictions[:, i]
ax.fill_between(time_points, 0, error,
color='gray', alpha=0.2, label='Error')
ax.set_title(f'Observation {i+1}')
ax.set_xlabel('Time')
ax.set_ylabel('Value')
ax.legend()
ax.grid(True)
plt.tight_layout()
if save_path:
plt.savefig(save_path)
plt.close()
def analyze_convergence(self,
free_energies: np.ndarray,
time_points: np.ndarray,
window_size: int = 10,
save_path: Optional[Union[str, Path]] = None) -> None:
"""Analyze convergence of free energy minimization.
Args:
free_energies: Free energy values over time
time_points: Time points
window_size: Window size for moving statistics
save_path: Path to save plot
"""
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 8))
# Plot raw free energy
ax1.plot(time_points, free_energies, 'b-', alpha=0.7)
ax1.set_title('Free Energy Evolution')
ax1.set_xlabel('Time')
ax1.set_ylabel('Free Energy')
ax1.grid(True)
# Compute and plot convergence metrics
dF = np.diff(free_energies)
conv_rate = np.zeros_like(free_energies[:-window_size])
for i in range(len(conv_rate)):
conv_rate[i] = np.mean(np.abs(dF[i:i+window_size]))
ax2.plot(time_points[:-window_size], conv_rate, 'r-', alpha=0.7)
ax2.set_title('Convergence Rate (Moving Average of |dF/dt|)')
ax2.set_xlabel('Time')
ax2.set_ylabel('Rate')
ax2.set_yscale('log')
ax2.grid(True)
plt.tight_layout()
if save_path:
plt.savefig(save_path)
plt.close()