cognitive/docs/guides/implementation/predictive_network.md
Daniel Ari Friedman 7ca29024d8 Updates
2025-02-12 13:23:51 -08:00

7.7 KiB

title type status created tags semantic_relations
Predictive Network Implementation implementation_guide stable 2024-02-12
implementation
predictive-processing
neural-networks
type links
implements
../../learning_paths/predictive_processing
type links
relates
error_propagation
precision_mechanisms

Predictive Network Implementation

Overview

This guide provides a detailed implementation of a basic predictive processing network, focusing on the core mechanisms of prediction generation and error computation.

Architecture

Network Structure

class PredictiveLayer:
    def __init__(self, input_size: int, hidden_size: int, output_size: int):
        """Initialize a predictive processing layer.
        
        Args:
            input_size: Size of input features
            hidden_size: Size of hidden representation
            output_size: Size of predictions
        """
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.output_size = output_size
        
        # Initialize weights
        self.W_hidden = torch.randn(input_size, hidden_size) * 0.1
        self.W_pred = torch.randn(hidden_size, output_size) * 0.1
        
        # Initialize biases
        self.b_hidden = torch.zeros(hidden_size)
        self.b_pred = torch.zeros(output_size)
        
        # Initialize precision (inverse variance)
        self.precision = torch.ones(output_size)

class PredictiveNetwork:
    def __init__(self, layer_sizes: List[int]):
        """Initialize hierarchical predictive network.
        
        Args:
            layer_sizes: List of layer sizes from bottom to top
        """
        self.layers = []
        for i in range(len(layer_sizes) - 1):
            layer = PredictiveLayer(
                input_size=layer_sizes[i],
                hidden_size=layer_sizes[i] * 2,
                output_size=layer_sizes[i + 1]
            )
            self.layers.append(layer)

Forward Pass

def forward(self, input_data: torch.Tensor) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
    """Forward pass through the network.
    
    Args:
        input_data: Input tensor
        
    Returns:
        predictions: List of predictions at each layer
        prediction_errors: List of prediction errors
    """
    current_input = input_data
    predictions = []
    prediction_errors = []
    
    # Bottom-up pass
    for layer in self.layers:
        # Generate prediction
        hidden = torch.tanh(current_input @ layer.W_hidden + layer.b_hidden)
        prediction = hidden @ layer.W_pred + layer.b_pred
        
        # Compute prediction error
        if len(predictions) > 0:
            error = current_input - prediction
            weighted_error = error * layer.precision
            prediction_errors.append(weighted_error)
        
        predictions.append(prediction)
        current_input = prediction
    
    return predictions, prediction_errors

Error Computation

def compute_errors(self, 
                  predictions: List[torch.Tensor], 
                  targets: List[torch.Tensor]) -> List[torch.Tensor]:
    """Compute prediction errors at each layer.
    
    Args:
        predictions: List of predictions
        targets: List of target values
        
    Returns:
        errors: List of prediction errors
    """
    errors = []
    for pred, target, layer in zip(predictions, targets, self.layers):
        error = target - pred
        weighted_error = error * layer.precision
        errors.append(weighted_error)
    return errors

Training

Loss Function

def compute_loss(self, 
                prediction_errors: List[torch.Tensor], 
                precision_errors: List[torch.Tensor]) -> torch.Tensor:
    """Compute total loss from prediction and precision errors.
    
    Args:
        prediction_errors: List of prediction errors
        precision_errors: List of precision estimation errors
        
    Returns:
        total_loss: Combined loss value
    """
    # Prediction error loss
    pred_loss = sum(torch.mean(error ** 2) for error in prediction_errors)
    
    # Precision error loss
    prec_loss = sum(torch.mean(error ** 2) for error in precision_errors)
    
    return pred_loss + 0.1 * prec_loss

Update Step

def update_step(self, 
                loss: torch.Tensor,
                learning_rate: float = 0.01):
    """Perform one update step.
    
    Args:
        loss: Loss value
        learning_rate: Learning rate for updates
    """
    # Compute gradients
    gradients = torch.autograd.grad(loss, self.parameters())
    
    # Update parameters
    with torch.no_grad():
        for param, grad in zip(self.parameters(), gradients):
            param -= learning_rate * grad

Usage Example

Basic Training Loop

# Initialize network
layer_sizes = [64, 32, 16]  # Example sizes
network = PredictiveNetwork(layer_sizes)

# Training loop
for epoch in range(num_epochs):
    for batch in data_loader:
        # Forward pass
        predictions, errors = network.forward(batch.inputs)
        
        # Compute loss
        loss = network.compute_loss(errors, [])
        
        # Update step
        network.update_step(loss)
        
        # Log progress
        if epoch % 100 == 0:
            print(f"Epoch {epoch}, Loss: {loss.item():.4f}")

Prediction Generation

def generate_predictions(self, input_data: torch.Tensor) -> List[torch.Tensor]:
    """Generate predictions for input data.
    
    Args:
        input_data: Input tensor
        
    Returns:
        predictions: List of predictions at each layer
    """
    predictions, _ = self.forward(input_data)
    return predictions

Advanced Features

Precision Estimation

def estimate_precision(self, 
                      errors: List[torch.Tensor],
                      window_size: int = 100) -> List[torch.Tensor]:
    """Estimate precision based on prediction errors.
    
    Args:
        errors: List of prediction errors
        window_size: Window size for estimation
        
    Returns:
        precisions: Updated precision estimates
    """
    precisions = []
    for error in errors:
        # Compute running variance
        var = torch.mean(error ** 2, dim=0)
        # Update precision (inverse variance)
        precision = 1.0 / (var + 1e-6)
        precisions.append(precision)
    return precisions

Layer Normalization

def normalize_layer(self, 
                   activations: torch.Tensor,
                   epsilon: float = 1e-5) -> torch.Tensor:
    """Apply layer normalization.
    
    Args:
        activations: Layer activations
        epsilon: Small constant for numerical stability
        
    Returns:
        normalized: Normalized activations
    """
    mean = torch.mean(activations, dim=-1, keepdim=True)
    std = torch.std(activations, dim=-1, keepdim=True)
    return (activations - mean) / (std + epsilon)

Best Practices

Initialization

  1. Use small random weights
  2. Initialize biases to zero
  3. Set reasonable precision values
  4. Validate layer sizes

Training

  1. Monitor convergence
  2. Use appropriate learning rates
  3. Implement early stopping
  4. Save checkpoints

Validation

  1. Test prediction accuracy
  2. Check error distributions
  3. Validate precision estimates
  4. Monitor layer activities

Common Issues

Numerical Stability

  1. Use layer normalization
  2. Add small constants to divisions
  3. Clip gradient values
  4. Monitor activation ranges

Performance

  1. Batch processing
  2. GPU acceleration
  3. Memory management
  4. Efficient updates