зеркало из
				https://github.com/docxology/cognitive.git
				synced 2025-10-31 05:06:04 +02:00 
			
		
		
		
	
		
			
				
	
	
	
		
			8.5 KiB
		
	
	
	
	
	
	
	
			
		
		
	
	
			8.5 KiB
		
	
	
	
	
	
	
	
| title | type | status | created | tags | semantic_relations | |||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| Error Propagation Implementation | implementation_guide | stable | 2024-02-12 | 
 | 
 | 
Error Propagation Implementation
Overview
This guide details the implementation of error propagation mechanisms in predictive processing networks, focusing on both forward and backward message passing.
Core Mechanisms
Error Types
class PredictionError:
    def __init__(self, 
                 predicted: torch.Tensor,
                 actual: torch.Tensor,
                 precision: torch.Tensor):
        """Initialize prediction error.
        
        Args:
            predicted: Predicted values
            actual: Actual values
            precision: Precision (inverse variance)
        """
        self.predicted = predicted
        self.actual = actual
        self.precision = precision
        self.error = self._compute_error()
        
    def _compute_error(self) -> torch.Tensor:
        """Compute weighted prediction error."""
        raw_error = self.actual - self.predicted
        return raw_error * self.precision
Message Passing
class MessagePassing:
    def __init__(self, network: PredictiveNetwork):
        """Initialize message passing.
        
        Args:
            network: Predictive network instance
        """
        self.network = network
        self.messages_up = []
        self.messages_down = []
    
    def forward_messages(self, 
                        input_data: torch.Tensor) -> List[PredictionError]:
        """Compute forward (bottom-up) messages.
        
        Args:
            input_data: Input tensor
            
        Returns:
            errors: List of prediction errors
        """
        current = input_data
        errors = []
        
        for layer in self.network.layers:
            # Generate prediction
            prediction = layer.forward(current)
            
            # Compute error
            error = PredictionError(
                predicted=prediction,
                actual=current,
                precision=layer.precision
            )
            errors.append(error)
            
            # Update current input
            current = prediction
        
        self.messages_up = errors
        return errors
    
    def backward_messages(self, 
                         top_down_signal: torch.Tensor) -> List[PredictionError]:
        """Compute backward (top-down) messages.
        
        Args:
            top_down_signal: Top-level signal
            
        Returns:
            errors: List of prediction errors
        """
        current = top_down_signal
        errors = []
        
        for layer in reversed(self.network.layers):
            # Generate backward prediction
            prediction = layer.backward(current)
            
            # Compute error
            error = PredictionError(
                predicted=prediction,
                actual=current,
                precision=layer.precision
            )
            errors.append(error)
            
            # Update current signal
            current = prediction
        
        self.messages_down = errors
        return errors
Error Integration
Error Combination
def combine_errors(self,
                  bottom_up: PredictionError,
                  top_down: PredictionError) -> torch.Tensor:
    """Combine bottom-up and top-down errors.
    
    Args:
        bottom_up: Bottom-up prediction error
        top_down: Top-down prediction error
        
    Returns:
        combined: Combined error signal
    """
    # Weight errors by their precisions
    weighted_up = bottom_up.error * bottom_up.precision
    weighted_down = top_down.error * top_down.precision
    
    # Combine weighted errors
    total_precision = bottom_up.precision + top_down.precision
    combined = (weighted_up + weighted_down) / (total_precision + 1e-6)
    
    return combined
Update Rules
def update_layer(self,
                 layer: PredictiveLayer,
                 combined_error: torch.Tensor,
                 learning_rate: float = 0.01):
    """Update layer parameters based on combined error.
    
    Args:
        layer: Layer to update
        combined_error: Combined error signal
        learning_rate: Learning rate for updates
    """
    # Compute gradients
    with torch.enable_grad():
        # Weight updates
        dW_hidden = torch.autograd.grad(
            combined_error.mean(),
            layer.W_hidden
        )[0]
        dW_pred = torch.autograd.grad(
            combined_error.mean(),
            layer.W_pred
        )[0]
        
        # Bias updates
        db_hidden = torch.autograd.grad(
            combined_error.mean(),
            layer.b_hidden
        )[0]
        db_pred = torch.autograd.grad(
            combined_error.mean(),
            layer.b_pred
        )[0]
    
    # Apply updates
    with torch.no_grad():
        layer.W_hidden -= learning_rate * dW_hidden
        layer.W_pred -= learning_rate * dW_pred
        layer.b_hidden -= learning_rate * db_hidden
        layer.b_pred -= learning_rate * db_pred
Implementation Example
Full Network Update
def update_network(self,
                  input_data: torch.Tensor,
                  learning_rate: float = 0.01):
    """Perform full network update.
    
    Args:
        input_data: Input tensor
        learning_rate: Learning rate for updates
    """
    # Forward pass
    forward_errors = self.forward_messages(input_data)
    
    # Generate top-down signal
    top_signal = self.network.layers[-1].forward(forward_errors[-1].actual)
    
    # Backward pass
    backward_errors = self.backward_messages(top_signal)
    
    # Update each layer
    for layer_idx, layer in enumerate(self.network.layers):
        # Combine errors
        combined = self.combine_errors(
            forward_errors[layer_idx],
            backward_errors[-(layer_idx + 1)]
        )
        
        # Update layer
        self.update_layer(layer, combined, learning_rate)
Advanced Features
Error Gating
def gate_error(self,
               error: PredictionError,
               threshold: float = 0.1) -> PredictionError:
    """Gate error signal based on magnitude.
    
    Args:
        error: Prediction error
        threshold: Gating threshold
        
    Returns:
        gated: Gated error signal
    """
    magnitude = torch.abs(error.error)
    mask = magnitude > threshold
    
    gated_error = PredictionError(
        predicted=error.predicted,
        actual=error.actual,
        precision=error.precision * mask.float()
    )
    
    return gated_error
Temporal Integration
def integrate_temporal_errors(self,
                            current_error: PredictionError,
                            previous_errors: List[PredictionError],
                            window_size: int = 5) -> PredictionError:
    """Integrate errors over time.
    
    Args:
        current_error: Current prediction error
        previous_errors: List of previous errors
        window_size: Integration window size
        
    Returns:
        integrated: Temporally integrated error
    """
    # Collect recent errors
    recent_errors = previous_errors[-window_size:]
    recent_errors.append(current_error)
    
    # Compute weighted average
    weights = torch.exp(-torch.arange(len(recent_errors)))
    weights = weights / weights.sum()
    
    integrated_error = sum(
        e.error * w for e, w in zip(recent_errors, weights)
    )
    
    return PredictionError(
        predicted=current_error.predicted,
        actual=current_error.actual,
        precision=current_error.precision,
        error=integrated_error
    )
Best Practices
Error Handling
- Validate error magnitudes
- Check precision values
- Monitor gradients
- Handle edge cases
Optimization
- Batch processing
- Memory management
- Computational efficiency
- Numerical stability
Debugging
- Visualize error flow
- Track convergence
- Monitor updates
- Validate predictions
Common Issues
Stability
- Gradient explosion
- Vanishing errors
- Precision instability
- Update oscillations
Solutions
- Gradient clipping
- Error normalization
- Adaptive learning rates
- Error gating
