Skip to content

Commit cd91e41

Browse files
authored
Enable fitlering of AdaptationInfo (#674)
* enable AdaptationInfo filtering * revert progress_bar * fix pre-commit * fix empty sets * enable adapt info filtering for all adaptation algorithms * fix precommit /progressbar=True * change filter tuple to use tree_map
1 parent af79fa4 commit cd91e41

File tree

7 files changed

+167
-20
lines changed

7 files changed

+167
-20
lines changed

blackjax/adaptation/base.py

+34-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,9 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
from typing import NamedTuple
14+
from typing import NamedTuple, Set
15+
16+
import jax
1517

1618
from blackjax.types import ArrayTree
1719

@@ -25,3 +27,34 @@ class AdaptationInfo(NamedTuple):
2527
state: NamedTuple
2628
info: NamedTuple
2729
adaptation_state: NamedTuple
30+
31+
32+
def return_all_adapt_info(state, info, adaptation_state):
33+
"""Return fully populated AdaptationInfo. Used for adaptation_info_fn
34+
parameters of the adaptation algorithms.
35+
"""
36+
return AdaptationInfo(state, info, adaptation_state)
37+
38+
39+
def get_filter_adapt_info_fn(
40+
state_keys: Set[str] = set(),
41+
info_keys: Set[str] = set(),
42+
adapt_state_keys: Set[str] = set(),
43+
):
44+
"""Generate a function to filter what is saved in AdaptationInfo. Used
45+
for adptation_info_fn parameters of the adaptation algorithms.
46+
adaptation_info_fn=get_filter_adapt_info_fn() saves no auxiliary information
47+
"""
48+
49+
def filter_tuple(tup, key_set):
50+
mapfn = lambda key, val: None if key not in key_set else val
51+
return jax.tree.map(mapfn, type(tup)(*tup._fields), tup)
52+
53+
def filter_fn(state, info, adaptation_state):
54+
sample_state = filter_tuple(state, state_keys)
55+
new_info = filter_tuple(info, info_keys)
56+
new_adapt_state = filter_tuple(adaptation_state, adapt_state_keys)
57+
58+
return AdaptationInfo(sample_state, new_info, new_adapt_state)
59+
60+
return filter_fn

blackjax/adaptation/chees_adaptation.py

+9-5
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
import blackjax.mcmc.dynamic_hmc as dynamic_hmc
1212
import blackjax.optimizers.dual_averaging as dual_averaging
13-
from blackjax.adaptation.base import AdaptationInfo, AdaptationResults
13+
from blackjax.adaptation.base import AdaptationResults, return_all_adapt_info
1414
from blackjax.base import AdaptationAlgorithm
1515
from blackjax.types import Array, ArrayLikeTree, PRNGKey
1616
from blackjax.util import pytree_size
@@ -278,6 +278,7 @@ def chees_adaptation(
278278
jitter_amount: float = 1.0,
279279
target_acceptance_rate: float = OPTIMAL_TARGET_ACCEPTANCE_RATE,
280280
decay_rate: float = 0.5,
281+
adaptation_info_fn: Callable = return_all_adapt_info,
281282
) -> AdaptationAlgorithm:
282283
"""Adapt the step size and trajectory length (number of integration steps / step size)
283284
parameters of the jittered HMC algorthm.
@@ -337,6 +338,11 @@ def chees_adaptation(
337338
Float representing how much to favor recent iterations over earlier ones in the optimization
338339
of step size and trajectory length. A value of 1 gives equal weight to all history. A value
339340
of 0 gives weight only to the most recent iteration.
341+
adaptation_info_fn
342+
Function to select the adaptation info returned. See return_all_adapt_info
343+
and get_filter_adapt_info_fn in blackjax.adaptation.base. By default all
344+
information is saved - this can result in excessive memory usage if the
345+
information is unused.
340346
341347
Returns
342348
-------
@@ -411,10 +417,8 @@ def one_step(carry, rng_key):
411417
info.is_divergent,
412418
)
413419

414-
return (new_states, new_adaptation_state), AdaptationInfo(
415-
new_states,
416-
info,
417-
new_adaptation_state,
420+
return (new_states, new_adaptation_state), adaptation_info_fn(
421+
new_states, info, new_adaptation_state
418422
)
419423

420424
batch_init = jax.vmap(

blackjax/adaptation/meads_adaptation.py

+9-5
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import jax.numpy as jnp
1818

1919
import blackjax.mcmc as mcmc
20-
from blackjax.adaptation.base import AdaptationInfo, AdaptationResults
20+
from blackjax.adaptation.base import AdaptationResults, return_all_adapt_info
2121
from blackjax.base import AdaptationAlgorithm
2222
from blackjax.types import Array, ArrayLikeTree, ArrayTree, PRNGKey
2323

@@ -165,6 +165,7 @@ def update(
165165
def meads_adaptation(
166166
logdensity_fn: Callable,
167167
num_chains: int,
168+
adaptation_info_fn: Callable = return_all_adapt_info,
168169
) -> AdaptationAlgorithm:
169170
"""Adapt the parameters of the Generalized HMC algorithm.
170171
@@ -194,6 +195,11 @@ def meads_adaptation(
194195
The log density probability density function from which we wish to sample.
195196
num_chains
196197
Number of chains used for cross-chain warm-up training.
198+
adaptation_info_fn
199+
Function to select the adaptation info returned. See return_all_adapt_info
200+
and get_filter_adapt_info_fn in blackjax.adaptation.base. By default all
201+
information is saved - this can result in excessive memory usage if the
202+
information is unused.
197203
198204
Returns
199205
-------
@@ -227,10 +233,8 @@ def one_step(carry, rng_key):
227233
adaptation_state, new_states.position, new_states.logdensity_grad
228234
)
229235

230-
return (new_states, new_adaptation_state), AdaptationInfo(
231-
new_states,
232-
info,
233-
new_adaptation_state,
236+
return (new_states, new_adaptation_state), adaptation_info_fn(
237+
new_states, info, new_adaptation_state
234238
)
235239

236240
def run(rng_key: PRNGKey, positions: ArrayLikeTree, num_steps: int = 1000):

blackjax/adaptation/pathfinder_adaptation.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import jax.numpy as jnp
1919

2020
import blackjax.vi as vi
21-
from blackjax.adaptation.base import AdaptationInfo, AdaptationResults
21+
from blackjax.adaptation.base import AdaptationResults, return_all_adapt_info
2222
from blackjax.adaptation.step_size import (
2323
DualAveragingAdaptationState,
2424
dual_averaging_adaptation,
@@ -141,6 +141,7 @@ def pathfinder_adaptation(
141141
logdensity_fn: Callable,
142142
initial_step_size: float = 1.0,
143143
target_acceptance_rate: float = 0.80,
144+
adaptation_info_fn: Callable = return_all_adapt_info,
144145
**extra_parameters,
145146
) -> AdaptationAlgorithm:
146147
"""Adapt the value of the inverse mass matrix and step size parameters of
@@ -156,6 +157,11 @@ def pathfinder_adaptation(
156157
The initial step size used in the algorithm.
157158
target_acceptance_rate
158159
The acceptance rate that we target during step size adaptation.
160+
adaptation_info_fn
161+
Function to select the adaptation info returned. See return_all_adapt_info
162+
and get_filter_adapt_info_fn in blackjax.adaptation.base. By default all
163+
information is saved - this can result in excessive memory usage if the
164+
information is unused.
159165
**extra_parameters
160166
The extra parameters to pass to the algorithm, e.g. the number of
161167
integration steps for HMC.
@@ -188,7 +194,7 @@ def one_step(carry, rng_key):
188194
)
189195
return (
190196
(new_state, new_adaptation_state),
191-
AdaptationInfo(new_state, info, new_adaptation_state),
197+
adaptation_info_fn(new_state, info, new_adaptation_state),
192198
)
193199

194200
def run(rng_key: PRNGKey, position: ArrayLikeTree, num_steps: int = 400):

blackjax/adaptation/window_adaptation.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import jax
1818
import jax.numpy as jnp
1919

20-
from blackjax.adaptation.base import AdaptationInfo, AdaptationResults
20+
from blackjax.adaptation.base import AdaptationResults, return_all_adapt_info
2121
from blackjax.adaptation.mass_matrix import (
2222
MassMatrixAdaptationState,
2323
mass_matrix_adaptation,
@@ -248,6 +248,7 @@ def window_adaptation(
248248
initial_step_size: float = 1.0,
249249
target_acceptance_rate: float = 0.80,
250250
progress_bar: bool = False,
251+
adaptation_info_fn: Callable = return_all_adapt_info,
251252
**extra_parameters,
252253
) -> AdaptationAlgorithm:
253254
"""Adapt the value of the inverse mass matrix and step size parameters of
@@ -278,6 +279,11 @@ def window_adaptation(
278279
The acceptance rate that we target during step size adaptation.
279280
progress_bar
280281
Whether we should display a progress bar.
282+
adaptation_info_fn
283+
Function to select the adaptation info returned. See return_all_adapt_info
284+
and get_filter_adapt_info_fn in blackjax.adaptation.base. By default all
285+
information is saved - this can result in excessive memory usage if the
286+
information is unused.
281287
**extra_parameters
282288
The extra parameters to pass to the algorithm, e.g. the number of
283289
integration steps for HMC.
@@ -316,7 +322,7 @@ def one_step(carry, xs):
316322

317323
return (
318324
(new_state, new_adaptation_state),
319-
AdaptationInfo(new_state, info, new_adaptation_state),
325+
adaptation_info_fn(new_state, info, new_adaptation_state),
320326
)
321327

322328
def run(rng_key: PRNGKey, position: ArrayLikeTree, num_steps: int = 1000):

tests/adaptation/test_adaptation.py

+50-2
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
import blackjax
88
from blackjax.adaptation import window_adaptation
9+
from blackjax.adaptation.base import get_filter_adapt_info_fn, return_all_adapt_info
910
from blackjax.util import run_inference_algorithm
1011

1112

@@ -34,7 +35,32 @@ def test_adaptation_schedule(num_steps, expected_schedule):
3435
assert np.array_equal(adaptation_schedule, expected_schedule)
3536

3637

37-
def test_chees_adaptation():
38+
@pytest.mark.parametrize(
39+
"adaptation_filters",
40+
[
41+
{
42+
"filter_fn": return_all_adapt_info,
43+
"return_sets": None,
44+
},
45+
{
46+
"filter_fn": get_filter_adapt_info_fn(),
47+
"return_sets": (set(), set(), set()),
48+
},
49+
{
50+
"filter_fn": get_filter_adapt_info_fn(
51+
{"logdensity"},
52+
{"proposal"},
53+
{"random_generator_arg", "step", "da_state"},
54+
),
55+
"return_sets": (
56+
{"logdensity"},
57+
{"proposal"},
58+
{"random_generator_arg", "step", "da_state"},
59+
),
60+
},
61+
],
62+
)
63+
def test_chees_adaptation(adaptation_filters):
3864
logprob_fn = lambda x: jax.scipy.stats.norm.logpdf(
3965
x, loc=0.0, scale=jnp.array([1.0, 10.0])
4066
).sum()
@@ -47,7 +73,10 @@ def test_chees_adaptation():
4773
init_key, warmup_key, inference_key = jax.random.split(jax.random.key(346), 3)
4874

4975
warmup = blackjax.chees_adaptation(
50-
logprob_fn, num_chains=num_chains, target_acceptance_rate=0.75
76+
logprob_fn,
77+
num_chains=num_chains,
78+
target_acceptance_rate=0.75,
79+
adaptation_info_fn=adaptation_filters["filter_fn"],
5180
)
5281

5382
initial_positions = jax.random.normal(init_key, (num_chains, 2))
@@ -66,6 +95,25 @@ def test_chees_adaptation():
6695
)(chain_keys, last_states)
6796

6897
harmonic_mean = 1.0 / jnp.mean(1.0 / infos.acceptance_rate)
98+
99+
def check_attrs(attribute, keyset):
100+
for name, param in getattr(warmup_info, attribute)._asdict().items():
101+
print(name, param)
102+
if name in keyset:
103+
assert param is not None
104+
else:
105+
assert param is None
106+
107+
keysets = adaptation_filters["return_sets"]
108+
if keysets is None:
109+
keysets = (
110+
warmup_info.state._fields,
111+
warmup_info.info._fields,
112+
warmup_info.adaptation_state._fields,
113+
)
114+
for i, attribute in enumerate(["state", "info", "adaptation_state"]):
115+
check_attrs(attribute, keysets[i])
116+
69117
np.testing.assert_allclose(harmonic_mean, 0.75, atol=1e-1)
70118
np.testing.assert_allclose(parameters["step_size"], 1.5, rtol=2e-1)
71119
np.testing.assert_array_less(infos.num_integration_steps.mean(), 15.0)

tests/mcmc/test_sampling.py

+49-3
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import blackjax
1414
import blackjax.diagnostics as diagnostics
1515
import blackjax.mcmc.random_walk
16+
from blackjax.adaptation.base import get_filter_adapt_info_fn, return_all_adapt_info
1617
from blackjax.util import run_inference_algorithm
1718

1819

@@ -56,6 +57,27 @@ def rmh_proposal_distribution(rng_key, position):
5657
},
5758
]
5859

60+
window_adaptation_filters = [
61+
{
62+
"filter_fn": return_all_adapt_info,
63+
"return_sets": None,
64+
},
65+
{
66+
"filter_fn": get_filter_adapt_info_fn(),
67+
"return_sets": (set(), set(), set()),
68+
},
69+
{
70+
"filter_fn": get_filter_adapt_info_fn(
71+
{"position"}, {"is_divergent"}, {"ss_state", "inverse_mass_matrix"}
72+
),
73+
"return_sets": (
74+
{"position"},
75+
{"is_divergent"},
76+
{"ss_state", "inverse_mass_matrix"},
77+
),
78+
},
79+
]
80+
5981

6082
class LinearRegressionTest(chex.TestCase):
6183
"""Test sampling of a linear regression model."""
@@ -112,8 +134,14 @@ def run_mclmc(self, logdensity_fn, num_steps, initial_position, key):
112134

113135
return samples
114136

115-
@parameterized.parameters(itertools.product(regression_test_cases, [True, False]))
116-
def test_window_adaptation(self, case, is_mass_matrix_diagonal):
137+
@parameterized.parameters(
138+
itertools.product(
139+
regression_test_cases, [True, False], window_adaptation_filters
140+
)
141+
)
142+
def test_window_adaptation(
143+
self, case, is_mass_matrix_diagonal, window_adapt_config
144+
):
117145
"""Test the HMC kernel and the Stan warmup."""
118146
rng_key, init_key0, init_key1 = jax.random.split(self.key, 3)
119147
x_data = jax.random.normal(init_key0, shape=(1000, 1))
@@ -131,15 +159,33 @@ def test_window_adaptation(self, case, is_mass_matrix_diagonal):
131159
logposterior_fn,
132160
is_mass_matrix_diagonal,
133161
progress_bar=True,
162+
adaptation_info_fn=window_adapt_config["filter_fn"],
134163
**case["parameters"],
135164
)
136-
(state, parameters), _ = warmup.run(
165+
(state, parameters), info = warmup.run(
137166
warmup_key,
138167
case["initial_position"],
139168
case["num_warmup_steps"],
140169
)
141170
inference_algorithm = case["algorithm"](logposterior_fn, **parameters)
142171

172+
def check_attrs(attribute, keyset):
173+
for name, param in getattr(info, attribute)._asdict().items():
174+
if name in keyset:
175+
assert param is not None
176+
else:
177+
assert param is None
178+
179+
keysets = window_adapt_config["return_sets"]
180+
if keysets is None:
181+
keysets = (
182+
info.state._fields,
183+
info.info._fields,
184+
info.adaptation_state._fields,
185+
)
186+
for i, attribute in enumerate(["state", "info", "adaptation_state"]):
187+
check_attrs(attribute, keysets[i])
188+
143189
_, states, _ = run_inference_algorithm(
144190
inference_key, state, inference_algorithm, case["num_sampling_steps"]
145191
)

0 commit comments

Comments
 (0)