Skip to content

Commit 2216b59

Browse files
andrewdipperandrewdipper
and
andrewdipper
authored
Reduce JAX post-processing memory usage (#7311)
Co-authored-by: andrewdipper <andrewdipper11235@gmail.com>
1 parent c43a4db commit 2216b59

File tree

2 files changed

+52
-55
lines changed

2 files changed

+52
-55
lines changed

pymc/sampling/jax.py

+48-55
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,11 @@ def _get_log_likelihood(
168168
elemwise_logp = model.logp(model.observed_RVs, sum=False)
169169
jax_fn = get_jaxified_graph(inputs=model.value_vars, outputs=elemwise_logp)
170170
result = _postprocess_samples(
171-
jax_fn, samples, backend, postprocessing_vectorize=postprocessing_vectorize
171+
jax_fn,
172+
samples,
173+
backend,
174+
postprocessing_vectorize=postprocessing_vectorize,
175+
donate_samples=False,
172176
)
173177
return {v.name: r for v, r in zip(model.observed_RVs, result)}
174178

@@ -181,7 +185,8 @@ def _postprocess_samples(
181185
jax_fn: Callable,
182186
raw_mcmc_samples: list[TensorVariable],
183187
postprocessing_backend: Literal["cpu", "gpu"] | None = None,
184-
postprocessing_vectorize: Literal["vmap", "scan"] = "scan",
188+
postprocessing_vectorize: Literal["vmap", "scan"] = "vmap",
189+
donate_samples: bool = False,
185190
) -> list[TensorVariable]:
186191
if postprocessing_vectorize == "scan":
187192
t_raw_mcmc_samples = [jnp.swapaxes(t, 0, 1) for t in raw_mcmc_samples]
@@ -193,7 +198,12 @@ def _postprocess_samples(
193198
)
194199
return [jnp.swapaxes(t, 0, 1) for t in outs]
195200
elif postprocessing_vectorize == "vmap":
196-
return jax.vmap(jax.vmap(jax_fn))(*_device_put(raw_mcmc_samples, postprocessing_backend))
201+
202+
def process_fn(x):
203+
return jax.vmap(jax.vmap(jax_fn))(*_device_put(x, postprocessing_backend))
204+
205+
return jax.jit(process_fn, donate_argnums=0 if donate_samples else None)(raw_mcmc_samples)
206+
197207
else:
198208
raise ValueError(f"Unrecognized postprocessing_vectorize: {postprocessing_vectorize}")
199209

@@ -253,7 +263,16 @@ def _blackjax_inference_loop(
253263
def _one_step(state, xs):
254264
_, rng_key = xs
255265
state, info = kernel(rng_key, state)
256-
return state, (state, info)
266+
position = state.position
267+
stats = {
268+
"diverging": info.is_divergent,
269+
"energy": info.energy,
270+
"tree_depth": info.num_trajectory_expansions,
271+
"n_steps": info.num_integration_steps,
272+
"acceptance_rate": info.acceptance_rate,
273+
"lp": state.logdensity,
274+
}
275+
return state, (position, stats)
257276

258277
progress_bar = adaptation_kwargs.pop("progress_bar", False)
259278
if progress_bar:
@@ -264,43 +283,9 @@ def _one_step(state, xs):
264283
one_step = jax.jit(_one_step)
265284

266285
keys = jax.random.split(seed, draws)
267-
_, (states, infos) = jax.lax.scan(one_step, last_state, (jnp.arange(draws), keys))
268-
269-
return states, infos
270-
271-
272-
def _blackjax_stats_to_dict(sample_stats, potential_energy) -> dict:
273-
"""Extract compatible stats from blackjax NUTS sampler
274-
with PyMC/Arviz naming conventions.
275-
276-
Parameters
277-
----------
278-
sample_stats: NUTSInfo
279-
Blackjax NUTSInfo object containing sampler statistics
280-
potential_energy: ArrayLike
281-
Potential energy values of sampled positions.
286+
_, (samples, stats) = jax.lax.scan(one_step, last_state, (jnp.arange(draws), keys))
282287

283-
Returns
284-
-------
285-
Dict[str, ArrayLike]
286-
Dictionary of sampler statistics.
287-
"""
288-
rename_key = {
289-
"is_divergent": "diverging",
290-
"energy": "energy",
291-
"num_trajectory_expansions": "tree_depth",
292-
"num_integration_steps": "n_steps",
293-
"acceptance_rate": "acceptance_rate", # naming here is
294-
"acceptance_probability": "acceptance_rate", # depending on blackjax version
295-
}
296-
converted_stats = {}
297-
converted_stats["lp"] = potential_energy
298-
for old_name, new_name in rename_key.items():
299-
value = getattr(sample_stats, old_name, None)
300-
if value is None:
301-
continue
302-
converted_stats[new_name] = value
303-
return converted_stats
288+
return samples, stats
304289

305290

306291
def _sample_blackjax_nuts(
@@ -410,11 +395,7 @@ def _sample_blackjax_nuts(
410395
**nuts_kwargs,
411396
)
412397

413-
states, stats = map_fn(get_posterior_samples)(keys, initial_points)
414-
raw_mcmc_samples = states.position
415-
potential_energy = states.logdensity.block_until_ready()
416-
sample_stats = _blackjax_stats_to_dict(stats, potential_energy)
417-
398+
raw_mcmc_samples, sample_stats = map_fn(get_posterior_samples)(keys, initial_points)
418399
return raw_mcmc_samples, sample_stats, blackjax
419400

420401

@@ -515,7 +496,7 @@ def sample_jax_nuts(
515496
keep_untransformed: bool = False,
516497
chain_method: str = "parallel",
517498
postprocessing_backend: Literal["cpu", "gpu"] | None = None,
518-
postprocessing_vectorize: Literal["vmap", "scan"] = "scan",
499+
postprocessing_vectorize: Literal["vmap", "scan"] | None = None,
519500
postprocessing_chunks=None,
520501
idata_kwargs: dict | None = None,
521502
compute_convergence_checks: bool = True,
@@ -597,6 +578,16 @@ def sample_jax_nuts(
597578
DeprecationWarning,
598579
)
599580

581+
if postprocessing_vectorize is not None:
582+
import warnings
583+
584+
warnings.warn(
585+
'postprocessing_vectorize={"scan", "vmap"} will be removed in a future release.',
586+
FutureWarning,
587+
)
588+
else:
589+
postprocessing_vectorize = "vmap"
590+
600591
model = modelcontext(model)
601592

602593
if var_names is not None:
@@ -645,15 +636,6 @@ def sample_jax_nuts(
645636
)
646637
tic2 = datetime.now()
647638

648-
jax_fn = get_jaxified_graph(inputs=model.value_vars, outputs=vars_to_sample)
649-
result = _postprocess_samples(
650-
jax_fn,
651-
raw_mcmc_samples,
652-
postprocessing_backend=postprocessing_backend,
653-
postprocessing_vectorize=postprocessing_vectorize,
654-
)
655-
mcmc_samples = {v.name: r for v, r in zip(vars_to_sample, result)}
656-
657639
if idata_kwargs is None:
658640
idata_kwargs = {}
659641
else:
@@ -669,6 +651,17 @@ def sample_jax_nuts(
669651
else:
670652
log_likelihood = None
671653

654+
jax_fn = get_jaxified_graph(inputs=model.value_vars, outputs=vars_to_sample)
655+
result = _postprocess_samples(
656+
jax_fn,
657+
raw_mcmc_samples,
658+
postprocessing_backend=postprocessing_backend,
659+
postprocessing_vectorize=postprocessing_vectorize,
660+
donate_samples=True,
661+
)
662+
del raw_mcmc_samples
663+
mcmc_samples = {v.name: r for v, r in zip(vars_to_sample, result)}
664+
672665
attrs = {
673666
"sampling_time": (tic2 - tic1).total_seconds(),
674667
"tuning_steps": tune,

pymc/sampling/mcmc.py

+4
Original file line numberDiff line numberDiff line change
@@ -272,6 +272,7 @@ def _sample_external_nuts(
272272
var_names: Sequence[str] | None,
273273
progressbar: bool,
274274
idata_kwargs: dict | None,
275+
compute_convergence_checks: bool,
275276
nuts_sampler_kwargs: dict | None,
276277
**kwargs,
277278
):
@@ -364,6 +365,7 @@ def _sample_external_nuts(
364365
progressbar=progressbar,
365366
nuts_sampler=sampler,
366367
idata_kwargs=idata_kwargs,
368+
compute_convergence_checks=compute_convergence_checks,
367369
**nuts_sampler_kwargs,
368370
)
369371
return idata
@@ -718,6 +720,7 @@ def joined_blas_limiter():
718720
raise ValueError(
719721
"Model can not be sampled with NUTS alone. Your model is probably not continuous."
720722
)
723+
721724
with joined_blas_limiter():
722725
return _sample_external_nuts(
723726
sampler=nuts_sampler,
@@ -731,6 +734,7 @@ def joined_blas_limiter():
731734
var_names=var_names,
732735
progressbar=progressbar,
733736
idata_kwargs=idata_kwargs,
737+
compute_convergence_checks=compute_convergence_checks,
734738
nuts_sampler_kwargs=nuts_sampler_kwargs,
735739
**kwargs,
736740
)

0 commit comments

Comments
 (0)