зеркало из
				https://github.com/docxology/cognitive.git
				synced 2025-10-30 20:56:04 +02:00 
			
		
		
		
	
		
			
				
	
	
		
			304 строки
		
	
	
		
			9.7 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			304 строки
		
	
	
		
			9.7 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| """
 | |
| Visualization module for Generic POMDP implementation.
 | |
| """
 | |
| 
 | |
| import matplotlib
 | |
| matplotlib.use('Agg')  # Use non-interactive backend
 | |
| import matplotlib.pyplot as plt
 | |
| import numpy as np
 | |
| import seaborn as sns
 | |
| from pathlib import Path
 | |
| from typing import Dict, List, Optional, Union
 | |
| import yaml
 | |
| 
 | |
| class POMDPVisualizer:
 | |
|     """Visualizer for Generic POMDP."""
 | |
|     
 | |
|     def __init__(self, config_path: Optional[Union[str, Path]] = None):
 | |
|         """Initialize visualizer.
 | |
|         
 | |
|         Args:
 | |
|             config_path: Optional path to configuration file
 | |
|         """
 | |
|         self.config = self._load_config(config_path)
 | |
|         self._setup_style()
 | |
|         
 | |
|     def _load_config(self, config_path: Optional[Union[str, Path]] = None) -> Dict:
 | |
|         """Load visualization configuration.
 | |
|         
 | |
|         Args:
 | |
|             config_path: Path to configuration file
 | |
|             
 | |
|         Returns:
 | |
|             Configuration dictionary
 | |
|         """
 | |
|         if config_path is None:
 | |
|             config_path = Path(__file__).parent / 'configuration.yaml'
 | |
|             
 | |
|         with open(config_path, 'r') as f:
 | |
|             config = yaml.safe_load(f)
 | |
|             
 | |
|         return config['visualization']
 | |
|         
 | |
|     def _setup_style(self):
 | |
|         """Setup plotting style."""
 | |
|         plt.style.use('default')  # Use default style instead of seaborn
 | |
|         plt.rcParams.update({
 | |
|             'figure.figsize': self.config['style']['figure_size'],
 | |
|             'font.size': self.config['style']['font_size']
 | |
|         })
 | |
|         
 | |
|     def _setup_output_dir(self) -> Path:
 | |
|         """Setup output directory.
 | |
|         
 | |
|         Returns:
 | |
|             Path to output directory
 | |
|         """
 | |
|         output_dir = Path(self.config['output_dir'])
 | |
|         output_dir.mkdir(parents=True, exist_ok=True)
 | |
|         return output_dir
 | |
|         
 | |
|     def plot_belief_evolution(self,
 | |
|                             beliefs: List[np.ndarray],
 | |
|                             save: bool = True) -> None:
 | |
|         """Plot evolution of beliefs over time.
 | |
|         
 | |
|         Args:
 | |
|             beliefs: List of belief arrays
 | |
|             save: Whether to save the plot
 | |
|         """
 | |
|         plt.figure()
 | |
|         
 | |
|         # Convert to array for easier plotting
 | |
|         beliefs_array = np.array(beliefs)
 | |
|         
 | |
|         # Plot each state's belief trajectory
 | |
|         for s in range(beliefs_array.shape[1]):
 | |
|             plt.plot(beliefs_array[:,s],
 | |
|                     label=f'State {s}',
 | |
|                     alpha=0.8)
 | |
|             
 | |
|         plt.xlabel('Time Step')
 | |
|         plt.ylabel('Belief Probability')
 | |
|         plt.title('Belief Evolution')
 | |
|         plt.legend()
 | |
|         plt.grid(True, alpha=0.3)
 | |
|         
 | |
|         if save:
 | |
|             self._save_plot('belief_evolution')
 | |
|         
 | |
|     def plot_free_energy(self,
 | |
|                         free_energies: List[float],
 | |
|                         save: bool = True) -> None:
 | |
|         """Plot free energy over time.
 | |
|         
 | |
|         Args:
 | |
|             free_energies: List of free energy values
 | |
|             save: Whether to save the plot
 | |
|         """
 | |
|         plt.figure()
 | |
|         
 | |
