зеркало из
https://github.com/docxology/cognitive.git
synced 2025-11-01 05:36:05 +02:00
347 строки
9.0 KiB
Markdown
347 строки
9.0 KiB
Markdown
---
|
|
title: Temporal Models Implementation
|
|
type: implementation_guide
|
|
status: stable
|
|
created: 2024-02-12
|
|
tags:
|
|
- implementation
|
|
- predictive-processing
|
|
- temporal-dynamics
|
|
semantic_relations:
|
|
- type: implements
|
|
links: [[../../learning_paths/predictive_processing]]
|
|
- type: relates
|
|
links:
|
|
- [[predictive_network]]
|
|
- [[error_propagation]]
|
|
- [[precision_mechanisms]]
|
|
---
|
|
|
|
# Temporal Models Implementation
|
|
|
|
## Overview
|
|
|
|
This guide details the implementation of temporal models in predictive processing networks, focusing on sequence prediction, temporal dependencies, and dynamic updating.
|
|
|
|
## Core Components
|
|
|
|
### Temporal Layer
|
|
```python
|
|
class TemporalLayer(nn.Module):
|
|
def __init__(self,
|
|
input_size: int,
|
|
hidden_size: int,
|
|
temporal_horizon: int = 10):
|
|
"""Initialize temporal layer.
|
|
|
|
Args:
|
|
input_size: Size of input features
|
|
hidden_size: Size of hidden state
|
|
temporal_horizon: Number of timesteps to consider
|
|
"""
|
|
super().__init__()
|
|
|
|
self.input_size = input_size
|
|
self.hidden_size = hidden_size
|
|
self.temporal_horizon = temporal_horizon
|
|
|
|
# Temporal state
|
|
self.hidden_state = torch.zeros(hidden_size)
|
|
self.state_history = []
|
|
|
|
# Neural components
|
|
self.state_update = nn.GRUCell(input_size, hidden_size)
|
|
self.temporal_attention = TemporalAttention(hidden_size)
|
|
self.prediction_head = nn.Linear(hidden_size, input_size)
|
|
```
|
|
|
|
### Temporal Attention
|
|
```python
|
|
class TemporalAttention(nn.Module):
|
|
def __init__(self, hidden_size: int):
|
|
"""Initialize temporal attention mechanism.
|
|
|
|
Args:
|
|
hidden_size: Size of hidden state
|
|
"""
|
|
super().__init__()
|
|
|
|
# Attention components
|
|
self.query_proj = nn.Linear(hidden_size, hidden_size)
|
|
self.key_proj = nn.Linear(hidden_size, hidden_size)
|
|
self.value_proj = nn.Linear(hidden_size, hidden_size)
|
|
|
|
def forward(self,
|
|
query: torch.Tensor,
|
|
keys: List[torch.Tensor],
|
|
values: List[torch.Tensor]) -> torch.Tensor:
|
|
"""Compute temporal attention.
|
|
|
|
Args:
|
|
query: Current query state
|
|
keys: Historical key states
|
|
values: Historical value states
|
|
|
|
Returns:
|
|
context: Temporal context vector
|
|
"""
|
|
# Project query, keys, and values
|
|
q = self.query_proj(query)
|
|
k = torch.stack([self.key_proj(k) for k in keys])
|
|
v = torch.stack([self.value_proj(v) for v in values])
|
|
|
|
# Compute attention scores
|
|
scores = torch.matmul(q, k.transpose(-2, -1))
|
|
scores = scores / math.sqrt(self.hidden_size)
|
|
weights = F.softmax(scores, dim=-1)
|
|
|
|
# Compute weighted context
|
|
context = torch.matmul(weights, v)
|
|
return context
|
|
```
|
|
|
|
## Implementation
|
|
|
|
### Temporal Prediction
|
|
```python
|
|
def predict_next(self,
|
|
current_input: torch.Tensor) -> torch.Tensor:
|
|
"""Predict next timestep.
|
|
|
|
Args:
|
|
current_input: Current input tensor
|
|
|
|
Returns:
|
|
prediction: Prediction for next timestep
|
|
"""
|
|
# Update hidden state
|
|
self.hidden_state = self.state_update(
|
|
current_input,
|
|
self.hidden_state
|
|
)
|
|
|
|
# Update state history
|
|
self.state_history.append(self.hidden_state)
|
|
if len(self.state_history) > self.temporal_horizon:
|
|
self.state_history.pop(0)
|
|
|
|
# Compute temporal attention
|
|
if len(self.state_history) > 1:
|
|
context = self.temporal_attention(
|
|
query=self.hidden_state,
|
|
keys=self.state_history[:-1],
|
|
values=self.state_history[:-1]
|
|
)
|
|
|
|
# Combine with current state
|
|
combined = (self.hidden_state + context) / 2
|
|
else:
|
|
combined = self.hidden_state
|
|
|
|
# Generate prediction
|
|
prediction = self.prediction_head(combined)
|
|
return prediction
|
|
```
|
|
|
|
### Sequence Processing
|
|
```python
|
|
def process_sequence(self,
|
|
sequence: torch.Tensor,
|
|
teacher_forcing_ratio: float = 0.5) -> List[torch.Tensor]:
|
|
"""Process temporal sequence.
|
|
|
|
Args:
|
|
sequence: Input sequence [T, input_size]
|
|
teacher_forcing_ratio: Ratio of teacher forcing
|
|
|
|
Returns:
|
|
predictions: List of predictions
|
|
"""
|
|
predictions = []
|
|
current_input = sequence[0]
|
|
|
|
for t in range(1, len(sequence)):
|
|
# Generate prediction
|
|
prediction = self.predict_next(current_input)
|
|
predictions.append(prediction)
|
|
|
|
# Teacher forcing
|
|
if random.random() < teacher_forcing_ratio:
|
|
current_input = sequence[t]
|
|
else:
|
|
current_input = prediction
|
|
|
|
return predictions
|
|
```
|
|
|
|
## Advanced Features
|
|
|
|
### Multi-scale Temporal Processing
|
|
```python
|
|
class MultiScaleTemporalModel(nn.Module):
|
|
def __init__(self,
|
|
input_size: int,
|
|
num_scales: int = 3):
|
|
"""Initialize multi-scale temporal model.
|
|
|
|
Args:
|
|
input_size: Size of input features
|
|
num_scales: Number of temporal scales
|
|
"""
|
|
super().__init__()
|
|
|
|
# Create temporal layers at different scales
|
|
self.temporal_scales = nn.ModuleList([
|
|
TemporalLayer(
|
|
input_size=input_size,
|
|
hidden_size=input_size * 2,
|
|
temporal_horizon=2 ** i
|
|
)
|
|
for i in range(num_scales)
|
|
])
|
|
|
|
# Scale integration
|
|
self.scale_attention = nn.Linear(
|
|
num_scales * input_size,
|
|
input_size
|
|
)
|
|
|
|
def forward(self,
|
|
sequence: torch.Tensor) -> torch.Tensor:
|
|
"""Process sequence at multiple scales.
|
|
|
|
Args:
|
|
sequence: Input sequence
|
|
|
|
Returns:
|
|
prediction: Integrated prediction
|
|
"""
|
|
# Process at each scale
|
|
scale_predictions = []
|
|
for scale in self.temporal_scales:
|
|
pred = scale.process_sequence(sequence)
|
|
scale_predictions.append(pred[-1])
|
|
|
|
# Integrate predictions
|
|
combined = torch.cat(scale_predictions, dim=-1)
|
|
prediction = self.scale_attention(combined)
|
|
|
|
return prediction
|
|
```
|
|
|
|
### Temporal Error Integration
|
|
```python
|
|
class TemporalErrorIntegration:
|
|
def __init__(self,
|
|
decay_factor: float = 0.9,
|
|
max_history: int = 100):
|
|
"""Initialize temporal error integration.
|
|
|
|
Args:
|
|
decay_factor: Error decay rate
|
|
max_history: Maximum history length
|
|
"""
|
|
self.decay_factor = decay_factor
|
|
self.max_history = max_history
|
|
self.error_history = []
|
|
|
|
def integrate_error(self,
|
|
current_error: torch.Tensor) -> torch.Tensor:
|
|
"""Integrate error over time.
|
|
|
|
Args:
|
|
current_error: Current prediction error
|
|
|
|
Returns:
|
|
integrated: Temporally integrated error
|
|
"""
|
|
# Update history
|
|
self.error_history.append(current_error)
|
|
if len(self.error_history) > self.max_history:
|
|
self.error_history.pop(0)
|
|
|
|
# Compute temporal weights
|
|
weights = torch.tensor([
|
|
self.decay_factor ** i
|
|
for i in range(len(self.error_history))
|
|
])
|
|
weights = weights / weights.sum()
|
|
|
|
# Compute weighted sum
|
|
integrated = sum(
|
|
e * w for e, w in zip(self.error_history, weights)
|
|
)
|
|
|
|
return integrated
|
|
```
|
|
|
|
## Usage Examples
|
|
|
|
### Basic Usage
|
|
```python
|
|
# Initialize temporal layer
|
|
temporal_layer = TemporalLayer(
|
|
input_size=64,
|
|
hidden_size=128,
|
|
temporal_horizon=10
|
|
)
|
|
|
|
# Process sequence
|
|
sequence = torch.randn(20, 64) # [T, input_size]
|
|
predictions = temporal_layer.process_sequence(sequence)
|
|
|
|
# Get next prediction
|
|
next_pred = temporal_layer.predict_next(sequence[-1])
|
|
```
|
|
|
|
### Multi-scale Processing
|
|
```python
|
|
# Initialize multi-scale model
|
|
multi_scale = MultiScaleTemporalModel(
|
|
input_size=64,
|
|
num_scales=3
|
|
)
|
|
|
|
# Process sequence
|
|
sequence = torch.randn(20, 64)
|
|
prediction = multi_scale(sequence)
|
|
```
|
|
|
|
## Best Practices
|
|
|
|
### Temporal Processing
|
|
1. Handle variable sequences
|
|
2. Implement teacher forcing
|
|
3. Use appropriate horizons
|
|
4. Consider multiple scales
|
|
|
|
### Memory Management
|
|
1. Limit history size
|
|
2. Clear unused states
|
|
3. Batch sequences
|
|
4. Optimize storage
|
|
|
|
### Training
|
|
1. Curriculum learning
|
|
2. Scheduled sampling
|
|
3. Gradient clipping
|
|
4. Validation splits
|
|
|
|
## Common Issues
|
|
|
|
### Stability
|
|
1. Vanishing gradients
|
|
2. Exploding predictions
|
|
3. Memory overflow
|
|
4. Error accumulation
|
|
|
|
### Solutions
|
|
1. Residual connections
|
|
2. Gradient normalization
|
|
3. Memory optimization
|
|
4. Error gating
|
|
|
|
## Related Documentation
|
|
- [[predictive_network]]
|
|
- [[error_propagation]]
|
|
- [[precision_mechanisms]] |