Этот коммит содержится в:
Daniel Ari Friedman 2025-02-07 08:50:37 -08:00
родитель 6caa1a7cb1
Коммит ccdeafd068
45 изменённых файлов: 5417 добавлений и 1 удалений

24
.obsidian/workspace.json поставляемый
Просмотреть файл

@ -162,5 +162,27 @@
}
},
"active": "0ba8a1b9b9dd949a",
"lastOpenFiles": []
"lastOpenFiles": [
"src/models/active_inference/dispatcher.py",
"Things/Continuous_Generic/__pycache__/test_continuous_generic.cpython-310-pytest-8.3.2.pyc",
"Things/Continuous_Generic/__pycache__/conftest.cpython-310-pytest-8.3.2.pyc",
"Things/Continuous_Generic/Output/tests/single_step",
"Things/Continuous_Generic/Output/tests/multi_step",
"Things/Continuous_Generic/Output/tests/complex",
"Things/Continuous_Generic/Output/tests/basic",
"Things/Continuous_Generic/Output/tests",
"Things/Continuous_Generic/Output",
"Things/Continuous_Generic/__pycache__/visualization.cpython-310.pyc",
"Things/Continuous_Generic/__pycache__/continuous_generic.cpython-310.pyc",
"Things/Continuous_Generic/Output/tests/complex/energy_conservation/energy_evolution.gif",
"Things/Continuous_Generic/Output/tests/complex/energy_conservation/energy_analysis.png",
"Things/Continuous_Generic/Output/tests/complex/taylor_prediction/prediction_analysis.png",
"Things/Continuous_Generic/Output/tests/complex/generalized_coordinates/error_analysis.png",
"Things/Continuous_Generic/Output/tests/complex/generalized_coordinates/derivative_analysis.png",
"Things/Continuous_Generic/Output/tests/complex/generalized_coordinates/generalized_coordinates.gif",
"Things/Continuous_Generic/Output/tests/complex/driven_oscillator/prediction_analysis.png",
"Things/Continuous_Generic/Output/tests/complex/driven_oscillator/state_correlations.png",
"Things/Continuous_Generic/Output/tests/complex/driven_oscillator/time_evolution.png",
"Things/Continuous_Generic/Output/tests/complex/driven_oscillator/phase_space.png"
]
}

Просмотреть файл

@ -0,0 +1,100 @@
{
"test_cases": [
[
0.1,
0.2
],
[
0.01,
0.02
],
[
0.0,
0.0
],
[
-0.1,
0.1
],
[
1.0,
1.0
]
],
"results": [
{
"input": [
0.1,
0.2
],
"output": [
0.10250000000000001,
0.20400000000000001
],
"ratio": [
1.025,
1.02
],
"scale_factor": 1.0225
},
{
"input": [
0.01,
0.02
],
"output": [
0.010025,
0.02004
],
"ratio": [
1.0025,
1.002
],
"scale_factor": 1.00225
},
{
"input": [
0.0,
0.0
],
"output": [
0.0,
0.0
],
"ratio": null,
"scale_factor": null
},
{
"input": [
-0.1,
0.1
],
"output": [
-0.1005,
0.0995
],
"ratio": [
1.005,
0.995
],
"scale_factor": 1.0
},
{
"input": [
1.0,
1.0
],
"output": [
1.1500000000000001,
1.1500000000000001
],
"ratio": [
1.1500000000000001,
1.1500000000000001
],
"scale_factor": 1.1500000000000001
}
],
"mean_scale": 1.0436875,
"scale_std": 0.06200185859109391
}

Двоичный файл не отображается.

После

Ширина:  |  Высота:  |  Размер: 88 KiB

Просмотреть файл

@ -0,0 +1,29 @@
{
"operator": [
[
0.0,
1.0,
0.0
],
[
0.0,
0.0,
2.0
],
[
0.0,
0.0,
0.0
]
],
"input": [
1.0,
2.0,
3.0
],
"output": [
2.0,
6.0,
0.0
]
}

Разница между файлами не показана из-за своего большого размера Загрузить разницу

Двоичные данные
Things/Continuous_Generic/Output/tests/complex/driven_oscillator/phase_space.png Обычный файл

Двоичный файл не отображается.

После

Ширина:  |  Высота:  |  Размер: 45 KiB

Двоичный файл не отображается.

После

Ширина:  |  Высота:  |  Размер: 31 KiB

Двоичный файл не отображается.

После

Ширина:  |  Высота:  |  Размер: 37 KiB

Двоичные данные
Things/Continuous_Generic/Output/tests/complex/driven_oscillator/time_evolution.png Обычный файл

