Skip to content

Commit

Permalink
Reduce JAX post-processing memory usage (#7311)
Browse files Browse the repository at this point in the history
Co-authored-by: andrewdipper <andrewdipper11235@gmail.com>
  • Loading branch information
andrewdipper and andrewdipper authored Jul 11, 2024
1 parent c43a4db commit 2216b59
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 55 deletions.
103 changes: 48 additions & 55 deletions pymc/sampling/jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,11 @@ def _get_log_likelihood(
elemwise_logp = model.logp(model.observed_RVs, sum=False)
jax_fn = get_jaxified_graph(inputs=model.value_vars, outputs=elemwise_logp)
result = _postprocess_samples(
jax_fn, samples, backend, postprocessing_vectorize=postprocessing_vectorize
jax_fn,
samples,
backend,
postprocessing_vectorize=postprocessing_vectorize,
donate_samples=False,
)
return {v.name: r for v, r in zip(model.observed_RVs, result)}

Expand All @@ -181,7 +185,8 @@ def _postprocess_samples(
jax_fn: Callable,
raw_mcmc_samples: list[TensorVariable],
postprocessing_backend: Literal["cpu", "gpu"] | None = None,
postprocessing_vectorize: Literal["vmap", "scan"] = "scan",
postprocessing_vectorize: Literal["vmap", "scan"] = "vmap",
donate_samples: bool = False,
) -> list[TensorVariable]:
if postprocessing_vectorize == "scan":
t_raw_mcmc_samples = [jnp.swapaxes(t, 0, 1) for t in raw_mcmc_samples]
Expand All @@ -193,7 +198,12 @@ def _postprocess_samples(
)
return [jnp.swapaxes(t, 0, 1) for t in outs]
elif postprocessing_vectorize == "vmap":
return jax.vmap(jax.vmap(jax_fn))(*_device_put(raw_mcmc_samples, postprocessing_backend))

def process_fn(x):
return jax.vmap(jax.vmap(jax_fn))(*_device_put(x, postprocessing_backend))

return jax.jit(process_fn, donate_argnums=0 if donate_samples else None)(raw_mcmc_samples)

else:
raise ValueError(f"Unrecognized postprocessing_vectorize: {postprocessing_vectorize}")

Expand Down Expand Up @@ -253,7 +263,16 @@ def _blackjax_inference_loop(
def _one_step(state, xs):
_, rng_key = xs
state, info = kernel(rng_key, state)
return state, (state, info)
position = state.position
stats = {
"diverging": info.is_divergent,
"energy": info.energy,
"tree_depth": info.num_trajectory_expansions,
"n_steps": info.num_integration_steps,
"acceptance_rate": info.acceptance_rate,
"lp": state.logdensity,
}
return state, (position, stats)

progress_bar = adaptation_kwargs.pop("progress_bar", False)
if progress_bar:
Expand All @@ -264,43 +283,9 @@ def _one_step(state, xs):
one_step = jax.jit(_one_step)

keys = jax.random.split(seed, draws)
_, (states, infos) = jax.lax.scan(one_step, last_state, (jnp.arange(draws), keys))

return states, infos


def _blackjax_stats_to_dict(sample_stats, potential_energy) -> dict:
"""Extract compatible stats from blackjax NUTS sampler
with PyMC/Arviz naming conventions.
Parameters
----------
sample_stats: NUTSInfo
Blackjax NUTSInfo object containing sampler statistics
potential_energy: ArrayLike
Potential energy values of sampled positions.
_, (samples, stats) = jax.lax.scan(one_step, last_state, (jnp.arange(draws), keys))

Returns
-------
Dict[str, ArrayLike]
Dictionary of sampler statistics.
"""
rename_key = {
"is_divergent": "diverging",
"energy": "energy",
"num_trajectory_expansions": "tree_depth",
"num_integration_steps": "n_steps",
"acceptance_rate": "acceptance_rate", # naming here is
"acceptance_probability": "acceptance_rate", # depending on blackjax version
}
converted_stats = {}
converted_stats["lp"] = potential_energy
for old_name, new_name in rename_key.items():
value = getattr(sample_stats, old_name, None)
if value is None:
continue
converted_stats[new_name] = value
return converted_stats
return samples, stats


def _sample_blackjax_nuts(
Expand Down Expand Up @@ -410,11 +395,7 @@ def _sample_blackjax_nuts(
**nuts_kwargs,
)

states, stats = map_fn(get_posterior_samples)(keys, initial_points)
raw_mcmc_samples = states.position
potential_energy = states.logdensity.block_until_ready()
sample_stats = _blackjax_stats_to_dict(stats, potential_energy)

raw_mcmc_samples, sample_stats = map_fn(get_posterior_samples)(keys, initial_points)
return raw_mcmc_samples, sample_stats, blackjax


Expand Down Expand Up @@ -515,7 +496,7 @@ def sample_jax_nuts(
keep_untransformed: bool = False,
chain_method: str = "parallel",
postprocessing_backend: Literal["cpu", "gpu"] | None = None,
postprocessing_vectorize: Literal["vmap", "scan"] = "scan",
postprocessing_vectorize: Literal["vmap", "scan"] | None = None,
postprocessing_chunks=None,
idata_kwargs: dict | None = None,
compute_convergence_checks: bool = True,
Expand Down Expand Up @@ -597,6 +578,16 @@ def sample_jax_nuts(
DeprecationWarning,
)

if postprocessing_vectorize is not None:
import warnings

warnings.warn(
'postprocessing_vectorize={"scan", "vmap"} will be removed in a future release.',
FutureWarning,
)
else:
postprocessing_vectorize = "vmap"

model = modelcontext(model)

if var_names is not None:
Expand Down Expand Up @@ -645,15 +636,6 @@ def sample_jax_nuts(
)
tic2 = datetime.now()

jax_fn = get_jaxified_graph(inputs=model.value_vars, outputs=vars_to_sample)
result = _postprocess_samples(
jax_fn,
raw_mcmc_samples,
postprocessing_backend=postprocessing_backend,
postprocessing_vectorize=postprocessing_vectorize,
)
mcmc_samples = {v.name: r for v, r in zip(vars_to_sample, result)}

if idata_kwargs is None:
idata_kwargs = {}
else:
Expand All @@ -669,6 +651,17 @@ def sample_jax_nuts(
else:
log_likelihood = None

jax_fn = get_jaxified_graph(inputs=model.value_vars, outputs=vars_to_sample)
result = _postprocess_samples(
jax_fn,
raw_mcmc_samples,
postprocessing_backend=postprocessing_backend,
postprocessing_vectorize=postprocessing_vectorize,
donate_samples=True,
)
del raw_mcmc_samples
mcmc_samples = {v.name: r for v, r in zip(vars_to_sample, result)}

attrs = {
"sampling_time": (tic2 - tic1).total_seconds(),
"tuning_steps": tune,
Expand Down
4 changes: 4 additions & 0 deletions pymc/sampling/mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,7 @@ def _sample_external_nuts(
var_names: Sequence[str] | None,
progressbar: bool,
idata_kwargs: dict | None,
compute_convergence_checks: bool,
nuts_sampler_kwargs: dict | None,
**kwargs,
):
Expand Down Expand Up @@ -364,6 +365,7 @@ def _sample_external_nuts(
progressbar=progressbar,
nuts_sampler=sampler,
idata_kwargs=idata_kwargs,
compute_convergence_checks=compute_convergence_checks,
**nuts_sampler_kwargs,
)
return idata
Expand Down Expand Up @@ -718,6 +720,7 @@ def joined_blas_limiter():
raise ValueError(
"Model can not be sampled with NUTS alone. Your model is probably not continuous."
)

with joined_blas_limiter():
return _sample_external_nuts(
sampler=nuts_sampler,
Expand All @@ -731,6 +734,7 @@ def joined_blas_limiter():
var_names=var_names,
progressbar=progressbar,
idata_kwargs=idata_kwargs,
compute_convergence_checks=compute_convergence_checks,
nuts_sampler_kwargs=nuts_sampler_kwargs,
**kwargs,
)
Expand Down

0 comments on commit 2216b59

Please sign in to comment.