--- title: Error Propagation Implementation type: implementation_guide status: stable created: 2024-02-12 tags: - implementation - predictive-processing - error-handling semantic_relations: - type: implements links: [[../../learning_paths/predictive_processing]] - type: relates links: - [[predictive_network]] - [[precision_mechanisms]] --- # Error Propagation Implementation ## Overview This guide details the implementation of error propagation mechanisms in predictive processing networks, focusing on both forward and backward message passing. ## Core Mechanisms ### Error Types ```python class PredictionError: def __init__(self, predicted: torch.Tensor, actual: torch.Tensor, precision: torch.Tensor): """Initialize prediction error. Args: predicted: Predicted values actual: Actual values precision: Precision (inverse variance) """ self.predicted = predicted self.actual = actual self.precision = precision self.error = self._compute_error() def _compute_error(self) -> torch.Tensor: """Compute weighted prediction error.""" raw_error = self.actual - self.predicted return raw_error * self.precision ``` ### Message Passing ```python class MessagePassing: def __init__(self, network: PredictiveNetwork): """Initialize message passing. Args: network: Predictive network instance """ self.network = network self.messages_up = [] self.messages_down = [] def forward_messages(self, input_data: torch.Tensor) -> List[PredictionError]: """Compute forward (bottom-up) messages. Args: input_data: Input tensor Returns: errors: List of prediction errors """ current = input_data errors = [] for layer in self.network.layers: # Generate prediction prediction = layer.forward(current) # Compute error error = PredictionError( predicted=prediction, actual=current, precision=layer.precision ) errors.append(error) # Update current input current = prediction self.messages_up = errors return errors def backward_messages(self, top_down_signal: torch.Tensor) -> List[PredictionError]: """Compute backward (top-down) messages. Args: top_down_signal: Top-level signal Returns: errors: List of prediction errors """ current = top_down_signal errors = [] for layer in reversed(self.network.layers): # Generate backward prediction prediction = layer.backward(current) # Compute error error = PredictionError( predicted=prediction, actual=current, precision=layer.precision ) errors.append(error) # Update current signal current = prediction self.messages_down = errors return errors ``` ## Error Integration ### Error Combination ```python def combine_errors(self, bottom_up: PredictionError, top_down: PredictionError) -> torch.Tensor: """Combine bottom-up and top-down errors. Args: bottom_up: Bottom-up prediction error top_down: Top-down prediction error Returns: combined: Combined error signal """ # Weight errors by their precisions weighted_up = bottom_up.error * bottom_up.precision weighted_down = top_down.error * top_down.precision # Combine weighted errors total_precision = bottom_up.precision + top_down.precision combined = (weighted_up + weighted_down) / (total_precision + 1e-6) return combined ``` ### Update Rules ```python def update_layer(self, layer: PredictiveLayer, combined_error: torch.Tensor, learning_rate: float = 0.01): """Update layer parameters based on combined error. Args: layer: Layer to update combined_error: Combined error signal learning_rate: Learning rate for updates """ # Compute gradients with torch.enable_grad(): # Weight updates dW_hidden = torch.autograd.grad( combined_error.mean(), layer.W_hidden )[0] dW_pred = torch.autograd.grad( combined_error.mean(), layer.W_pred )[0] # Bias updates db_hidden = torch.autograd.grad( combined_error.mean(), layer.b_hidden )[0] db_pred = torch.autograd.grad( combined_error.mean(), layer.b_pred )[0] # Apply updates with torch.no_grad(): layer.W_hidden -= learning_rate * dW_hidden layer.W_pred -= learning_rate * dW_pred layer.b_hidden -= learning_rate * db_hidden layer.b_pred -= learning_rate * db_pred ``` ## Implementation Example ### Full Network Update ```python def update_network(self, input_data: torch.Tensor, learning_rate: float = 0.01): """Perform full network update. Args: input_data: Input tensor learning_rate: Learning rate for updates """ # Forward pass forward_errors = self.forward_messages(input_data) # Generate top-down signal top_signal = self.network.layers[-1].forward(forward_errors[-1].actual) # Backward pass backward_errors = self.backward_messages(top_signal) # Update each layer for layer_idx, layer in enumerate(self.network.layers): # Combine errors combined = self.combine_errors( forward_errors[layer_idx], backward_errors[-(layer_idx + 1)] ) # Update layer self.update_layer(layer, combined, learning_rate) ``` ## Advanced Features ### Error Gating ```python def gate_error(self, error: PredictionError, threshold: float = 0.1) -> PredictionError: """Gate error signal based on magnitude. Args: error: Prediction error threshold: Gating threshold Returns: gated: Gated error signal """ magnitude = torch.abs(error.error) mask = magnitude > threshold gated_error = PredictionError( predicted=error.predicted, actual=error.actual, precision=error.precision * mask.float() ) return gated_error ``` ### Temporal Integration ```python def integrate_temporal_errors(self, current_error: PredictionError, previous_errors: List[PredictionError], window_size: int = 5) -> PredictionError: """Integrate errors over time. Args: current_error: Current prediction error previous_errors: List of previous errors window_size: Integration window size Returns: integrated: Temporally integrated error """ # Collect recent errors recent_errors = previous_errors[-window_size:] recent_errors.append(current_error) # Compute weighted average weights = torch.exp(-torch.arange(len(recent_errors))) weights = weights / weights.sum() integrated_error = sum( e.error * w for e, w in zip(recent_errors, weights) ) return PredictionError( predicted=current_error.predicted, actual=current_error.actual, precision=current_error.precision, error=integrated_error ) ``` ## Best Practices ### Error Handling 1. Validate error magnitudes 2. Check precision values 3. Monitor gradients 4. Handle edge cases ### Optimization 1. Batch processing 2. Memory management 3. Computational efficiency 4. Numerical stability ### Debugging 1. Visualize error flow 2. Track convergence 3. Monitor updates 4. Validate predictions ## Common Issues ### Stability 1. Gradient explosion 2. Vanishing errors 3. Precision instability 4. Update oscillations ### Solutions 1. Gradient clipping 2. Error normalization 3. Adaptive learning rates 4. Error gating ## Related Documentation - [[predictive_network]] - [[precision_mechanisms]] - [[temporal_models]]