Двоичный файл не отображается.

После

Ширина:  |  Высота:  |  Размер: 72 KiB

Двоичный файл не отображается.

После

Ширина:  |  Высота:  |  Размер: 188 KiB

Двоичный файл не отображается.

После

Ширина:  |  Высота:  |  Размер: 566 KiB

Двоичный файл не отображается.

После

Ширина:  |  Высота:  |  Размер: 84 KiB

Двоичный файл не отображается.

После

Ширина:  |  Высота:  |  Размер: 54 KiB

Двоичный файл не отображается.

После

Ширина:  |  Высота:  |  Размер: 353 KiB

Разница между файлами не показана из-за своего большого размера Загрузить разницу

Двоичные данные
Things/Continuous_Generic/Output/tests/complex/harmonic_motion/energy_analysis.png Обычный файл

Двоичный файл не отображается.

После

Ширина:  |  Высота:  |  Размер: 74 KiB

Двоичный файл не отображается.

После

Ширина:  |  Высота:  |  Размер: 89 KiB

Двоичные данные
Things/Continuous_Generic/Output/tests/complex/harmonic_motion/phase_space.png Обычный файл

Двоичный файл не отображается.

После

Ширина:  |  Высота:  |  Размер: 57 KiB

Двоичный файл не отображается.

После

Ширина:  |  Высота:  |  Размер: 59 KiB

Двоичный файл не отображается.

После

Ширина:  |  Высота:  |  Размер: 44 KiB

Двоичные данные
Things/Continuous_Generic/Output/tests/complex/harmonic_motion/time_evolution.png Обычный файл

Двоичный файл не отображается.

После

Ширина:  |  Высота:  |  Размер: 70 KiB

Двоичный файл не отображается.

После

Ширина:  |  Высота:  |  Размер: 123 KiB

Просмотреть файл

@ -0,0 +1,148 @@
{
"state_history": [
[
[
0.0010002076083012311,
0.0005000830404114237,
0.0001
],
[
-0.0009997924062441156,
0.0004999169595885763,
-0.0001
]
],
[
[
0.0010004152510810118,
0.0005001660808228474,
0.0001
],
[
-0.0009995848469667805,
0.0004998339191771527,
-0.0001
]
],
[
[
0.0010006229283393419,
0.000500249121234271,
0.0001
],
[
-0.0009993773221679953,
0.000499750878765729,
-0.0001
]
],
[
[
0.0010008306400762217,
0.0005003321616456947,
0.0001
],
[
-0.0009991698318477598,
0.0004996678383543053,
-0.0001
]
],
[
[
0.0010010383862916513,
0.0005004152020571184,
0.0001
],
[
-0.0009989623760060739,
0.0004995847979428817,
-0.0001
]
],
[
[
0.0010012461669856303,
0.000500498242468542,
0.0001
],
[
-0.0009987549546429378,
0.000499501757531458,
-0.0001
]
],
[
[
0.0010014539821581589,
0.0005005812828799657,
0.0001
],
[
-0.0009985475677583508,
0.0004994187171200343,
-0.0001
]
],
[
[
0.0010016618318092372,
0.0005006643232913894,
0.0001
],
[
-0.000998340215352314,
0.0004993356767086106,
-0.0001
]
],
[
[
0.0010018697159388656,
0.000500747363702813,
0.0001
],
[
-0.000998132897424827,
0.000499252636297187,
-0.0001
]
],
[
[
0.0010020776345470433,
0.0005008304041142367,
0.0001
],
[
-0.0009979256139758896,
0.0004991695958857633,
-0.0001
]
]
],
"free_energy": [
55.26203700203595,
55.262037002035136,
55.26203700203414,
55.26203700203296,
55.26203700203158,
55.26203700203003,
55.262037002028286,
55.26203700202633,
55.262037002024215,
55.2620370020219
],
"prediction_error": [
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0
]
}

Двоичный файл не отображается.

После

Ширина:  |  Высота:  |  Размер: 87 KiB

Двоичный файл не отображается.

После

Ширина:  |  Высота:  |  Размер: 31 KiB

Двоичные данные
Things/Continuous_Generic/Output/tests/multi_step/belief_evolution/phase_space.png Обычный файл

Двоичный файл не отображается.

После

Ширина:  |  Высота:  |  Размер: 86 KiB

Двоичный файл не отображается.

После

Ширина:  |  Высота:  |  Размер: 57 KiB

Двоичный файл не отображается.

После

Ширина:  |  Высота:  |  Размер: 26 KiB

