cognitive/docs/guides/implementation/temporal_models.md
Daniel Ari Friedman 7ca29024d8 Updates
2025-02-12 13:23:51 -08:00

9.0 KiB

title type status created tags semantic_relations
Temporal Models Implementation implementation_guide stable 2024-02-12
implementation
predictive-processing
temporal-dynamics
type links
implements
../../learning_paths/predictive_processing
type links
relates
predictive_network
error_propagation
precision_mechanisms

Temporal Models Implementation

Overview

This guide details the implementation of temporal models in predictive processing networks, focusing on sequence prediction, temporal dependencies, and dynamic updating.

Core Components

Temporal Layer

class TemporalLayer(nn.Module):
    def __init__(self,
                 input_size: int,
                 hidden_size: int,
                 temporal_horizon: int = 10):
        """Initialize temporal layer.
        
        Args:
            input_size: Size of input features
            hidden_size: Size of hidden state
            temporal_horizon: Number of timesteps to consider
        """
        super().__init__()
        
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.temporal_horizon = temporal_horizon
        
        # Temporal state
        self.hidden_state = torch.zeros(hidden_size)
        self.state_history = []
        
        # Neural components
        self.state_update = nn.GRUCell(input_size, hidden_size)
        self.temporal_attention = TemporalAttention(hidden_size)
        self.prediction_head = nn.Linear(hidden_size, input_size)

Temporal Attention

class TemporalAttention(nn.Module):
    def __init__(self, hidden_size: int):
        """Initialize temporal attention mechanism.
        
        Args:
            hidden_size: Size of hidden state
        """
        super().__init__()
        
        # Attention components
        self.query_proj = nn.Linear(hidden_size, hidden_size)
        self.key_proj = nn.Linear(hidden_size, hidden_size)
        self.value_proj = nn.Linear(hidden_size, hidden_size)
        
    def forward(self,
                query: torch.Tensor,
                keys: List[torch.Tensor],
                values: List[torch.Tensor]) -> torch.Tensor:
        """Compute temporal attention.
        
        Args:
            query: Current query state
            keys: Historical key states
            values: Historical value states
            
        Returns:
            context: Temporal context vector
        """
        # Project query, keys, and values
        q = self.query_proj(query)
        k = torch.stack([self.key_proj(k) for k in keys])
        v = torch.stack([self.value_proj(v) for v in values])
        
        # Compute attention scores
        scores = torch.matmul(q, k.transpose(-2, -1))
        scores = scores / math.sqrt(self.hidden_size)
        weights = F.softmax(scores, dim=-1)
        
        # Compute weighted context
        context = torch.matmul(weights, v)
        return context

Implementation

Temporal Prediction

def predict_next(self,
                current_input: torch.Tensor) -> torch.Tensor:
    """Predict next timestep.
    
    Args:
        current_input: Current input tensor
        
    Returns:
        prediction: Prediction for next timestep
    """
    # Update hidden state
    self.hidden_state = self.state_update(
        current_input,
        self.hidden_state
    )
    
    # Update state history
    self.state_history.append(self.hidden_state)
    if len(self.state_history) > self.temporal_horizon:
        self.state_history.pop(0)
    
    # Compute temporal attention
    if len(self.state_history) > 1:
        context = self.temporal_attention(
            query=self.hidden_state,
            keys=self.state_history[:-1],
            values=self.state_history[:-1]
        )
        
        # Combine with current state
        combined = (self.hidden_state + context) / 2
    else:
        combined = self.hidden_state
    
    # Generate prediction
    prediction = self.prediction_head(combined)
    return prediction

Sequence Processing

