зеркало из
https://github.com/docxology/cognitive.git
synced 2025-10-30 04:36:05 +02:00
335 строки
12 KiB
Python
335 строки
12 KiB
Python
"""
|
|
Nestmate Agent Implementation
|
|
|
|
This module implements the Nestmate agent class, which represents an individual ant
|
|
in the colony using the Free Energy Principle (FEP) and Active Inference framework.
|
|
"""
|
|
|
|
import numpy as np
|
|
from typing import Dict, List, Tuple, Optional
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from dataclasses import dataclass
|
|
from enum import Enum
|
|
|
|
class TaskType(Enum):
|
|
"""Possible task types for a Nestmate agent."""
|
|
FORAGING = "foraging"
|
|
MAINTENANCE = "maintenance"
|
|
NURSING = "nursing"
|
|
DEFENSE = "defense"
|
|
EXPLORATION = "exploration"
|
|
|
|
@dataclass
|
|
class Observation:
|
|
"""Container for sensory observations."""
|
|
pheromone: np.ndarray # Pheromone gradients
|
|
food: np.ndarray # Food source locations
|
|
nestmates: np.ndarray # Other agent positions
|
|
obstacles: np.ndarray # Obstacle positions
|
|
nest: np.ndarray # Nest location/gradient
|
|
|
|
class GenerativeModel(nn.Module):
|
|
"""Hierarchical generative model for active inference."""
|
|
|
|
def __init__(self, config: dict):
|
|
super().__init__()
|
|
|
|
# Model dimensions
|
|
self.obs_dim = config['dimensions']['observations']
|
|
self.state_dim = config['dimensions']['states']
|
|
self.action_dim = config['dimensions']['actions']
|
|
self.temporal_horizon = config['dimensions']['planning_horizon']
|
|
|
|
# Hierarchical layers
|
|
self.layers = nn.ModuleList([
|
|
nn.Linear(self.state_dim, self.state_dim)
|
|
for _ in range(config['active_inference']['model']['hierarchical_levels'])
|
|
])
|
|
|
|
# State transition model (dynamics)
|
|
self.transition = nn.Sequential(
|
|
nn.Linear(self.state_dim + self.action_dim, self.state_dim * 2),
|
|
nn.ReLU(),
|
|
nn.Linear(self.state_dim * 2, self.state_dim)
|
|
)
|
|
|
|
# Observation model
|
|
self.observation = nn.Sequential(
|
|
nn.Linear(self.state_dim, self.obs_dim * 2),
|
|
nn.ReLU(),
|
|
nn.Linear(self.obs_dim * 2, self.obs_dim)
|
|
)
|
|
|
|
# Policy network
|
|
self.policy = nn.Sequential(
|
|
nn.Linear(self.state_dim, self.action_dim * 2),
|
|
nn.ReLU(),
|
|
nn.Linear(self.action_dim * 2, self.action_dim)
|
|
)
|
|
|
|
# Precision parameters
|
|
self.alpha = nn.Parameter(torch.ones(1)) # Precision of beliefs
|
|
self.beta = nn.Parameter(torch.ones(1)) # Precision of policies
|
|
|
|
def forward(self,
|
|
state: torch.Tensor,
|
|
action: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
"""Forward pass through the generative model."""
|
|
|
|
# Hierarchical state processing
|
|
for layer in self.layers:
|
|
state = F.relu(layer(state))
|
|
|
|
# Generate observations
|
|
predicted_obs = self.observation(state)
|
|
|
|
# If action provided, predict next state
|
|
if action is not None:
|
|
state_action = torch.cat([state, action], dim=-1)
|
|
next_state = self.transition(state_action)
|
|
return predicted_obs, next_state
|
|
|
|
return predicted_obs, None
|
|
|
|
def infer_state(self,
|
|
obs: torch.Tensor,
|
|
prev_state: Optional[torch.Tensor] = None,
|
|
n_steps: int = 10) -> torch.Tensor:
|
|
"""Infer hidden state through iterative message passing."""
|
|
|
|
if prev_state is None:
|
|
state = torch.zeros(obs.shape[0], self.state_dim)
|
|
else:
|
|
state = prev_state
|
|
|
|
state.requires_grad = True
|
|
optimizer = torch.optim.Adam([state], lr=0.1)
|
|
|
|
for _ in range(n_steps):
|
|
optimizer.zero_grad()
|
|
|
|
# Prediction errors
|
|
pred_obs, _ = self.forward(state)
|
|
obs_error = F.mse_loss(pred_obs, obs)
|
|
|
|
if prev_state is not None:
|
|
state_error = F.mse_loss(state, prev_state)
|
|
loss = obs_error + self.alpha * state_error
|
|
else:
|
|
loss = obs_error
|
|
|
|
loss.backward()
|
|
optimizer.step()
|
|
|
|
return state.detach()
|
|
|
|
def select_action(self,
|
|
state: torch.Tensor,
|
|
temperature: float = 1.0) -> torch.Tensor:
|
|
"""Select action using active inference."""
|
|
|
|
# Get action distribution
|
|
action_logits = self.policy(state)
|
|
action_probs = F.softmax(action_logits / temperature, dim=-1)
|
|
|
|
# Sample action
|
|
action = torch.multinomial(action_probs, 1)
|
|
|
|
return action
|
|
|
|
class Nestmate:
|
|
"""
|
|
Individual ant agent implementing active inference for decision making.
|
|
"""
|
|
|
|
def __init__(self, config: dict):
|
|
"""Initialize Nestmate agent."""
|
|
self.config = config
|
|
|
|
# Physical state
|
|
self.position = np.zeros(2)
|
|
self.velocity = np.zeros(2)
|
|
self.orientation = 0.0
|
|
self.energy = config['physical']['energy']['initial']
|
|
|
|
# Task state
|
|
self.current_task = TaskType.EXPLORATION
|
|
self.carrying = None
|
|
|
|
# Sensory state
|
|
self.observations = Observation(
|
|
pheromone=np.zeros(config['sensors']['pheromone']['types'].__len__()),
|
|
food=np.zeros(2),
|
|
nestmates=np.zeros(2),
|
|
obstacles=np.zeros(2),
|
|
nest=np.zeros(2)
|
|
)
|
|
|
|
# Active inference components
|
|
self.generative_model = GenerativeModel(config)
|
|
self.current_state = None
|
|
self.previous_action = None
|
|
|
|
# Memory
|
|
self.memory = {
|
|
'spatial': [],
|
|
'temporal': [],
|
|
'social': []
|
|
}
|
|
|
|
# Learning parameters
|
|
self.learning_rate = config['learning']['parameters']['learning_rate']
|
|
self.exploration_rate = config['learning']['parameters']['exploration_rate']
|
|
|
|
def update(self, observation: Observation) -> np.ndarray:
|
|
"""
|
|
Update agent state and select action using active inference.
|
|
|
|
Args:
|
|
observation: Current sensory observations
|
|
|
|
Returns:
|
|
action: Selected action as numpy array
|
|
"""
|
|
# Convert observation to tensor
|
|
obs_tensor = torch.tensor(self._preprocess_observation(observation))
|
|
|
|
# State inference
|
|
inferred_state = self.generative_model.infer_state(
|
|
obs_tensor,
|
|
prev_state=self.current_state
|
|
)
|
|
self.current_state = inferred_state
|
|
|
|
# Action selection
|
|
action = self.generative_model.select_action(
|
|
inferred_state,
|
|
temperature=self.config['active_inference']['free_energy']['temperature']
|
|
)
|
|
|
|
# Update memory
|
|
self._update_memory(observation, action)
|
|
|
|
# Update internal state
|
|
self._update_internal_state()
|
|
|
|
return action.numpy()
|
|
|
|
def _preprocess_observation(self, observation: Observation) -> np.ndarray:
|
|
"""Preprocess raw observations into model input format."""
|
|
# Combine all observations into single vector
|
|
obs_vector = np.concatenate([
|
|
observation.pheromone,
|
|
observation.food,
|
|
observation.nestmates,
|
|
observation.obstacles,
|
|
observation.nest
|
|
])
|
|
|
|
# Normalize
|
|
obs_vector = (obs_vector - obs_vector.mean()) / (obs_vector.std() + 1e-8)
|
|
|
|
return obs_vector
|
|
|
|
def _update_memory(self, observation: Observation, action: torch.Tensor):
|
|
"""Update agent's memory systems."""
|
|
# Spatial memory
|
|
self.memory['spatial'].append({
|
|
'position': self.position.copy(),
|
|
'observation': observation,
|
|
'timestamp': None # Add actual timestamp in implementation
|
|
})
|
|
|
|
# Temporal memory
|
|
self.memory['temporal'].append({
|
|
'state': self.current_state.detach().numpy(),
|
|
'action': action.numpy(),
|
|
'reward': self._compute_reward(observation)
|
|
})
|
|
|
|
# Social memory (interactions with other agents)
|
|
if np.any(observation.nestmates):
|
|
self.memory['social'].append({
|
|
'nestmate_positions': observation.nestmates.copy(),
|
|
'interaction_type': self._classify_interaction(observation)
|
|
})
|
|
|
|
# Maintain memory size limits
|
|
for memory_type in self.memory:
|
|
if len(self.memory[memory_type]) > self.config['memory'][memory_type]['capacity']:
|
|
self.memory[memory_type].pop(0)
|
|
|
|
def _update_internal_state(self):
|
|
"""Update agent's internal state variables."""
|
|
# Update energy
|
|
self.energy -= self.config['physical']['energy']['consumption_rate']
|
|
if self.carrying is not None:
|
|
self.energy -= self.config['physical']['energy']['consumption_rate'] * 2
|
|
|
|
# Update task if needed
|
|
if self._should_switch_task():
|
|
self._switch_task()
|
|
|
|
# Update learning parameters
|
|
self.exploration_rate *= self.config['learning']['parameters']['decay_rate']
|
|
self.exploration_rate = max(
|
|
self.exploration_rate,
|
|
self.config['learning']['parameters']['min_exploration']
|
|
)
|
|
|
|
def _compute_reward(self, observation: Observation) -> float:
|
|
"""Compute reward signal from current observation."""
|
|
reward = 0.0
|
|
|
|
# Task-specific rewards
|
|
if self.current_task == TaskType.FORAGING:
|
|
reward += np.sum(observation.food) * self.config['active_inference']['preferences']['food_weight']
|
|
|
|
# Distance to nest reward
|
|
nest_distance = np.linalg.norm(observation.nest)
|
|
reward -= nest_distance * self.config['active_inference']['preferences']['home_weight']
|
|
|
|
# Safety reward (avoiding obstacles)
|
|
obstacle_penalty = np.sum(1.0 / (1.0 + np.linalg.norm(observation.obstacles, axis=1)))
|
|
reward -= obstacle_penalty * self.config['active_inference']['preferences']['safety_weight']
|
|
|
|
# Social reward
|
|
if np.any(observation.nestmates):
|
|
social_reward = self.config['active_inference']['preferences']['social_weight']
|
|
reward += social_reward
|
|
|
|
return reward
|
|
|
|
def _should_switch_task(self) -> bool:
|
|
"""Determine if agent should switch its current task."""
|
|
# Energy-based switching
|
|
if self.energy < self.config['physical']['energy']['critical_level']:
|
|
return True
|
|
|
|
# Random switching based on flexibility
|
|
if np.random.random() < self.config['behavior']['task_switching']['flexibility']:
|
|
return True
|
|
|
|
return False
|
|
|
|
def _switch_task(self):
|
|
"""Switch to a new task based on current conditions."""
|
|
# Get valid task options
|
|
valid_tasks = list(TaskType)
|
|
if self.current_task in valid_tasks:
|
|
valid_tasks.remove(self.current_task)
|
|
|
|
# Select new task (can be made more sophisticated)
|
|
self.current_task = np.random.choice(valid_tasks)
|
|
|
|
def _classify_interaction(self, observation: Observation) -> str:
|
|
"""Classify type of interaction with nearby nestmates."""
|
|
# Simple distance-based classification
|
|
distances = np.linalg.norm(observation.nestmates, axis=1)
|
|
if np.any(distances < 1.0):
|
|
return "direct"
|
|
elif np.any(distances < 3.0):
|
|
return "indirect"
|
|
return "none" |