--- title: Temporal Models Implementation type: implementation_guide status: stable created: 2024-02-12 tags: - implementation - predictive-processing - temporal-dynamics semantic_relations: - type: implements links: [[../../learning_paths/predictive_processing]] - type: relates links: - [[predictive_network]] - [[error_propagation]] - [[precision_mechanisms]] --- # Temporal Models Implementation ## Overview This guide details the implementation of temporal models in predictive processing networks, focusing on sequence prediction, temporal dependencies, and dynamic updating. ## Core Components ### Temporal Layer ```python class TemporalLayer(nn.Module): def __init__(self, input_size: int, hidden_size: int, temporal_horizon: int = 10): """Initialize temporal layer. Args: input_size: Size of input features hidden_size: Size of hidden state temporal_horizon: Number of timesteps to consider """ super().__init__() self.input_size = input_size self.hidden_size = hidden_size self.temporal_horizon = temporal_horizon # Temporal state self.hidden_state = torch.zeros(hidden_size) self.state_history = [] # Neural components self.state_update = nn.GRUCell(input_size, hidden_size) self.temporal_attention = TemporalAttention(hidden_size) self.prediction_head = nn.Linear(hidden_size, input_size) ``` ### Temporal Attention ```python class TemporalAttention(nn.Module): def __init__(self, hidden_size: int): """Initialize temporal attention mechanism. Args: hidden_size: Size of hidden state """ super().__init__() # Attention components self.query_proj = nn.Linear(hidden_size, hidden_size) self.key_proj = nn.Linear(hidden_size, hidden_size) self.value_proj = nn.Linear(hidden_size, hidden_size) def forward(self, query: torch.Tensor, keys: List[torch.Tensor], values: List[torch.Tensor]) -> torch.Tensor: """Compute temporal attention. Args: query: Current query state keys: Historical key states values: Historical value states Returns: context: Temporal context vector """ # Project query, keys, and values q = self.query_proj(query) k = torch.stack([self.key_proj(k) for k in keys]) v = torch.stack([self.value_proj(v) for v in values]) # Compute attention scores scores = torch.matmul(q, k.transpose(-2, -1)) scores = scores / math.sqrt(self.hidden_size) weights = F.softmax(scores, dim=-1) # Compute weighted context context = torch.matmul(weights, v) return context ``` ## Implementation ### Temporal Prediction ```python def predict_next(self, current_input: torch.Tensor) -> torch.Tensor: """Predict next timestep. Args: current_input: Current input tensor Returns: prediction: Prediction for next timestep """ # Update hidden state self.hidden_state = self.state_update( current_input, self.hidden_state ) # Update state history self.state_history.append(self.hidden_state) if len(self.state_history) > self.temporal_horizon: self.state_history.pop(0) # Compute temporal attention if len(self.state_history) > 1: context = self.temporal_attention( query=self.hidden_state, keys=self.state_history[:-1], values=self.state_history[:-1] ) # Combine with current state combined = (self.hidden_state + context) / 2 else: combined = self.hidden_state # Generate prediction prediction = self.prediction_head(combined) return prediction ``` ### Sequence Processing ```python def process_sequence(self, sequence: torch.Tensor, teacher_forcing_ratio: float = 0.5) -> List[torch.Tensor]: """Process temporal sequence. Args: sequence: Input sequence [T, input_size] teacher_forcing_ratio: Ratio of teacher forcing Returns: predictions: List of predictions """ predictions = [] current_input = sequence[0] for t in range(1, len(sequence)): # Generate prediction prediction = self.predict_next(current_input) predictions.append(prediction) # Teacher forcing if random.random() < teacher_forcing_ratio: current_input = sequence[t] else: current_input = prediction return predictions ``` ## Advanced Features ### Multi-scale Temporal Processing ```python class MultiScaleTemporalModel(nn.Module): def __init__(self, input_size: int, num_scales: int = 3): """Initialize multi-scale temporal model. Args: input_size: Size of input features num_scales: Number of temporal scales """ super().__init__() # Create temporal layers at different scales self.temporal_scales = nn.ModuleList([ TemporalLayer( input_size=input_size, hidden_size=input_size * 2, temporal_horizon=2 ** i ) for i in range(num_scales) ]) # Scale integration self.scale_attention = nn.Linear( num_scales * input_size, input_size ) def forward(self, sequence: torch.Tensor) -> torch.Tensor: """Process sequence at multiple scales. Args: sequence: Input sequence Returns: prediction: Integrated prediction """ # Process at each scale scale_predictions = [] for scale in self.temporal_scales: pred = scale.process_sequence(sequence) scale_predictions.append(pred[-1]) # Integrate predictions combined = torch.cat(scale_predictions, dim=-1) prediction = self.scale_attention(combined) return prediction ``` ### Temporal Error Integration ```python class TemporalErrorIntegration: def __init__(self, decay_factor: float = 0.9, max_history: int = 100): """Initialize temporal error integration. Args: decay_factor: Error decay rate max_history: Maximum history length """ self.decay_factor = decay_factor self.max_history = max_history self.error_history = [] def integrate_error(self, current_error: torch.Tensor) -> torch.Tensor: """Integrate error over time. Args: current_error: Current prediction error Returns: integrated: Temporally integrated error """ # Update history self.error_history.append(current_error) if len(self.error_history) > self.max_history: self.error_history.pop(0) # Compute temporal weights weights = torch.tensor([ self.decay_factor ** i for i in range(len(self.error_history)) ]) weights = weights / weights.sum() # Compute weighted sum integrated = sum( e * w for e, w in zip(self.error_history, weights) ) return integrated ``` ## Usage Examples ### Basic Usage ```python # Initialize temporal layer temporal_layer = TemporalLayer( input_size=64, hidden_size=128, temporal_horizon=10 ) # Process sequence sequence = torch.randn(20, 64) # [T, input_size] predictions = temporal_layer.process_sequence(sequence) # Get next prediction next_pred = temporal_layer.predict_next(sequence[-1]) ``` ### Multi-scale Processing ```python # Initialize multi-scale model multi_scale = MultiScaleTemporalModel( input_size=64, num_scales=3 ) # Process sequence sequence = torch.randn(20, 64) prediction = multi_scale(sequence) ``` ## Best Practices ### Temporal Processing 1. Handle variable sequences 2. Implement teacher forcing 3. Use appropriate horizons 4. Consider multiple scales ### Memory Management 1. Limit history size 2. Clear unused states 3. Batch sequences 4. Optimize storage ### Training 1. Curriculum learning 2. Scheduled sampling 3. Gradient clipping 4. Validation splits ## Common Issues ### Stability 1. Vanishing gradients 2. Exploding predictions 3. Memory overflow 4. Error accumulation ### Solutions 1. Residual connections 2. Gradient normalization 3. Memory optimization 4. Error gating ## Related Documentation - [[predictive_network]] - [[error_propagation]] - [[precision_mechanisms]]