def process_sequence(self,
                    sequence: torch.Tensor,
                    teacher_forcing_ratio: float = 0.5) -> List[torch.Tensor]:
    """Process temporal sequence.
    
    Args:
        sequence: Input sequence [T, input_size]
        teacher_forcing_ratio: Ratio of teacher forcing
        
    Returns:
        predictions: List of predictions
    """
    predictions = []
    current_input = sequence[0]
    
    for t in range(1, len(sequence)):
        # Generate prediction
        prediction = self.predict_next(current_input)
        predictions.append(prediction)
        
        # Teacher forcing
        if random.random() < teacher_forcing_ratio:
            current_input = sequence[t]
        else:
            current_input = prediction
    
    return predictions

Advanced Features

Multi-scale Temporal Processing

class MultiScaleTemporalModel(nn.Module):
    def __init__(self,
                 input_size: int,
                 num_scales: int = 3):
        """Initialize multi-scale temporal model.
        
        Args:
            input_size: Size of input features
            num_scales: Number of temporal scales
        """
        super().__init__()
        
        # Create temporal layers at different scales
        self.temporal_scales = nn.ModuleList([
            TemporalLayer(
                input_size=input_size,
                hidden_size=input_size * 2,
                temporal_horizon=2 ** i
            )
            for i in range(num_scales)
        ])
        
        # Scale integration
        self.scale_attention = nn.Linear(
            num_scales * input_size,
            input_size
        )
    
    def forward(self,
                sequence: torch.Tensor) -> torch.Tensor:
        """Process sequence at multiple scales.
        
        Args:
            sequence: Input sequence
            
        Returns:
            prediction: Integrated prediction
        """
        # Process at each scale
        scale_predictions = []
        for scale in self.temporal_scales:
            pred = scale.process_sequence(sequence)
            scale_predictions.append(pred[-1])
        
        # Integrate predictions
        combined = torch.cat(scale_predictions, dim=-1)
        prediction = self.scale_attention(combined)
        
        return prediction

Temporal Error Integration

class TemporalErrorIntegration:
    def __init__(self,
                 decay_factor: float = 0.9,
                 max_history: int = 100):
        """Initialize temporal error integration.
        
        Args:
            decay_factor: Error decay rate
            max_history: Maximum history length
        """
        self.decay_factor = decay_factor
        self.max_history = max_history
        self.error_history = []
        
    def integrate_error(self,
                       current_error: torch.Tensor) -> torch.Tensor:
        """Integrate error over time.
        
        Args:
            current_error: Current prediction error
            
        Returns:
            integrated: Temporally integrated error
        """
        # Update history
        self.error_history.append(current_error)
        if len(self.error_history) > self.max_history:
            self.error_history.pop(0)
        
        # Compute temporal weights
        weights = torch.tensor([
            self.decay_factor ** i
            for i in range(len(self.error_history))
        ])
        weights = weights / weights.sum()
        
        # Compute weighted sum
        integrated = sum(
            e * w for e, w in zip(self.error_history, weights)
        )
        
        return integrated

Usage Examples

Basic Usage

# Initialize temporal layer
temporal_layer = TemporalLayer(
    input_size=64,
    hidden_size=128,
    temporal_horizon=10
)

# Process sequence
sequence = torch.randn(20, 64)  # [T, input_size]
predictions = temporal_layer.process_sequence(sequence)

# Get next prediction
next_pred = temporal_layer.predict_next(sequence[-1])

Multi-scale Processing

# Initialize multi-scale model
multi_scale = MultiScaleTemporalModel(
    input_size=64,
    num_scales=3
)

# Process sequence
sequence = torch.randn(20, 64)
prediction = multi_scale(sequence)

Best Practices

Temporal Processing

  1. Handle variable sequences
  2. Implement teacher forcing
  3. Use appropriate horizons
  4. Consider multiple scales

Memory Management

  1. Limit history size
  2. Clear unused states
  3. Batch sequences
  4. Optimize storage

Training

  1. Curriculum learning
  2. Scheduled sampling
  3. Gradient clipping
  4. Validation splits

Common Issues

Stability

  1. Vanishing gradients
  2. Exploding predictions
  3. Memory overflow
  4. Error accumulation

Solutions

  1. Residual connections
  2. Gradient normalization
  3. Memory optimization
  4. Error gating