13
13
import blackjax
14
14
import blackjax .diagnostics as diagnostics
15
15
import blackjax .mcmc .random_walk
16
+ from blackjax .adaptation .base import get_filter_adapt_info_fn , return_all_adapt_info
16
17
from blackjax .util import run_inference_algorithm
17
18
18
19
@@ -56,6 +57,27 @@ def rmh_proposal_distribution(rng_key, position):
56
57
},
57
58
]
58
59
60
+ window_adaptation_filters = [
61
+ {
62
+ "filter_fn" : return_all_adapt_info ,
63
+ "return_sets" : None ,
64
+ },
65
+ {
66
+ "filter_fn" : get_filter_adapt_info_fn (),
67
+ "return_sets" : (set (), set (), set ()),
68
+ },
69
+ {
70
+ "filter_fn" : get_filter_adapt_info_fn (
71
+ {"position" }, {"is_divergent" }, {"ss_state" , "inverse_mass_matrix" }
72
+ ),
73
+ "return_sets" : (
74
+ {"position" },
75
+ {"is_divergent" },
76
+ {"ss_state" , "inverse_mass_matrix" },
77
+ ),
78
+ },
79
+ ]
80
+
59
81
60
82
class LinearRegressionTest (chex .TestCase ):
61
83
"""Test sampling of a linear regression model."""
@@ -112,8 +134,14 @@ def run_mclmc(self, logdensity_fn, num_steps, initial_position, key):
112
134
113
135
return samples
114
136
115
- @parameterized .parameters (itertools .product (regression_test_cases , [True , False ]))
116
- def test_window_adaptation (self , case , is_mass_matrix_diagonal ):
137
+ @parameterized .parameters (
138
+ itertools .product (
139
+ regression_test_cases , [True , False ], window_adaptation_filters
140
+ )
141
+ )
142
+ def test_window_adaptation (
143
+ self , case , is_mass_matrix_diagonal , window_adapt_config
144
+ ):
117
145
"""Test the HMC kernel and the Stan warmup."""
118
146
rng_key , init_key0 , init_key1 = jax .random .split (self .key , 3 )
119
147
x_data = jax .random .normal (init_key0 , shape = (1000 , 1 ))
@@ -131,15 +159,33 @@ def test_window_adaptation(self, case, is_mass_matrix_diagonal):
131
159
logposterior_fn ,
132
160
is_mass_matrix_diagonal ,
133
161
progress_bar = True ,
162
+ adaptation_info_fn = window_adapt_config ["filter_fn" ],
134
163
** case ["parameters" ],
135
164
)
136
- (state , parameters ), _ = warmup .run (
165
+ (state , parameters ), info = warmup .run (
137
166
warmup_key ,
138
167
case ["initial_position" ],
139
168
case ["num_warmup_steps" ],
140
169
)
141
170
inference_algorithm = case ["algorithm" ](logposterior_fn , ** parameters )
142
171
172
+ def check_attrs (attribute , keyset ):
173
+ for name , param in getattr (info , attribute )._asdict ().items ():
174
+ if name in keyset :
175
+ assert param is not None
176
+ else :
177
+ assert param is None
178
+
179
+ keysets = window_adapt_config ["return_sets" ]
180
+ if keysets is None :
181
+ keysets = (
182
+ info .state ._fields ,
183
+ info .info ._fields ,
184
+ info .adaptation_state ._fields ,
185
+ )
186
+ for i , attribute in enumerate (["state" , "info" , "adaptation_state" ]):
187
+ check_attrs (attribute , keysets [i ])
188
+
143
189
_ , states , _ = run_inference_algorithm (
144
190
inference_key , state , inference_algorithm , case ["num_sampling_steps" ]
145
191
)
0 commit comments