Двоичный файл не отображается.

После

Ширина:  |  Высота:  |  Размер: 67 KiB

Двоичный файл не отображается.

После

Ширина:  |  Высота:  |  Размер: 95 KiB

Просмотреть файл

@ -0,0 +1,660 @@
{
"state_history": [
[
[
0.0999999991758025,
0.0,
0.0
],
[
-0.09999999916334613,
0.0,
0.0
]
],
[
[
0.099999998351605,
0.0,
0.0
],
[
-0.09999999832669229,
0.0,
0.0
]
],
[
[
0.09999999752740753,
0.0,
0.0
],
[
-0.09999999749003845,
0.0,
0.0
]
],
[
[
0.09999999670321011,
0.0,
0.0
],
[
-0.09999999665338462,
0.0,
0.0
]
],
[
[
0.09999999587901269,
0.0,
0.0
],
[
-0.09999999581673082,
0.0,
0.0
]
],
[
[
0.09999999505481527,
0.0,
0.0
],
[
-0.09999999498007706,
0.0,
0.0
]
],
[
[
0.09999999423061788,
0.0,
0.0
],
[
-0.09999999414342331,
0.0,
0.0
]
],
[
[
0.09999999340642052,
0.0,
0.0
],
[
-0.09999999330676958,
0.0,
0.0
]
],
[
[
0.09999999258222318,
0.0,
0.0
],
[
-0.09999999247011587,
0.0,
0.0
]
],
[
[
0.09999999175802586,
0.0,
0.0
],
[
-0.09999999163346218,
0.0,
0.0
]
],
[
[
0.09999999093382855,
0.0,
0.0
],
[
-0.09999999079680853,
0.0,
0.0
]
],
[
[
0.09999999010963126,
0.0,
0.0
],
[
-0.0999999899601549,
0.0,
0.0
]
],
[
[
0.09999998928543398,
0.0,
0.0
],
[
-0.09999998912350128,
0.0,
0.0
]
],
[
[
0.09999998846123674,
0.0,
0.0
],
[
-0.0999999882868477,
0.0,
0.0
]
],
[
[
0.09999998763703953,
0.0,
0.0
],
[
-0.09999998745019412,
0.0,
0.0
]
],
[
[
0.09999998681284235,
0.0,
0.0
],
[
-0.09999998661354056,
0.0,
0.0
]
],
[
[
0.09999998598864518,
0.0,
0.0
],
[
-0.09999998577688703,
0.0,
0.0
]
],
[
[
0.09999998516444801,
0.0,
0.0
],
[
-0.0999999849402335,
0.0,
0.0
]
],
[
[
0.09999998434025086,
0.0,
0.0
],
[
-0.09999998410358002,
0.0,
0.0
]
],
[
[
0.09999998351605374,
0.0,
0.0
],
[
-0.09999998326692656,
0.0,
0.0
]
],
[
[
0.09999998269185664,
0.0,
0.0
],
[
-0.09999998243027311,
0.0,
0.0
]
],
[
[
0.09999998186765957,
0.0,
0.0
],
[
-0.09999998159361968,
0.0,
0.0
]
],
[
[
0.0999999810434625,
0.0,
0.0
],
[
-0.09999998075696626,
0.0,
0.0
]
],
[
[
0.09999998021926547,
0.0,
0.0
],
[
-0.09999997992031284,
0.0,
0.0
]
],
[
[
0.09999997939506845,
0.0,
0.0
],
[
-0.09999997908365944,
0.0,
0.0
]
],
[
[
0.09999997857087146,
0.0,
0.0
],
[
-0.09999997824700611,
0.0,
0.0
]
],
[
[
0.09999997774667448,
0.0,
0.0
],
[
-0.09999997741035281,
0.0,
0.0
]
],
[
[
0.09999997692247753,
0.0,
0.0
],
[
-0.09999997657369951,
0.0,
0.0
]
],
[
[
0.09999997609828061,
0.0,
0.0
],
[
-0.09999997573704623,
0.0,
0.0
]
],
[
[
0.09999997527408369,
0.0,
0.0
],
[
-0.09999997490039297,
0.0,
0.0
]
],
[
[
0.0999999744498868,
0.0,
0.0
],
[
-0.09999997406373973,
0.0,
0.0
]
],
[
[
0.09999997362568992,
0.0,
0.0
],
[
-0.0999999732270865,
0.0,
0.0
]
],
[
[
0.09999997280149306,
0.0,
0.0
],
[
-0.09999997239043328,
0.0,
0.0
]
],
[
[
0.09999997197729622,
0.0,
0.0
],
[
-0.0999999715537801,
0.0,
0.0
]
],
[
[
0.09999997115309943,
0.0,
0.0
],
[
-0.09999997071712693,
0.0,
0.0
]
],
[
[
0.09999997032890263,
0.0,
0.0
],
[
-0.09999996988047377,
0.0,
0.0
]
],
[
[
0.09999996950470587,
0.0,
0.0
],
[
-0.09999996904382066,
0.0,
0.0
]
],
[
[
0.0999999686805091,
0.0,
0.0
],
[
-0.09999996820716758,
0.0,
0.0
]
],
[
[
0.09999996785631235,
0.0,
0.0
],
[
-0.09999996737051453,
0.0,
0.0
]
],
[
[
0.09999996703211564,
0.0,
0.0
],
[
-0.09999996653386149,
0.0,
0.0
]
],
[
[
0.09999996620791893,
0.0,
0.0
],
[
-0.09999996569720845,
0.0,
0.0
]
],
[
[
0.09999996538372227,
0.0,
0.0
],
[
-0.09999996486055543,
0.0,
0.0
]
],
[
[
0.09999996455952563,
0.0,
0.0
],
[
-0.09999996402390247,
0.0,
0.0
]
],
[
[
0.09999996373532899,
0.0,
0.0
],
[
-0.09999996318724952,
0.0,
0.0
]
],
[
[
0.09999996291113236,
0.0,
0.0
],
[
-0.09999996235059656,
0.0,
0.0
]
],
[
[
0.09999996208693578,
0.0,
0.0
],
[
-0.09999996151394362,
0.0,
0.0
]
],
[
[
0.0999999612627392,
0.0,
0.0
],
[
-0.09999996067729072,
0.0,
0.0
]
],
[
[
0.09999996043854267,
0.0,
0.0
],
[
-0.09999995984063785,
0.0,
0.0
]
],
[
[
0.09999995961434617,
0.0,
0.0
],
[
-0.09999995900398499,
0.0,
0.0
]
],
[
[
0.09999995879014968,
0.0,
0.0
],
[
-0.09999995816733213,
0.0,
0.0
]
]
],
"distance_history": [
0.14142135506291026,
0.141421353888511,
0.14142135271411177,
0.14142135153971258,
0.14142135036531342,
0.1414213491909143,
0.14142134801651518,
0.14142134684211613,
0.14142134566771708,
0.14142134449331806,
0.1414213433189191,
0.14142134214452015,
0.1414213409701212,
0.14142133979572233,
0.14142133862132347,
0.14142133744692467,
0.14142133627252587,
0.1414213350981271,
0.14142133392372833,
0.14142133274932964,
0.14142133157493095,
0.14142133040053229,
0.14142132922613365,
0.14142132805173502,
0.14142132687733644,
0.14142132570293792,
0.14142132452853942,
0.14142132335414095,
0.14142132217974251,
0.14142132100534408,
0.1414213198309457,
0.1414213186565473,
0.14142131748214895,
0.14142131630775065,
0.14142131513335238,
0.1414213139589541,
0.1414213127845559,
0.14142131161015772,
0.14142131043575956,
0.14142130926136143,
0.14142130808696332,
0.14142130691256524,
0.14142130573816722,
0.14142130456376922,
0.14142130338937123,
0.14142130221497326,
0.14142130104057532,
0.14142129986617746,
0.1414212986917796,
0.14142129751738178
],
"target": [
0.0,
0.0
]
}