|         plt.plot(free_energies, 'b-', alpha=0.8)
 | |
|         plt.xlabel('Time Step')
 | |
|         plt.ylabel('Free Energy')
 | |
|         plt.title('Free Energy Evolution')
 | |
|         plt.grid(True, alpha=0.3)
 | |
|         
 | |
|         if save:
 | |
|             self._save_plot('free_energy')
 | |
|             
 | |
|     def plot_action_probabilities(self,
 | |
|                                 action_probs: List[np.ndarray],
 | |
|                                 save: bool = True) -> None:
 | |
|         """Plot action probability evolution.
 | |
|         
 | |
|         Args:
 | |
|             action_probs: List of action probability arrays
 | |
|             save: Whether to save the plot
 | |
|         """
 | |
|         plt.figure()
 | |
|         
 | |
|         # Convert to array for easier plotting
 | |
|         probs_array = np.array(action_probs)
 | |
|         
 | |
|         # Plot each action's probability trajectory
 | |
|         for a in range(probs_array.shape[1]):
 | |
|             plt.plot(probs_array[:,a],
 | |
|                     label=f'Action {a}',
 | |
|                     alpha=0.8)
 | |
|             
 | |
|         plt.xlabel('Time Step')
 | |
|         plt.ylabel('Action Probability')
 | |
|         plt.title('Action Selection Evolution')
 | |
|         plt.legend()
 | |
|         plt.grid(True, alpha=0.3)
 | |
|         
 | |
|         if save:
 | |
|             self._save_plot('action_probabilities')
 | |
|             
 | |
|     def plot_observation_counts(self,
 | |
|                               observations: List[int],
 | |
|                               num_observations: int,
 | |
|                               save: bool = True) -> None:
 | |
|         """Plot histogram of observations.
 | |
|         
 | |
|         Args:
 | |
|             observations: List of observation indices
 | |
|             num_observations: Total number of possible observations
 | |
|             save: Whether to save the plot
 | |
|         """
 | |
|         plt.figure()
 | |
|         
 | |
|         plt.hist(observations,
 | |
|                 bins=np.arange(num_observations + 1) - 0.5,
 | |
|                 rwidth=0.8,
 | |
|                 alpha=0.8)
 | |
|         plt.xlabel('Observation')
 | |
|         plt.ylabel('Count')
 | |
|         plt.title('Observation Distribution')
 | |
|         plt.grid(True, alpha=0.3)
 | |
|         
 | |
|         if save:
 | |
|             self._save_plot('observation_counts')
 | |
|             
 | |
|     def plot_state_transition_matrix(self,
 | |
|                                    B_matrix: np.ndarray,
 | |
|                                    action: int,
 | |
|                                    save: bool = True) -> None:
 | |
|         """Plot state transition matrix for given action.
 | |
|         
 | |
|         Args:
 | |
|             B_matrix: State transition matrix
 | |
|             action: Action index
 | |
|             save: Whether to save the plot
 | |
|         """
 | |
|         plt.figure()
 | |
|         
 | |
|         plt.imshow(B_matrix[:,:,action],
 | |
|                   cmap=self.config['style']['colormap'],
 | |
|                   aspect='auto')
 | |
|         plt.colorbar(label='Transition Probability')
 | |
|         plt.xlabel('Current State')
 | |
|         plt.ylabel('Next State')
 | |
|         plt.title(f'State Transition Matrix (Action {action})')
 | |
|         
 | |
|         if save:
 | |
|             self._save_plot(f'transition_matrix_action_{action}')
 | |
|             
 | |
|     def plot_observation_matrix(self,
 | |
|                               A_matrix: np.ndarray,
 | |
|                               save: bool = True) -> None:
 | |
|         """Plot observation matrix.
 | |
|         
 | |
|         Args:
 | |
|             A_matrix: Observation matrix
 | |
|             save: Whether to save the plot
 | |
|         """
 | |
|         plt.figure()
 | |
|         
 | |
|         plt.imshow(A_matrix,
 | |
|                   cmap=self.config['style']['colormap'],
 | |
|                   aspect='auto')
 | |
|         plt.colorbar(label='Observation Probability')
 | |
|         plt.xlabel('State')
 | |
