cognitive/docs/guides/implementation/error_propagation.md
Daniel Ari Friedman 7ca29024d8 Updates
2025-02-12 13:23:51 -08:00

333 строки
8.5 KiB
Markdown

---
title: Error Propagation Implementation
type: implementation_guide
status: stable
created: 2024-02-12
tags:
- implementation
- predictive-processing
- error-handling
semantic_relations:
- type: implements
links: [[../../learning_paths/predictive_processing]]
- type: relates
links:
- [[predictive_network]]
- [[precision_mechanisms]]
---
# Error Propagation Implementation
## Overview
This guide details the implementation of error propagation mechanisms in predictive processing networks, focusing on both forward and backward message passing.
## Core Mechanisms
### Error Types
```python
class PredictionError:
def __init__(self,
predicted: torch.Tensor,
actual: torch.Tensor,
precision: torch.Tensor):
"""Initialize prediction error.
Args:
predicted: Predicted values
actual: Actual values
precision: Precision (inverse variance)
"""
self.predicted = predicted
self.actual = actual
self.precision = precision
self.error = self._compute_error()
def _compute_error(self) -> torch.Tensor:
"""Compute weighted prediction error."""
raw_error = self.actual - self.predicted
return raw_error * self.precision
```
### Message Passing
```python
class MessagePassing:
def __init__(self, network: PredictiveNetwork):
"""Initialize message passing.
Args:
network: Predictive network instance
"""
self.network = network
self.messages_up = []
self.messages_down = []
def forward_messages(self,
input_data: torch.Tensor) -> List[PredictionError]:
"""Compute forward (bottom-up) messages.
Args:
input_data: Input tensor
Returns:
errors: List of prediction errors
"""
current = input_data
errors = []
for layer in self.network.layers:
# Generate prediction
prediction = layer.forward(current)
# Compute error
error = PredictionError(
predicted=prediction,
actual=current,
precision=layer.precision
)
errors.append(error)
# Update current input
current = prediction
self.messages_up = errors
return errors
def backward_messages(self,
top_down_signal: torch.Tensor) -> List[PredictionError]:
"""Compute backward (top-down) messages.
Args:
top_down_signal: Top-level signal
Returns:
errors: List of prediction errors
"""
current = top_down_signal
errors = []
for layer in reversed(self.network.layers):
# Generate backward prediction
prediction = layer.backward(current)
# Compute error
error = PredictionError(
predicted=prediction,
actual=current,
precision=layer.precision
)
errors.append(error)
# Update current signal
current = prediction
self.messages_down = errors
return errors
```
## Error Integration
### Error Combination
```python
def combine_errors(self,
bottom_up: PredictionError,
top_down: PredictionError) -> torch.Tensor:
"""Combine bottom-up and top-down errors.
Args:
bottom_up: Bottom-up prediction error
top_down: Top-down prediction error
Returns:
combined: Combined error signal
"""
# Weight errors by their precisions
weighted_up = bottom_up.error * bottom_up.precision
weighted_down = top_down.error * top_down.precision
# Combine weighted errors
total_precision = bottom_up.precision + top_down.precision
combined = (weighted_up + weighted_down) / (total_precision + 1e-6)
return combined
```
### Update Rules
```python
def update_layer(self,
layer: PredictiveLayer,
combined_error: torch.Tensor,
learning_rate: float = 0.01):
"""Update layer parameters based on combined error.
Args:
layer: Layer to update
combined_error: Combined error signal
learning_rate: Learning rate for updates
"""
# Compute gradients
with torch.enable_grad():
# Weight updates
dW_hidden = torch.autograd.grad(
combined_error.mean(),
layer.W_hidden
)[0]
dW_pred = torch.autograd.grad(
combined_error.mean(),
layer.W_pred
)[0]
# Bias updates
db_hidden = torch.autograd.grad(
combined_error.mean(),
layer.b_hidden
)[0]
db_pred = torch.autograd.grad(
combined_error.mean(),
layer.b_pred
)[0]
# Apply updates
with torch.no_grad():
layer.W_hidden -= learning_rate * dW_hidden
layer.W_pred -= learning_rate * dW_pred
layer.b_hidden -= learning_rate * db_hidden
layer.b_pred -= learning_rate * db_pred
```
## Implementation Example
### Full Network Update
```python
def update_network(self,
input_data: torch.Tensor,
learning_rate: float = 0.01):
"""Perform full network update.
Args:
input_data: Input tensor
learning_rate: Learning rate for updates
"""
# Forward pass
forward_errors = self.forward_messages(input_data)
# Generate top-down signal
top_signal = self.network.layers[-1].forward(forward_errors[-1].actual)
# Backward pass
backward_errors = self.backward_messages(top_signal)
# Update each layer
for layer_idx, layer in enumerate(self.network.layers):
# Combine errors
combined = self.combine_errors(
forward_errors[layer_idx],
backward_errors[-(layer_idx + 1)]
)
# Update layer
self.update_layer(layer, combined, learning_rate)
```
## Advanced Features
### Error Gating
```python
def gate_error(self,
error: PredictionError,
threshold: float = 0.1) -> PredictionError:
"""Gate error signal based on magnitude.
Args:
error: Prediction error
threshold: Gating threshold
Returns:
gated: Gated error signal
"""
magnitude = torch.abs(error.error)
mask = magnitude > threshold
gated_error = PredictionError(
predicted=error.predicted,
actual=error.actual,
precision=error.precision * mask.float()
)
return gated_error
```
### Temporal Integration
```python
def integrate_temporal_errors(self,
current_error: PredictionError,
previous_errors: List[PredictionError],
window_size: int = 5) -> PredictionError:
"""Integrate errors over time.
Args:
current_error: Current prediction error
previous_errors: List of previous errors
window_size: Integration window size
Returns:
integrated: Temporally integrated error
"""
# Collect recent errors
recent_errors = previous_errors[-window_size:]
recent_errors.append(current_error)
# Compute weighted average
weights = torch.exp(-torch.arange(len(recent_errors)))
weights = weights / weights.sum()
integrated_error = sum(
e.error * w for e, w in zip(recent_errors, weights)
)
return PredictionError(
predicted=current_error.predicted,
actual=current_error.actual,
precision=current_error.precision,
error=integrated_error
)
```
## Best Practices
### Error Handling
1. Validate error magnitudes
2. Check precision values
3. Monitor gradients
4. Handle edge cases
### Optimization
1. Batch processing
2. Memory management
3. Computational efficiency
4. Numerical stability
### Debugging
1. Visualize error flow
2. Track convergence
3. Monitor updates
4. Validate predictions
## Common Issues
### Stability
1. Gradient explosion
2. Vanishing errors
3. Precision instability
4. Update oscillations
### Solutions
1. Gradient clipping
2. Error normalization
3. Adaptive learning rates
4. Error gating
## Related Documentation
- [[predictive_network]]
- [[precision_mechanisms]]
- [[temporal_models]]