Двоичные данные
Things/Continuous_Generic/Output/tests/multi_step/convergence/phase_space.png Обычный файл

Двоичный файл не отображается.

После

Ширина:  |  Высота:  |  Размер: 40 KiB

Двоичный файл не отображается.

После

Ширина:  |  Высота:  |  Размер: 46 KiB

Двоичные данные
Things/Continuous_Generic/Output/tests/multi_step/convergence/state_correlations.png Обычный файл

Двоичный файл не отображается.

После

Ширина:  |  Высота:  |  Размер: 36 KiB

Двоичные данные
Things/Continuous_Generic/Output/tests/multi_step/convergence/time_evolution.png Обычный файл

Двоичный файл не отображается.

После

Ширина:  |  Высота:  |  Размер: 68 KiB

Просмотреть файл

@ -0,0 +1,25 @@
{
"state": [
[
0.09999999587901269,
0.0,
0.0
],
[
-0.09999999581673082,
0.0,
0.0
]
],
"observation": [
0.0,
0.0
],
"free_energy": [
0.020000279999999673,
13.835582850610585,
13.835582900097055,
13.835582949583518,
13.835582999069986
]
}

Двоичный файл не отображается.

После

Ширина:  |  Высота:  |  Размер: 73 KiB

Двоичный файл не отображается.

