зеркало из
https://github.com/docxology/cognitive.git
synced 2025-10-31 21:26:04 +02:00
8.5 KiB
8.5 KiB
| title | type | status | created | tags | semantic_relations | |||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| Error Propagation Implementation | implementation_guide | stable | 2024-02-12 |
|
|
Error Propagation Implementation
Overview
This guide details the implementation of error propagation mechanisms in predictive processing networks, focusing on both forward and backward message passing.
Core Mechanisms
Error Types
class PredictionError:
def __init__(self,
predicted: torch.Tensor,
actual: torch.Tensor,
precision: torch.Tensor):
"""Initialize prediction error.
Args:
predicted: Predicted values
actual: Actual values
precision: Precision (inverse variance)
"""
self.predicted = predicted
self.actual = actual
self.precision = precision
self.error = self._compute_error()
def _compute_error(self) -> torch.Tensor:
"""Compute weighted prediction error."""
raw_error = self.actual - self.predicted
return raw_error * self.precision
Message Passing
class MessagePassing:
def __init__(self, network: PredictiveNetwork):
"""Initialize message passing.
Args:
network: Predictive network instance
"""
self.network = network
self.messages_up = []
self.messages_down = []
def forward_messages(self,
input_data: torch.Tensor) -> List[PredictionError]:
"""Compute forward (bottom-up) messages.
Args:
input_data: Input tensor
Returns:
errors: List of prediction errors
"""
current = input_data
errors = []
for layer in self.network.layers:
# Generate prediction
prediction = layer.forward(current)
# Compute error
error = PredictionError(
predicted=prediction,
actual=current,
precision=layer.precision
)
errors.append(error)
# Update current input
current = prediction
self.messages_up = errors
return errors
def backward_messages(self,
top_down_signal: torch.Tensor) -> List[PredictionError]:
"""Compute backward (top-down) messages.
Args:
top_down_signal: Top-level signal
Returns:
errors: List of prediction errors
"""
current = top_down_signal
errors = []
for layer in reversed(self.network.layers):
# Generate backward prediction
prediction = layer.backward(current)
# Compute error
error = PredictionError(
predicted=prediction,
actual=current,
precision=layer.precision
)
errors.append(error)
# Update current signal
current = prediction
self.messages_down = errors
return errors
Error Integration
Error Combination
def combine_errors(self,
bottom_up: PredictionError,
top_down: PredictionError) -> torch.Tensor:
"""Combine bottom-up and top-down errors.
Args:
bottom_up: Bottom-up prediction error
top_down: Top-down prediction error
Returns:
combined: Combined error signal
"""
# Weight errors by their precisions
weighted_up = bottom_up.error * bottom_up.precision
weighted_down = top_down.error * top_down.precision
# Combine weighted errors
total_precision = bottom_up.precision + top_down.precision
combined = (weighted_up + weighted_down) / (total_precision + 1e-6)
return combined
Update Rules
def update_layer(self,
layer: PredictiveLayer,
combined_error: torch.Tensor,
learning_rate: float = 0.01):
"""Update layer parameters based on combined error.
Args:
layer: Layer to update
combined_error: Combined error signal
learning_rate: Learning rate for updates
"""
# Compute gradients
with torch.enable_grad():
# Weight updates
dW_hidden = torch.autograd.grad(
combined_error.mean(),
layer.W_hidden
)[0]
dW_pred = torch.autograd.grad(
combined_error.mean(),
layer.W_pred
)[0]
# Bias updates
db_hidden = torch.autograd.grad(
combined_error.mean(),
layer.b_hidden
)[0]
db_pred = torch.autograd.grad(
combined_error.mean(),
layer.b_pred
)[0]
# Apply updates
with torch.no_grad():
layer.W_hidden -= learning_rate * dW_hidden
layer.W_pred -= learning_rate * dW_pred
layer.b_hidden -= learning_rate * db_hidden
layer.b_pred -= learning_rate * db_pred
Implementation Example
Full Network Update
def update_network(self,
input_data: torch.Tensor,
learning_rate: float = 0.01):
"""Perform full network update.
Args:
input_data: Input tensor
learning_rate: Learning rate for updates
"""
# Forward pass
forward_errors = self.forward_messages(input_data)
# Generate top-down signal
top_signal = self.network.layers[-1].forward(forward_errors[-1].actual)
# Backward pass
backward_errors = self.backward_messages(top_signal)
# Update each layer
for layer_idx, layer in enumerate(self.network.layers):
# Combine errors
combined = self.combine_errors(
forward_errors[layer_idx],
backward_errors[-(layer_idx + 1)]
)
# Update layer
self.update_layer(layer, combined, learning_rate)
Advanced Features
Error Gating
def gate_error(self,
error: PredictionError,
threshold: float = 0.1) -> PredictionError:
"""Gate error signal based on magnitude.
Args:
error: Prediction error
threshold: Gating threshold
Returns:
gated: Gated error signal
"""
magnitude = torch.abs(error.error)
mask = magnitude > threshold
gated_error = PredictionError(
predicted=error.predicted,
actual=error.actual,
precision=error.precision * mask.float()
)
return gated_error
Temporal Integration
def integrate_temporal_errors(self,
current_error: PredictionError,
previous_errors: List[PredictionError],
window_size: int = 5) -> PredictionError:
"""Integrate errors over time.
Args:
current_error: Current prediction error
previous_errors: List of previous errors
window_size: Integration window size
Returns:
integrated: Temporally integrated error
"""
# Collect recent errors
recent_errors = previous_errors[-window_size:]
recent_errors.append(current_error)
# Compute weighted average
weights = torch.exp(-torch.arange(len(recent_errors)))
weights = weights / weights.sum()
integrated_error = sum(
e.error * w for e, w in zip(recent_errors, weights)
)
return PredictionError(
predicted=current_error.predicted,
actual=current_error.actual,
precision=current_error.precision,
error=integrated_error
)
Best Practices
Error Handling
- Validate error magnitudes
- Check precision values
- Monitor gradients
- Handle edge cases
Optimization
- Batch processing
- Memory management
- Computational efficiency
- Numerical stability
Debugging
- Visualize error flow
- Track convergence
- Monitor updates
- Validate predictions
Common Issues
Stability
- Gradient explosion
- Vanishing errors
- Precision instability
- Update oscillations
Solutions
- Gradient clipping
- Error normalization
- Adaptive learning rates
- Error gating