cognitive/Things/BioFirm/homeostatic.py
Daniel Ari Friedman 9de628ee16 Updates
2025-02-07 11:08:50 -08:00

320 строки
11 KiB
Python

"""
Homeostatic control implementation for BioFirm framework.
"""
import sys
from pathlib import Path
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple, Union, Any
import numpy as np
import yaml
import logging
# Setup logging
logger = logging.getLogger(__name__)
@dataclass
class StateSpace:
"""Abstract representation of state spaces in active inference models."""
dimensions: List[int]
labels: Dict[str, List[str]]
mappings: Dict[str, np.ndarray]
hierarchical_levels: Optional[int] = 1
def validate(self) -> bool:
"""Validate state space configuration."""
try:
# Check dimensions match labels
for dim, label_list in zip(self.dimensions, self.labels.values()):
if len(label_list) != dim:
return False
# Check mappings are valid
for mapping in self.mappings.values():
if not isinstance(mapping, np.ndarray):
return False
return True
except Exception as e:
logger.error(f"Validation error: {str(e)}")
return False
@dataclass
class ModelState:
"""Represents the current state of an Active Inference model."""
beliefs: np.ndarray
policies: np.ndarray
precision: float
free_energy: float
prediction_error: float
def validate(self) -> bool:
"""Validate model state."""
try:
if not isinstance(self.beliefs, np.ndarray) or not isinstance(self.policies, np.ndarray):
return False
if not isinstance(self.precision, float) or not isinstance(self.free_energy, float):
return False
if not isinstance(self.prediction_error, float):
return False
return True
except Exception as e:
logger.error(f"State validation error: {str(e)}")
return False
class ControlMode(ABC):
"""Abstract base class for control modes."""
@abstractmethod
def compute_policy_prior(self,
state: ModelState,
goal: np.ndarray) -> np.ndarray:
"""Compute policy prior based on control mode."""
pass
def validate_inputs(self, state: ModelState, goal: np.ndarray) -> bool:
"""Validate inputs for policy computation."""
try:
if not isinstance(state, ModelState) or not state.validate():
return False
if not isinstance(goal, np.ndarray):
return False
return True
except Exception as e:
logger.error(f"Input validation error: {str(e)}")
return False
class HomestaticControl(ControlMode):
"""Homeostatic control mode implementation."""
def __init__(self,
bounds: Tuple[float, float],
target_state: Union[str, int],
weight: float = 1.0):
self.bounds = bounds
self.target_state = target_state
self.weight = weight
def compute_policy_prior(self,
state: ModelState,
goal: np.ndarray) -> np.ndarray:
"""Compute homeostatic control policy prior."""
if not self.validate_inputs(state, goal):
raise ValueError("Invalid inputs for policy computation")
try:
deviation = np.abs(state.beliefs - goal)
return np.exp(-self.weight * deviation)
except Exception as e:
logger.error(f"Error computing policy prior: {str(e)}")
raise
class AdaptiveControl(ControlMode):
"""Adaptive control mode implementation."""
def __init__(self,
learning_rate: float = 0.1,
exploration_weight: float = 0.3):
self.learning_rate = learning_rate
self.exploration_weight = exploration_weight
def compute_policy_prior(self,
state: ModelState,
goal: np.ndarray) -> np.ndarray:
"""Compute adaptive control policy prior."""
if not self.validate_inputs(state, goal):
raise ValueError("Invalid inputs for policy computation")
try:
# Balance exploitation and exploration
exploitation = -np.abs(state.beliefs - goal)
exploration = -state.prediction_error * np.ones_like(state.beliefs)
return np.exp(
(1 - self.exploration_weight) * exploitation +
self.exploration_weight * exploration
)
except Exception as e:
logger.error(f"Error computing adaptive policy prior: {str(e)}")
raise
class HomeostaticInference:
"""Homeostatic control using Active Inference."""
def __init__(self,
config_path: Union[str, Path],
control_mode: ControlMode):
"""Initialize homeostatic inference.
Args:
config_path: Path to configuration file
control_mode: Control mode instance
"""
self.config_path = Path(config_path)
self.control_mode = control_mode
self.config = self._load_config()
self._initialize_matrices()
self.state = self._initialize_state()
def _load_config(self) -> Dict:
"""Load model configuration."""
try:
with open(self.config_path, 'r') as f:
config = yaml.safe_load(f)
self._validate_config(config)
return config
except Exception as e:
logger.error(f"Error loading config: {str(e)}")
raise
def _validate_config(self, config: Dict) -> bool:
"""Validate configuration structure."""
required_fields = [
'observation_model',
'transition_model',
'preference_model',
'prior_beliefs'
]
for field in required_fields:
if field not in config:
raise ValueError(f"Missing required config field: {field}")
return True
def _initialize_matrices(self):
"""Initialize model matrices."""
try:
# Initialize observation model (A matrix)
self.A = np.array(self.config.get('observation_model', np.eye(5)))
# Initialize transition model (B matrix)
self.B = np.array(self.config.get('transition_model', np.eye(5)))
# Initialize preference model (C matrix)
self.C = np.array(self.config.get('preference_model', np.zeros(5)))
# Initialize prior beliefs (D matrix)
self.D = np.array(self.config.get('prior_beliefs', np.ones(5) / 5))
self._validate_matrices()
except Exception as e:
logger.error(f"Error initializing matrices: {str(e)}")
raise
def _validate_matrices(self):
"""Validate matrix dimensions and properties."""
if self.A.shape[0] != self.A.shape[1]:
raise ValueError("A matrix must be square")
if self.B.shape[0] != self.B.shape[1]:
raise ValueError("B matrix must be square")
if len(self.C) != self.A.shape[0]:
raise ValueError("C vector dimension mismatch")
if len(self.D) != self.A.shape[0]:
raise ValueError("D vector dimension mismatch")
def _initialize_state(self) -> ModelState:
"""Initialize model state."""
try:
state = ModelState(
beliefs=self.D.copy(),
policies=np.ones(len(self.D)) / len(self.D),
precision=1.0,
free_energy=0.0,
prediction_error=0.0
)
if not state.validate():
raise ValueError("Invalid initial state")
return state
except Exception as e:
logger.error(f"Error initializing state: {str(e)}")
raise
def update_beliefs(self, observation: np.ndarray) -> np.ndarray:
"""Update beliefs using active inference."""
try:
# Generate prediction
prediction = np.dot(self.A, self.state.beliefs)
# Compute prediction error
prediction_error = observation - prediction
# Update beliefs using precision-weighted prediction errors
belief_update = self.state.precision * prediction_error
self.state.beliefs += belief_update
# Normalize beliefs
self.state.beliefs = self.state.beliefs / np.sum(self.state.beliefs)
# Update prediction error
self.state.prediction_error = np.mean(np.square(prediction_error))
return self.state.beliefs
except Exception as e:
logger.error(f"Error updating beliefs: {str(e)}")
raise
def select_action(self) -> int:
"""Select action using active inference."""
try:
# Compute expected free energy for each policy
expected_free_energy = self._compute_expected_free_energy()
# Get policy prior from control mode
policy_prior = self.control_mode.compute_policy_prior(
self.state,
self.config["target_state"]
)
# Combine expected free energy and prior
policies = self._softmax(-expected_free_energy + np.log(policy_prior))
# Update state
self.state.policies = policies
return np.argmax(policies)
except Exception as e:
logger.error(f"Error selecting action: {str(e)}")
raise
def _compute_expected_free_energy(self) -> np.ndarray:
"""Compute expected free energy for each policy."""
try:
n_policies = len(self.state.beliefs)
expected_free_energy = np.zeros(n_policies)
for i in range(n_policies):
# Compute predicted next state
predicted_state = np.dot(self.B[:, :, i], self.state.beliefs)
# Compute predicted observation
predicted_obs = np.dot(self.A, predicted_state)
# Compute expected free energy components
ambiguity = -np.sum(predicted_obs * np.log(predicted_obs + 1e-8))
risk = np.sum(predicted_state * self.C)
expected_free_energy[i] = ambiguity + risk
return expected_free_energy
except Exception as e:
logger.error(f"Error computing expected free energy: {str(e)}")
raise
def _softmax(self, x: np.ndarray) -> np.ndarray:
"""Compute softmax values."""
try:
exp_x = np.exp(x - np.max(x))
return exp_x / np.sum(exp_x)
except Exception as e:
logger.error(f"Error computing softmax: {str(e)}")
raise
def update_precision(self, beta: float = 0.9) -> float:
"""Update precision based on prediction errors."""
try:
self.state.precision = (
beta * self.state.precision +
(1 - beta) * (1.0 / (self.state.prediction_error + 1e-8))
)
return self.state.precision
except Exception as e:
logger.error(f"Error updating precision: {str(e)}")
raise