После

Ширина:  |  Высота:  |  Размер: 26 KiB

Просмотреть файл

@ -0,0 +1,35 @@
{
"initial_state": [
[
0.010002075937558768,
0.004999169595885761,
-0.0010000000000000022
],
[
0.019997924062441125,
-0.004999169595885761,
0.0009999999999999979
]
],
"final_state": [
[
0.010002075937558768,
0.004999169595885761,
-0.0010000000000000022
],
[
0.019997924062441125,
-0.004999169595885761,
0.0009999999999999979
]
],
"predicted_position": [
0.010000499995,
0.019999500005
],
"predicted_velocity": [
0.0049999,
-0.0049999
],
"dt": 0.0001
}

Двоичный файл не отображается.

Двоичный файл не отображается.

Двоичный файл не отображается.

Двоичный файл не отображается.

202
src/models/active_inference/dispatcher.py Обычный файл
Просмотреть файл

@ -0,0 +1,202 @@
"""
Active Inference method dispatcher and high-level abstractions.
Provides a clean interface for dispatching active inference operations
to appropriate low-level implementations.
"""
from enum import Enum
from typing import Dict, List, Optional, Tuple, Union, Any, Callable
import numpy as np
from dataclasses import dataclass
from pathlib import Path
from .base import ActiveInferenceModel, ModelState
from ..matrices.matrix_ops import MatrixOps, MatrixInitializer
class InferenceMethod(Enum):
"""Supported inference methods."""
VARIATIONAL = "variational"
SAMPLING = "sampling"
MEAN_FIELD = "mean_field"
class PolicyType(Enum):
"""Supported policy types."""
DISCRETE = "discrete"
CONTINUOUS = "continuous"
HIERARCHICAL = "hierarchical"
@dataclass
class InferenceConfig:
"""Configuration for inference method dispatch."""
method: InferenceMethod
policy_type: PolicyType
temporal_horizon: int
learning_rate: float
precision_init: float
use_gpu: bool = False
custom_params: Optional[Dict[str, Any]] = None
class ActiveInferenceDispatcher:
"""
Dispatcher for Active Inference operations.
Provides high-level interface and handles routing to specific implementations.
"""
def __init__(self, config: InferenceConfig):
"""Initialize dispatcher with configuration."""
self.config = config
self._setup_implementations()
self._initialize_matrices()
def _setup_implementations(self):
"""Set up mapping of operations to implementations."""
self._implementations = {
InferenceMethod.VARIATIONAL: {
'belief_update': self._variational_belief_update,
'policy_inference': self._variational_policy_inference
},
InferenceMethod.SAMPLING: {
'belief_update': self._sampling_belief_update,
'policy_inference': self._sampling_policy_inference
},
InferenceMethod.MEAN_FIELD: {
'belief_update': self._mean_field_belief_update,
'policy_inference': self._mean_field_policy_inference
}
}
def _initialize_matrices(self):
"""Initialize required matrices based on configuration."""
self.matrix_ops = MatrixOps()
self.matrix_init = MatrixInitializer()
def dispatch_belief_update(self,
observation: np.ndarray,
current_state: ModelState,
**kwargs) -> np.ndarray:
"""
Dispatch belief update to appropriate implementation.
Args:
observation: Current observation
current_state: Current model state
**kwargs: Additional parameters for specific implementations
Returns:
Updated beliefs
"""
update_fn = self._implementations[self.config.method]['belief_update']
return update_fn(observation, current_state, **kwargs)
def dispatch_policy_inference(self,
state: ModelState,
goal_prior: Optional[np.ndarray] = None,
**kwargs) -> np.ndarray:
"""
Dispatch policy inference to appropriate implementation.
Args:
state: Current model state
goal_prior: Optional prior over goal states
**kwargs: Additional parameters for specific implementations
Returns:
Inferred policy distributions
"""
inference_fn = self._implementations[self.config.method]['policy_inference']
return inference_fn(state, goal_prior, **kwargs)
def _variational_belief_update(self,
observation: np.ndarray,
state: ModelState,
**kwargs) -> np.ndarray:
"""Variational implementation of belief updates."""
# Implementation details for variational belief updates
prediction = np.dot(state.beliefs, kwargs.get('generative_matrix', np.eye(len(state.beliefs))))
prediction_error = observation - prediction
belief_update = state.precision * prediction_error
return state.beliefs + belief_update
def _sampling_belief_update(self,
observation: np.ndarray,
state: ModelState,
**kwargs) -> np.ndarray:
"""Sampling-based implementation of belief updates."""
# Implementation for sampling-based updates
raise NotImplementedError("Sampling-based belief updates not yet implemented")
def _mean_field_belief_update(self,
observation: np.ndarray,
state: ModelState,
**kwargs) -> np.ndarray:
"""Mean-field implementation of belief updates."""
# Implementation for mean-field updates
raise NotImplementedError("Mean-field belief updates not yet implemented")
def _variational_policy_inference(self,
state: ModelState,
goal_prior: Optional[np.ndarray] = None,
**kwargs) -> np.ndarray:
"""Variational implementation of policy inference."""
# Implementation for variational policy inference
if goal_prior is None:
goal_prior = np.ones(len(state.policies)) / len(state.policies)
expected_free_energy = self._calculate_expected_free_energy(
state, goal_prior, **kwargs)
return self.matrix_ops.softmax(-expected_free_energy)
def _sampling_policy_inference(self,
state: ModelState,
goal_prior: Optional[np.ndarray] = None,
**kwargs) -> np.ndarray:
"""Sampling-based implementation of policy inference."""
raise NotImplementedError("Sampling-based policy inference not yet implemented")
def _mean_field_policy_inference(self,
state: ModelState,
goal_prior: Optional[np.ndarray] = None,
**kwargs) -> np.ndarray:
"""Mean-field implementation of policy inference."""
raise NotImplementedError("Mean-field policy inference not yet implemented")
def _calculate_expected_free_energy(self,
state: ModelState,
goal_prior: np.ndarray,
**kwargs) -> np.ndarray:
"""Calculate expected free energy for policy evaluation."""
# Basic implementation - can be extended based on specific needs
pragmatic_value = -np.log(goal_prior + 1e-8) # Avoid log(0)
epistemic_value = self._calculate_epistemic_value(state)
return pragmatic_value + epistemic_value
def _calculate_epistemic_value(self, state: ModelState) -> np.ndarray:
"""Calculate epistemic value component of expected free energy."""
# Simple implementation - can be extended
return -state.prediction_error * np.ones(len(state.policies))
class ActiveInferenceFactory:
"""Factory for creating Active Inference instances with specific configurations."""
@staticmethod
def create(config: InferenceConfig) -> ActiveInferenceDispatcher:
"""Create an Active Inference dispatcher with specified configuration."""
return ActiveInferenceDispatcher(config)
@staticmethod
def create_from_yaml(config_path: Union[str, Path]) -> ActiveInferenceDispatcher:
"""Create an Active Inference dispatcher from YAML configuration."""
import yaml
with open(config_path, 'r') as f:
config_dict = yaml.safe_load(f)
config = InferenceConfig(
method=InferenceMethod(config_dict['method']),
policy_type=PolicyType(config_dict['policy_type']),
temporal_horizon=config_dict['temporal_horizon'],
learning_rate=config_dict['learning_rate'],
precision_init=config_dict['precision_init'],
use_gpu=config_dict.get('use_gpu', False),
custom_params=config_dict.get('custom_params', None)
)
return ActiveInferenceFactory.create(config)

