зеркало из
https://github.com/docxology/cognitive.git
synced 2025-10-30 12:46:04 +02:00
320 строки
11 KiB
Python
320 строки
11 KiB
Python
"""
|
|
Homeostatic control implementation for BioFirm framework.
|
|
"""
|
|
|
|
import sys
|
|
from pathlib import Path
|
|
from abc import ABC, abstractmethod
|
|
from dataclasses import dataclass
|
|
from typing import Dict, List, Optional, Tuple, Union, Any
|
|
import numpy as np
|
|
import yaml
|
|
import logging
|
|
|
|
# Setup logging
|
|
logger = logging.getLogger(__name__)
|
|
|
|
@dataclass
|
|
class StateSpace:
|
|
"""Abstract representation of state spaces in active inference models."""
|
|
dimensions: List[int]
|
|
labels: Dict[str, List[str]]
|
|
mappings: Dict[str, np.ndarray]
|
|
hierarchical_levels: Optional[int] = 1
|
|
|
|
def validate(self) -> bool:
|
|
"""Validate state space configuration."""
|
|
try:
|
|
# Check dimensions match labels
|
|
for dim, label_list in zip(self.dimensions, self.labels.values()):
|
|
if len(label_list) != dim:
|
|
return False
|
|
# Check mappings are valid
|
|
for mapping in self.mappings.values():
|
|
if not isinstance(mapping, np.ndarray):
|
|
return False
|
|
return True
|
|
except Exception as e:
|
|
logger.error(f"Validation error: {str(e)}")
|
|
return False
|
|
|
|
@dataclass
|
|
class ModelState:
|
|
"""Represents the current state of an Active Inference model."""
|
|
beliefs: np.ndarray
|
|
policies: np.ndarray
|
|
precision: float
|
|
free_energy: float
|
|
prediction_error: float
|
|
|
|
def validate(self) -> bool:
|
|
"""Validate model state."""
|
|
try:
|
|
if not isinstance(self.beliefs, np.ndarray) or not isinstance(self.policies, np.ndarray):
|
|
return False
|
|
if not isinstance(self.precision, float) or not isinstance(self.free_energy, float):
|
|
return False
|
|
if not isinstance(self.prediction_error, float):
|
|
return False
|
|
return True
|
|
except Exception as e:
|
|
logger.error(f"State validation error: {str(e)}")
|
|
return False
|
|
|
|
class ControlMode(ABC):
|
|
"""Abstract base class for control modes."""
|
|
|
|
@abstractmethod
|
|
def compute_policy_prior(self,
|
|
state: ModelState,
|
|
goal: np.ndarray) -> np.ndarray:
|
|
"""Compute policy prior based on control mode."""
|
|
pass
|
|
|
|
def validate_inputs(self, state: ModelState, goal: np.ndarray) -> bool:
|
|
"""Validate inputs for policy computation."""
|
|
try:
|
|
if not isinstance(state, ModelState) or not state.validate():
|
|
return False
|
|
if not isinstance(goal, np.ndarray):
|
|
return False
|
|
return True
|
|
except Exception as e:
|
|
logger.error(f"Input validation error: {str(e)}")
|
|
return False
|
|
|
|
class HomestaticControl(ControlMode):
|
|
"""Homeostatic control mode implementation."""
|
|
|
|
def __init__(self,
|
|
bounds: Tuple[float, float],
|
|
target_state: Union[str, int],
|
|
weight: float = 1.0):
|
|
self.bounds = bounds
|
|
self.target_state = target_state
|
|
self.weight = weight
|
|
|
|
def compute_policy_prior(self,
|
|
state: ModelState,
|
|
goal: np.ndarray) -> np.ndarray:
|
|
"""Compute homeostatic control policy prior."""
|
|
if not self.validate_inputs(state, goal):
|
|
raise ValueError("Invalid inputs for policy computation")
|
|
|
|
try:
|
|
deviation = np.abs(state.beliefs - goal)
|
|
return np.exp(-self.weight * deviation)
|
|
except Exception as e:
|
|
logger.error(f"Error computing policy prior: {str(e)}")
|
|
raise
|
|
|
|
class AdaptiveControl(ControlMode):
|
|
"""Adaptive control mode implementation."""
|
|
|
|
def __init__(self,
|
|
learning_rate: float = 0.1,
|
|
exploration_weight: float = 0.3):
|
|
self.learning_rate = learning_rate
|
|
self.exploration_weight = exploration_weight
|
|
|
|
def compute_policy_prior(self,
|
|
state: ModelState,
|
|
goal: np.ndarray) -> np.ndarray:
|
|
"""Compute adaptive control policy prior."""
|
|
if not self.validate_inputs(state, goal):
|
|
raise ValueError("Invalid inputs for policy computation")
|
|
|
|
try:
|
|
# Balance exploitation and exploration
|
|
exploitation = -np.abs(state.beliefs - goal)
|
|
exploration = -state.prediction_error * np.ones_like(state.beliefs)
|
|
return np.exp(
|
|
(1 - self.exploration_weight) * exploitation +
|
|
self.exploration_weight * exploration
|
|
)
|
|
except Exception as e:
|
|
logger.error(f"Error computing adaptive policy prior: {str(e)}")
|
|
raise
|
|
|
|
class HomeostaticInference:
|
|
"""Homeostatic control using Active Inference."""
|
|
|
|
def __init__(self,
|
|
config_path: Union[str, Path],
|
|
control_mode: ControlMode):
|
|
"""Initialize homeostatic inference.
|
|
|
|
Args:
|
|
config_path: Path to configuration file
|
|
control_mode: Control mode instance
|
|
"""
|
|
self.config_path = Path(config_path)
|
|
self.control_mode = control_mode
|
|
self.config = self._load_config()
|
|
self._initialize_matrices()
|
|
self.state = self._initialize_state()
|
|
|
|
def _load_config(self) -> Dict:
|
|
"""Load model configuration."""
|
|
try:
|
|
with open(self.config_path, 'r') as f:
|
|
config = yaml.safe_load(f)
|
|
self._validate_config(config)
|
|
return config
|
|
except Exception as e:
|
|
logger.error(f"Error loading config: {str(e)}")
|
|
raise
|
|
|
|
def _validate_config(self, config: Dict) -> bool:
|
|
"""Validate configuration structure."""
|
|
required_fields = [
|
|
'observation_model',
|
|
'transition_model',
|
|
'preference_model',
|
|
'prior_beliefs'
|
|
]
|
|
|
|
for field in required_fields:
|
|
if field not in config:
|
|
raise ValueError(f"Missing required config field: {field}")
|
|
return True
|
|
|
|
def _initialize_matrices(self):
|
|
"""Initialize model matrices."""
|
|
try:
|
|
# Initialize observation model (A matrix)
|
|
self.A = np.array(self.config.get('observation_model', np.eye(5)))
|
|
|
|
# Initialize transition model (B matrix)
|
|
self.B = np.array(self.config.get('transition_model', np.eye(5)))
|
|
|
|
# Initialize preference model (C matrix)
|
|
self.C = np.array(self.config.get('preference_model', np.zeros(5)))
|
|
|
|
# Initialize prior beliefs (D matrix)
|
|
self.D = np.array(self.config.get('prior_beliefs', np.ones(5) / 5))
|
|
|
|
self._validate_matrices()
|
|
except Exception as e:
|
|
logger.error(f"Error initializing matrices: {str(e)}")
|
|
raise
|
|
|
|
def _validate_matrices(self):
|
|
"""Validate matrix dimensions and properties."""
|
|
if self.A.shape[0] != self.A.shape[1]:
|
|
raise ValueError("A matrix must be square")
|
|
if self.B.shape[0] != self.B.shape[1]:
|
|
raise ValueError("B matrix must be square")
|
|
if len(self.C) != self.A.shape[0]:
|
|
raise ValueError("C vector dimension mismatch")
|
|
if len(self.D) != self.A.shape[0]:
|
|
raise ValueError("D vector dimension mismatch")
|
|
|
|
def _initialize_state(self) -> ModelState:
|
|
"""Initialize model state."""
|
|
try:
|
|
state = ModelState(
|
|
beliefs=self.D.copy(),
|
|
policies=np.ones(len(self.D)) / len(self.D),
|
|
precision=1.0,
|
|
free_energy=0.0,
|
|
prediction_error=0.0
|
|
)
|
|
if not state.validate():
|
|
raise ValueError("Invalid initial state")
|
|
return state
|
|
except Exception as e:
|
|
logger.error(f"Error initializing state: {str(e)}")
|
|
raise
|
|
|
|
def update_beliefs(self, observation: np.ndarray) -> np.ndarray:
|
|
"""Update beliefs using active inference."""
|
|
try:
|
|
# Generate prediction
|
|
prediction = np.dot(self.A, self.state.beliefs)
|
|
|
|
# Compute prediction error
|
|
prediction_error = observation - prediction
|
|
|
|
# Update beliefs using precision-weighted prediction errors
|
|
belief_update = self.state.precision * prediction_error
|
|
self.state.beliefs += belief_update
|
|
|
|
# Normalize beliefs
|
|
self.state.beliefs = self.state.beliefs / np.sum(self.state.beliefs)
|
|
|
|
# Update prediction error
|
|
self.state.prediction_error = np.mean(np.square(prediction_error))
|
|
|
|
return self.state.beliefs
|
|
except Exception as e:
|
|
logger.error(f"Error updating beliefs: {str(e)}")
|
|
raise
|
|
|
|
def select_action(self) -> int:
|
|
"""Select action using active inference."""
|
|
try:
|
|
# Compute expected free energy for each policy
|
|
expected_free_energy = self._compute_expected_free_energy()
|
|
|
|
# Get policy prior from control mode
|
|
policy_prior = self.control_mode.compute_policy_prior(
|
|
self.state,
|
|
self.config["target_state"]
|
|
)
|
|
|
|
# Combine expected free energy and prior
|
|
policies = self._softmax(-expected_free_energy + np.log(policy_prior))
|
|
|
|
# Update state
|
|
self.state.policies = policies
|
|
|
|
return np.argmax(policies)
|
|
except Exception as e:
|
|
logger.error(f"Error selecting action: {str(e)}")
|
|
raise
|
|
|
|
def _compute_expected_free_energy(self) -> np.ndarray:
|
|
"""Compute expected free energy for each policy."""
|
|
try:
|
|
n_policies = len(self.state.beliefs)
|
|
expected_free_energy = np.zeros(n_policies)
|
|
|
|
for i in range(n_policies):
|
|
# Compute predicted next state
|
|
predicted_state = np.dot(self.B[:, :, i], self.state.beliefs)
|
|
|
|
# Compute predicted observation
|
|
predicted_obs = np.dot(self.A, predicted_state)
|
|
|
|
# Compute expected free energy components
|
|
ambiguity = -np.sum(predicted_obs * np.log(predicted_obs + 1e-8))
|
|
risk = np.sum(predicted_state * self.C)
|
|
|
|
expected_free_energy[i] = ambiguity + risk
|
|
|
|
return expected_free_energy
|
|
except Exception as e:
|
|
logger.error(f"Error computing expected free energy: {str(e)}")
|
|
raise
|
|
|
|
def _softmax(self, x: np.ndarray) -> np.ndarray:
|
|
"""Compute softmax values."""
|
|
try:
|
|
exp_x = np.exp(x - np.max(x))
|
|
return exp_x / np.sum(exp_x)
|
|
except Exception as e:
|
|
logger.error(f"Error computing softmax: {str(e)}")
|
|
raise
|
|
|
|
def update_precision(self, beta: float = 0.9) -> float:
|
|
"""Update precision based on prediction errors."""
|
|
try:
|
|
self.state.precision = (
|
|
beta * self.state.precision +
|
|
(1 - beta) * (1.0 / (self.state.prediction_error + 1e-8))
|
|
)
|
|
return self.state.precision
|
|
except Exception as e:
|
|
logger.error(f"Error updating precision: {str(e)}")
|
|
raise |