зеркало из
https://github.com/docxology/cognitive.git
synced 2025-11-01 05:36:05 +02:00
333 строки
8.5 KiB
Markdown
333 строки
8.5 KiB
Markdown
---
|
|
title: Error Propagation Implementation
|
|
type: implementation_guide
|
|
status: stable
|
|
created: 2024-02-12
|
|
tags:
|
|
- implementation
|
|
- predictive-processing
|
|
- error-handling
|
|
semantic_relations:
|
|
- type: implements
|
|
links: [[../../learning_paths/predictive_processing]]
|
|
- type: relates
|
|
links:
|
|
- [[predictive_network]]
|
|
- [[precision_mechanisms]]
|
|
---
|
|
|
|
# Error Propagation Implementation
|
|
|
|
## Overview
|
|
|
|
This guide details the implementation of error propagation mechanisms in predictive processing networks, focusing on both forward and backward message passing.
|
|
|
|
## Core Mechanisms
|
|
|
|
### Error Types
|
|
```python
|
|
class PredictionError:
|
|
def __init__(self,
|
|
predicted: torch.Tensor,
|
|
actual: torch.Tensor,
|
|
precision: torch.Tensor):
|
|
"""Initialize prediction error.
|
|
|
|
Args:
|
|
predicted: Predicted values
|
|
actual: Actual values
|
|
precision: Precision (inverse variance)
|
|
"""
|
|
self.predicted = predicted
|
|
self.actual = actual
|
|
self.precision = precision
|
|
self.error = self._compute_error()
|
|
|
|
def _compute_error(self) -> torch.Tensor:
|
|
"""Compute weighted prediction error."""
|
|
raw_error = self.actual - self.predicted
|
|
return raw_error * self.precision
|
|
```
|
|
|
|
### Message Passing
|
|
```python
|
|
class MessagePassing:
|
|
def __init__(self, network: PredictiveNetwork):
|
|
"""Initialize message passing.
|
|
|
|
Args:
|
|
network: Predictive network instance
|
|
"""
|
|
self.network = network
|
|
self.messages_up = []
|
|
self.messages_down = []
|
|
|
|
def forward_messages(self,
|
|
input_data: torch.Tensor) -> List[PredictionError]:
|
|
"""Compute forward (bottom-up) messages.
|
|
|
|
Args:
|
|
input_data: Input tensor
|
|
|
|
Returns:
|
|
errors: List of prediction errors
|
|
"""
|
|
current = input_data
|
|
errors = []
|
|
|
|
for layer in self.network.layers:
|
|
# Generate prediction
|
|
prediction = layer.forward(current)
|
|
|
|
# Compute error
|
|
error = PredictionError(
|
|
predicted=prediction,
|
|
actual=current,
|
|
precision=layer.precision
|
|
)
|
|
errors.append(error)
|
|
|
|
# Update current input
|
|
current = prediction
|
|
|
|
self.messages_up = errors
|
|
return errors
|
|
|
|
def backward_messages(self,
|
|
top_down_signal: torch.Tensor) -> List[PredictionError]:
|
|
"""Compute backward (top-down) messages.
|
|
|
|
Args:
|
|
top_down_signal: Top-level signal
|
|
|
|
Returns:
|
|
errors: List of prediction errors
|
|
"""
|
|
current = top_down_signal
|
|
errors = []
|
|
|
|
for layer in reversed(self.network.layers):
|
|
# Generate backward prediction
|
|
prediction = layer.backward(current)
|
|
|
|
# Compute error
|
|
error = PredictionError(
|
|
predicted=prediction,
|
|
actual=current,
|
|
precision=layer.precision
|
|
)
|
|
errors.append(error)
|
|
|
|
# Update current signal
|
|
current = prediction
|
|
|
|
self.messages_down = errors
|
|
return errors
|
|
```
|
|
|
|
## Error Integration
|
|
|
|
### Error Combination
|
|
```python
|
|
def combine_errors(self,
|
|
bottom_up: PredictionError,
|
|
top_down: PredictionError) -> torch.Tensor:
|
|
"""Combine bottom-up and top-down errors.
|
|
|
|
Args:
|
|
bottom_up: Bottom-up prediction error
|
|
top_down: Top-down prediction error
|
|
|
|
Returns:
|
|
combined: Combined error signal
|
|
"""
|
|
# Weight errors by their precisions
|
|
weighted_up = bottom_up.error * bottom_up.precision
|
|
weighted_down = top_down.error * top_down.precision
|
|
|
|
# Combine weighted errors
|
|
total_precision = bottom_up.precision + top_down.precision
|
|
combined = (weighted_up + weighted_down) / (total_precision + 1e-6)
|
|
|
|
return combined
|
|
```
|
|
|
|
### Update Rules
|
|
```python
|
|
def update_layer(self,
|
|
layer: PredictiveLayer,
|
|
combined_error: torch.Tensor,
|
|
learning_rate: float = 0.01):
|
|
"""Update layer parameters based on combined error.
|
|
|
|
Args:
|
|
layer: Layer to update
|
|
combined_error: Combined error signal
|
|
learning_rate: Learning rate for updates
|
|
"""
|
|
# Compute gradients
|
|
with torch.enable_grad():
|
|
# Weight updates
|
|
dW_hidden = torch.autograd.grad(
|
|
combined_error.mean(),
|
|
layer.W_hidden
|
|
)[0]
|
|
dW_pred = torch.autograd.grad(
|
|
combined_error.mean(),
|
|
layer.W_pred
|
|
)[0]
|
|
|
|
# Bias updates
|
|
db_hidden = torch.autograd.grad(
|
|
combined_error.mean(),
|
|
layer.b_hidden
|
|
)[0]
|
|
db_pred = torch.autograd.grad(
|
|
combined_error.mean(),
|
|
layer.b_pred
|
|
)[0]
|
|
|
|
# Apply updates
|
|
with torch.no_grad():
|
|
layer.W_hidden -= learning_rate * dW_hidden
|
|
layer.W_pred -= learning_rate * dW_pred
|
|
layer.b_hidden -= learning_rate * db_hidden
|
|
layer.b_pred -= learning_rate * db_pred
|
|
```
|
|
|
|
## Implementation Example
|
|
|
|
### Full Network Update
|
|
```python
|
|
def update_network(self,
|
|
input_data: torch.Tensor,
|
|
learning_rate: float = 0.01):
|
|
"""Perform full network update.
|
|
|
|
Args:
|
|
input_data: Input tensor
|
|
learning_rate: Learning rate for updates
|
|
"""
|
|
# Forward pass
|
|
forward_errors = self.forward_messages(input_data)
|
|
|
|
# Generate top-down signal
|
|
top_signal = self.network.layers[-1].forward(forward_errors[-1].actual)
|
|
|
|
# Backward pass
|
|
backward_errors = self.backward_messages(top_signal)
|
|
|
|
# Update each layer
|
|
for layer_idx, layer in enumerate(self.network.layers):
|
|
# Combine errors
|
|
combined = self.combine_errors(
|
|
forward_errors[layer_idx],
|
|
backward_errors[-(layer_idx + 1)]
|
|
)
|
|
|
|
# Update layer
|
|
self.update_layer(layer, combined, learning_rate)
|
|
```
|
|
|
|
## Advanced Features
|
|
|
|
### Error Gating
|
|
```python
|
|
def gate_error(self,
|
|
error: PredictionError,
|
|
threshold: float = 0.1) -> PredictionError:
|
|
"""Gate error signal based on magnitude.
|
|
|
|
Args:
|
|
error: Prediction error
|
|
threshold: Gating threshold
|
|
|
|
Returns:
|
|
gated: Gated error signal
|
|
"""
|
|
magnitude = torch.abs(error.error)
|
|
mask = magnitude > threshold
|
|
|
|
gated_error = PredictionError(
|
|
predicted=error.predicted,
|
|
actual=error.actual,
|
|
precision=error.precision * mask.float()
|
|
)
|
|
|
|
return gated_error
|
|
```
|
|
|
|
### Temporal Integration
|
|
```python
|
|
def integrate_temporal_errors(self,
|
|
current_error: PredictionError,
|
|
previous_errors: List[PredictionError],
|
|
window_size: int = 5) -> PredictionError:
|
|
"""Integrate errors over time.
|
|
|
|
Args:
|
|
current_error: Current prediction error
|
|
previous_errors: List of previous errors
|
|
window_size: Integration window size
|
|
|
|
Returns:
|
|
integrated: Temporally integrated error
|
|
"""
|
|
# Collect recent errors
|
|
recent_errors = previous_errors[-window_size:]
|
|
recent_errors.append(current_error)
|
|
|
|
# Compute weighted average
|
|
weights = torch.exp(-torch.arange(len(recent_errors)))
|
|
weights = weights / weights.sum()
|
|
|
|
integrated_error = sum(
|
|
e.error * w for e, w in zip(recent_errors, weights)
|
|
)
|
|
|
|
return PredictionError(
|
|
predicted=current_error.predicted,
|
|
actual=current_error.actual,
|
|
precision=current_error.precision,
|
|
error=integrated_error
|
|
)
|
|
```
|
|
|
|
## Best Practices
|
|
|
|
### Error Handling
|
|
1. Validate error magnitudes
|
|
2. Check precision values
|
|
3. Monitor gradients
|
|
4. Handle edge cases
|
|
|
|
### Optimization
|
|
1. Batch processing
|
|
2. Memory management
|
|
3. Computational efficiency
|
|
4. Numerical stability
|
|
|
|
### Debugging
|
|
1. Visualize error flow
|
|
2. Track convergence
|
|
3. Monitor updates
|
|
4. Validate predictions
|
|
|
|
## Common Issues
|
|
|
|
### Stability
|
|
1. Gradient explosion
|
|
2. Vanishing errors
|
|
3. Precision instability
|
|
4. Update oscillations
|
|
|
|
### Solutions
|
|
1. Gradient clipping
|
|
2. Error normalization
|
|
3. Adaptive learning rates
|
|
4. Error gating
|
|
|
|
## Related Documentation
|
|
- [[predictive_network]]
|
|
- [[precision_mechanisms]]
|
|
- [[temporal_models]] |