276
tests/test_continuous_generic.py Обычный файл
Просмотреть файл

@ -0,0 +1,276 @@
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
import seaborn as sns
from scipy.signal import hilbert
import pytest
from pathlib import Path
class TestComplexDynamics:
"""Test suite for complex dynamical behaviors."""
@pytest.fixture(autouse=True)
def setup(self, tmp_path):
"""Setup test fixtures."""
plt.style.use('seaborn')
sns.set_palette("husl")
self.output_dir = tmp_path / "figures"
self.output_dir.mkdir(exist_ok=True)
# Test parameters
self.dt = 0.01
self.t_max = 10.0
self.num_steps = int(self.t_max / self.dt)
self.time = np.linspace(0, self.t_max, self.num_steps)
# System parameters
self.omega = 2.0 # Natural frequency
self.damping = 0.1
self.amplitude = 1.0
self.frequency = 2.5
yield
# Cleanup
plt.close('all')
def simulate_harmonic_motion(self):
"""Simulate harmonic oscillator dynamics."""
states = np.zeros((self.num_steps, 2))
states[0] = [1.0, 0.0] # Initial conditions
for i in range(1, self.num_steps):
# Simple harmonic motion with damping
states[i, 0] = states[i-1, 0] + self.dt * states[i-1, 1]
states[i, 1] = states[i-1, 1] - self.dt * (
self.omega**2 * states[i-1, 0] +
2 * self.damping * states[i-1, 1]
)
return states
def test_harmonic_motion(self):
"""Test visualization of harmonic motion dynamics."""
states = self.simulate_harmonic_motion()
free_energy = self.compute_free_energy(states)
fig = plt.figure(figsize=(15, 10))
gs = GridSpec(2, 3, figure=fig)
# Phase space trajectory
ax1 = fig.add_subplot(gs[0, 0])
ax1.plot(states[:, 0], states[:, 1], 'b-', label='Phase trajectory')
ax1.set_xlabel('Position')
ax1.set_ylabel('Velocity')
ax1.set_title('Phase Space')
ax1.grid(True)
ax1.legend()
# Time series
ax2 = fig.add_subplot(gs[0, 1:])
ax2.plot(self.time, states[:, 0], 'b-', label='Position')
ax2.plot(self.time, states[:, 1], 'r--', label='Velocity')
ax2.set_xlabel('Time (s)')
ax2.set_ylabel('State Variables')
ax2.set_title('Time Evolution')
ax2.grid(True)
ax2.legend()
# Energy plot
ax3 = fig.add_subplot(gs[1, 0])
kinetic = 0.5 * states[:, 1]**2
potential = 0.5 * self.omega**2 * states[:, 0]**2
total = kinetic + potential
ax3.plot(self.time, kinetic, 'g-', label='Kinetic')
ax3.plot(self.time, potential, 'r-', label='Potential')
ax3.plot(self.time, total, 'k--', label='Total')
ax3.set_xlabel('Time (s)')
ax3.set_ylabel('Energy')
ax3.set_title('Energy Components')
ax3.grid(True)
ax3.legend()
# Free energy evolution
ax4 = fig.add_subplot(gs[1, 1:])
if len(free_energy) > 0: # Only plot if free energy is computed
ax4.plot(self.time, free_energy, 'b-', label='Free Energy')
ax4.legend()
ax4.set_xlabel('Time (s)')
ax4.set_ylabel('Free Energy')
ax4.set_title('Free Energy Evolution')
ax4.grid(True)
plt.tight_layout()
plt.savefig(self.output_dir / 'harmonic_motion_analysis.png', dpi=300, bbox_inches='tight')
plt.close()
# Assertions to verify the simulation
assert np.all(np.isfinite(states)), "States contain invalid values"
assert np.all(np.abs(total - total[0]) < 1e-2), "Energy is not conserved"
def test_driven_oscillator(self):
"""Test visualization of driven oscillator dynamics."""
states = self.simulate_driven_oscillator()
fig = plt.figure(figsize=(15, 12))
gs = GridSpec(3, 2, figure=fig)
# Phase space with driving force
ax1 = fig.add_subplot(gs[0, 0])
ax1.plot(states[:, 0], states[:, 1], 'b-', label='Phase trajectory')
ax1.set_xlabel('Position')
ax1.set_ylabel('Velocity')
ax1.set_title('Phase Space')
ax1.grid(True)
ax1.legend()
# Time series with driving force
ax2 = fig.add_subplot(gs[0, 1])
driving_force = self.amplitude * np.sin(self.frequency * self.time)
ax2.plot(self.time, states[:, 0], 'b-', label='Position')
ax2.plot(self.time, states[:, 1], 'r--', label='Velocity')
ax2.plot(self.time, driving_force, 'g:', label='Driving Force')
ax2.set_xlabel('Time (s)')
ax2.set_ylabel('State Variables')
ax2.set_title('Time Evolution')
ax2.grid(True)
ax2.legend()
# Power spectrum
ax3 = fig.add_subplot(gs[1, :])
frequencies = np.fft.fftfreq(len(states), self.dt)
spectrum = np.abs(np.fft.fft(states[:, 0]))
mask = frequencies > 0 # Only show positive frequencies
ax3.plot(frequencies[mask], spectrum[mask], 'b-', label='Position Spectrum')
ax3.set_xlabel('Frequency (Hz)')
ax3.set_ylabel('Amplitude')
ax3.set_title('Power Spectrum')
ax3.grid(True)
ax3.legend()
# Phase difference analysis
ax4 = fig.add_subplot(gs[2, 0])
phase_diff = np.angle(hilbert(states[:, 0])) - np.angle(hilbert(driving_force))
ax4.plot(self.time, np.unwrap(phase_diff), 'r-', label='Phase Difference')
ax4.set_xlabel('Time (s)')
ax4.set_ylabel('Phase Difference (rad)')
ax4.set_title('Phase Relationship')
ax4.grid(True)
ax4.legend()
# Response amplitude vs time
ax5 = fig.add_subplot(gs[2, 1])
envelope = np.abs(hilbert(states[:, 0]))
ax5.plot(self.time, envelope, 'r-', label='Response Amplitude')
ax5.plot(self.time, np.abs(driving_force), 'b--', label='Driving Amplitude')
ax5.set_xlabel('Time (s)')
ax5.set_ylabel('Amplitude')
ax5.set_title('Response Amplitude')
ax5.grid(True)
ax5.legend()
plt.tight_layout()
plt.savefig(self.output_dir / 'driven_oscillator_analysis.png', dpi=300, bbox_inches='tight')
plt.close()
# Assertions to verify the simulation
assert np.all(np.isfinite(states)), "States contain invalid values"
assert np.max(np.abs(states[:, 0])) > 0, "No oscillation detected"
def simulate_driven_oscillator(self):
"""Simulate driven oscillator dynamics."""
states = np.zeros((self.num_steps, 2))
states[0] = [0.0, 0.0] # Initial conditions
for i in range(1, self.num_steps):
driving_force = self.amplitude * np.sin(self.frequency * self.time[i])
states[i, 0] = states[i-1, 0] + self.dt * states[i-1, 1]
states[i, 1] = states[i-1, 1] - self.dt * (
self.omega**2 * states[i-1, 0] +
2 * self.damping * states[i-1, 1] -
driving_force
)
return states
def compute_free_energy(self, states):
"""Compute free energy for the system."""
# Placeholder - implement actual free energy computation
return np.zeros_like(self.time)
class TestGeneralizedCoordinates:
"""Test suite for generalized coordinates."""
@pytest.fixture(autouse=True)
def setup(self, tmp_path):
"""Setup test fixtures."""
plt.style.use('seaborn')
sns.set_palette("husl")
self.output_dir = tmp_path / "figures"
self.output_dir.mkdir(exist_ok=True)
yield
plt.close('all')
def test_generalized_coordinates_consistency(self):
"""Test consistency of generalized coordinates predictions."""
# ... existing test code ...
# Plotting
fig, (ax1, ax2, ax3, ax4) = plt.subplots(2, 2, figsize=(12, 10))
# Plot position predictions
lines1 = ax1.plot(time_points, positions, 'b-', label='Actual')
ax1.plot(time_points, predicted_positions, 'r--', label='Predicted')
ax1.set_xlabel('Time (s)')
ax1.set_ylabel('Position')
ax1.set_title('Position Prediction')
if lines1: # Only add legend if there are plotted lines
ax1.legend()
ax1.grid(True)
# Plot velocity predictions
lines2 = ax2.plot(time_points, velocities, 'b-', label='Actual')
ax2.plot(time_points, predicted_velocities, 'r--', label='Predicted')
ax2.set_xlabel('Time (s)')
ax2.set_ylabel('Velocity')
ax2.set_title('Velocity Prediction')
if lines2: # Only add legend if there are plotted lines
ax2.legend()
ax2.grid(True)
# Plot acceleration predictions
lines3 = ax3.plot(time_points, accelerations, 'b-', label='Actual')
ax3.plot(time_points, predicted_accelerations, 'r--', label='Predicted')
ax3.set_xlabel('Time (s)')
ax3.set_ylabel('Acceleration')
ax3.set_title('Acceleration Prediction')
if lines3: # Only add legend if there are plotted lines
ax3.legend()
ax3.grid(True)
# Plot prediction errors
lines4 = []
if len(position_errors) > 0:
lines4.extend(ax4.plot(time_points[1:], position_errors, 'r-', label='Position Error'))
if len(velocity_errors) > 0:
lines4.extend(ax4.plot(time_points[1:], velocity_errors, 'b--', label='Velocity Error'))
if len(acceleration_errors) > 0:
lines4.extend(ax4.plot(time_points[1:], acceleration_errors, 'g:', label='Acceleration Error'))
ax4.set_xlabel('Time (s)')
ax4.set_ylabel('Prediction Error')
ax4.set_title('Prediction Errors')
if lines4: # Only add legend if there are plotted lines
ax4.legend()
ax4.grid(True)
plt.tight_layout()
plt.savefig(self.output_dir / 'generalized_coordinates_predictions.png', dpi=300, bbox_inches='tight')
plt.close()
# Assertions
assert np.all(np.isfinite(positions)), "Invalid position values"
assert np.all(np.isfinite(velocities)), "Invalid velocity values"
assert np.all(np.isfinite(accelerations)), "Invalid acceleration values"