cognitive/docs/guides/implementation/error_propagation.md
Daniel Ari Friedman 7ca29024d8 Updates
2025-02-12 13:23:51 -08:00

8.5 KiB

title type status created tags semantic_relations
Error Propagation Implementation implementation_guide stable 2024-02-12
implementation
predictive-processing
error-handling
type links
implements
../../learning_paths/predictive_processing
type links
relates
predictive_network
precision_mechanisms

Error Propagation Implementation

Overview

This guide details the implementation of error propagation mechanisms in predictive processing networks, focusing on both forward and backward message passing.

Core Mechanisms

Error Types

class PredictionError:
    def __init__(self, 
                 predicted: torch.Tensor,
                 actual: torch.Tensor,
                 precision: torch.Tensor):
        """Initialize prediction error.
        
        Args:
            predicted: Predicted values
            actual: Actual values
            precision: Precision (inverse variance)
        """
        self.predicted = predicted
        self.actual = actual
        self.precision = precision
        self.error = self._compute_error()
        
    def _compute_error(self) -> torch.Tensor:
        """Compute weighted prediction error."""
        raw_error = self.actual - self.predicted
        return raw_error * self.precision

Message Passing

class MessagePassing:
    def __init__(self, network: PredictiveNetwork):
        """Initialize message passing.
        
        Args:
            network: Predictive network instance
        """
        self.network = network
        self.messages_up = []
        self.messages_down = []
    
    def forward_messages(self, 
                        input_data: torch.Tensor) -> List[PredictionError]:
        """Compute forward (bottom-up) messages.
        
        Args:
            input_data: Input tensor
            
        Returns:
            errors: List of prediction errors
        """
        current = input_data
        errors = []
        
        for layer in self.network.layers:
            # Generate prediction
            prediction = layer.forward(current)
            
            # Compute error
            error = PredictionError(
                predicted=prediction,
                actual=current,
                precision=layer.precision
            )
            errors.append(error)
            
            # Update current input
            current = prediction
        
        self.messages_up = errors
        return errors
    
    def backward_messages(self, 
                         top_down_signal: torch.Tensor) -> List[PredictionError]:
        """Compute backward (top-down) messages.
        
        Args:
            top_down_signal: Top-level signal
            
        Returns:
            errors: List of prediction errors
        """
        current = top_down_signal
        errors = []
        
        for layer in reversed(self.network.layers):
            # Generate backward prediction
            prediction = layer.backward(current)
            
            # Compute error
            error = PredictionError(
                predicted=prediction,
                actual=current,
                precision=layer.precision
            )
            errors.append(error)
            
            # Update current signal
            current = prediction
        
        self.messages_down = errors
        return errors

Error Integration

Error Combination

def combine_errors(self,
                  bottom_up: PredictionError,
                  top_down: PredictionError) -> torch.Tensor:
    """Combine bottom-up and top-down errors.
    
    Args:
        bottom_up: Bottom-up prediction error
        top_down: Top-down prediction error
        
    Returns:
        combined: Combined error signal
    """
    # Weight errors by their precisions
    weighted_up = bottom_up.error * bottom_up.precision
    weighted_down = top_down.error * top_down.precision
    
    # Combine weighted errors
    total_precision = bottom_up.precision + top_down.precision
    combined = (weighted_up + weighted_down) / (total_precision + 1e-6)
    
    return combined

Update Rules

def update_layer(self,
                 layer: PredictiveLayer,
                 combined_error: torch.Tensor,
                 learning_rate: float = 0.01):
    """Update layer parameters based on combined error.
    
    Args:
        layer: Layer to update
        combined_error: Combined error signal
        learning_rate: Learning rate for updates
    """
    # Compute gradients
    with torch.enable_grad():
        # Weight updates
        dW_hidden = torch.autograd.grad(
            combined_error.mean(),
            layer.W_hidden
        )[0]
        dW_pred = torch.autograd.grad(
            combined_error.mean(),
            layer.W_pred
        )[0]
        
        # Bias updates
        db_hidden = torch.autograd.grad(
            combined_error.mean(),
            layer.b_hidden
        )[0]
        db_pred = torch.autograd.grad(
            combined_error.mean(),
            layer.b_pred
        )[0]
    
    # Apply updates
    with torch.no_grad():
        layer.W_hidden -= learning_rate * dW_hidden
        layer.W_pred -= learning_rate * dW_pred
        layer.b_hidden -= learning_rate * db_hidden
        layer.b_pred -= learning_rate * db_pred

Implementation Example

Full Network Update

def update_network(self,
                  input_data: torch.Tensor,
                  learning_rate: float = 0.01):
    """Perform full network update.
    
    Args:
        input_data: Input tensor
        learning_rate: Learning rate for updates
    """
    # Forward pass
    forward_errors = self.forward_messages(input_data)
    
    # Generate top-down signal
    top_signal = self.network.layers[-1].forward(forward_errors[-1].actual)
    
    # Backward pass
    backward_errors = self.backward_messages(top_signal)
    
    # Update each layer
    for layer_idx, layer in enumerate(self.network.layers):
        # Combine errors
        combined = self.combine_errors(
            forward_errors[layer_idx],
            backward_errors[-(layer_idx + 1)]
        )
        
        # Update layer
        self.update_layer(layer, combined, learning_rate)

Advanced Features

Error Gating

def gate_error(self,
               error: PredictionError,
               threshold: float = 0.1) -> PredictionError:
    """Gate error signal based on magnitude.
    
    Args:
        error: Prediction error
        threshold: Gating threshold
        
    Returns:
        gated: Gated error signal
    """
    magnitude = torch.abs(error.error)
    mask = magnitude > threshold
    
    gated_error = PredictionError(
        predicted=error.predicted,
        actual=error.actual,
        precision=error.precision * mask.float()
    )
    
    return gated_error

Temporal Integration

def integrate_temporal_errors(self,
                            current_error: PredictionError,
                            previous_errors: List[PredictionError],
                            window_size: int = 5) -> PredictionError:
    """Integrate errors over time.
    
    Args:
        current_error: Current prediction error
        previous_errors: List of previous errors
        window_size: Integration window size
        
    Returns:
        integrated: Temporally integrated error
    """
    # Collect recent errors
    recent_errors = previous_errors[-window_size:]
    recent_errors.append(current_error)
    
    # Compute weighted average
    weights = torch.exp(-torch.arange(len(recent_errors)))
    weights = weights / weights.sum()
    
    integrated_error = sum(
        e.error * w for e, w in zip(recent_errors, weights)
    )
    
    return PredictionError(
        predicted=current_error.predicted,
        actual=current_error.actual,
        precision=current_error.precision,
        error=integrated_error
    )

Best Practices

Error Handling

  1. Validate error magnitudes
  2. Check precision values
  3. Monitor gradients
  4. Handle edge cases

Optimization

  1. Batch processing
  2. Memory management
  3. Computational efficiency
  4. Numerical stability

Debugging

  1. Visualize error flow
  2. Track convergence
  3. Monitor updates
  4. Validate predictions

Common Issues

Stability

  1. Gradient explosion
  2. Vanishing errors
  3. Precision instability
  4. Update oscillations

Solutions

  1. Gradient clipping
  2. Error normalization
  3. Adaptive learning rates
  4. Error gating