|         plt.ylabel('Observation')
 | |
|         plt.title('Observation Matrix')
 | |
|         
 | |
|         if save:
 | |
|             self._save_plot('observation_matrix')
 | |
|             
 | |
|     def plot_preferences(self,
 | |
|                         C_matrix: np.ndarray,
 | |
|                         save: bool = True) -> None:
 | |
|         """Plot preference matrix.
 | |
|         
 | |
|         Args:
 | |
|             C_matrix: Preference matrix
 | |
|             save: Whether to save the plot
 | |
|         """
 | |
|         plt.figure()
 | |
|         
 | |
|         plt.imshow(C_matrix,
 | |
|                   cmap=self.config['style']['colormap'],
 | |
|                   aspect='auto')
 | |
|         plt.colorbar(label='Preference Value')
 | |
|         plt.xlabel('Time Step')
 | |
|         plt.ylabel('Observation')
 | |
|         plt.title('Preference Matrix')
 | |
|         
 | |
|         if save:
 | |
|             self._save_plot('preferences')
 | |
|             
 | |
|     def plot_belief_entropy(self,
 | |
|                            beliefs: List[np.ndarray],
 | |
|                            save: bool = True) -> None:
 | |
|         """Plot belief entropy over time.
 | |
|         
 | |
|         Args:
 | |
|             beliefs: List of belief arrays
 | |
|             save: Whether to save the plot
 | |
|         """
 | |
|         plt.figure()
 | |
|         
 | |
|         # Compute entropy for each belief state
 | |
|         entropies = [-np.sum(b * np.log2(b + 1e-12)) for b in beliefs]
 | |
|         
 | |
|         plt.plot(entropies, 'r-', alpha=0.8)
 | |
|         plt.xlabel('Time Step')
 | |
|         plt.ylabel('Belief Entropy (bits)')
 | |
|         plt.title('Belief Entropy Evolution')
 | |
|         plt.grid(True, alpha=0.3)
 | |
|         
 | |
|         if save:
 | |
|             self._save_plot('belief_entropy')
 | |
|             
 | |
|     def plot_all(self,
 | |
|                  model_state: Dict,
 | |
|                  model_params: Dict) -> None:
 | |
|         """Plot all available visualizations.
 | |
|         
 | |
|         Args:
 | |
|             model_state: Dictionary containing model state history
 | |
|             model_params: Dictionary containing model parameters
 | |
|         """
 | |
|         # Plot belief evolution
 | |
|         self.plot_belief_evolution(model_state['history']['beliefs'])
 | |
|         
 | |
|         # Plot free energy
 | |
|         self.plot_free_energy(model_state['history']['free_energy'])
 | |
|         
 | |
|         # Plot action probabilities if available
 | |
|         if 'policy_probs' in model_state['history']:
 | |
|             self.plot_action_probabilities(model_state['history']['policy_probs'])
 | |
|         
 | |
|         # Plot observation counts
 | |
|         self.plot_observation_counts(
 | |
|             model_state['history']['observations'],
 | |
|             model_params['num_observations']
 | |
|         )
 | |
|         
 | |
|         # Plot matrices
 | |
|         self.plot_observation_matrix(model_params['A'])
 | |
|         
 | |
|         for a in range(model_params['num_actions']):
 | |
|             self.plot_state_transition_matrix(model_params['B'], a)
 | |
|             
 | |
|         self.plot_preferences(model_params['C'])
 | |
|         
 | |
|         # Plot belief entropy
 | |
|         self.plot_belief_entropy(model_state['history']['beliefs'])
 | |
|         
 | |
|     def _save_plot(self, name: str) -> None:
 | |
|         """Save current plot to file.
 | |
|         
 | |
|         Args:
 | |
|             name: Base name for the plot file
 | |
|         """
 | |
|         output_dir = self._setup_output_dir()
 | |
|         
 | |
|         # Save in all configured formats
 | |
|         for fmt in self.config['formats']:
 | |
|             path = output_dir / f'{name}.{fmt}'
 | |
|             plt.savefig(path,
 | |
|                        dpi=self.config['style']['dpi'],
 | |
|                        bbox_inches='tight')  | 
