2121import pytensor .tensor as at
2222
2323from arviz .data .base import make_attrs
24+ from jax .experimental .maps import SerialLoop , xmap
2425from pytensor .compile import SharedVariable , Supervisor , mode
2526from pytensor .graph .basic import graph_inputs
2627from pytensor .graph .fg import FunctionGraph
@@ -143,6 +144,27 @@ def _sample_stats_to_xarray(posterior):
143144 return data
144145
145146
147+ def _postprocess_samples (
148+ jax_fn : List [TensorVariable ],
149+ raw_mcmc_samples : List [TensorVariable ],
150+ postprocessing_backend : str ,
151+ num_chunks : Optional [int ] = None ,
152+ ) -> List [TensorVariable ]:
153+ if num_chunks is not None :
154+ loop = xmap (
155+ jax_fn ,
156+ in_axes = ["chain" , "samples" , ...],
157+ out_axes = ["chain" , "samples" , ...],
158+ axis_resources = {"samples" : SerialLoop (num_chunks )},
159+ )
160+ f = xmap (loop , in_axes = [...], out_axes = [...])
161+ return f (* jax .device_put (raw_mcmc_samples , jax .devices (postprocessing_backend )[0 ]))
162+ else :
163+ return jax .vmap (jax .vmap (jax_fn ))(
164+ * jax .device_put (raw_mcmc_samples , jax .devices (postprocessing_backend )[0 ])
165+ )
166+
167+
146168def _blackjax_stats_to_dict (sample_stats , potential_energy ) -> Dict :
147169 """Extract compatible stats from blackjax NUTS sampler
148170 with PyMC/Arviz naming conventions.
@@ -177,11 +199,13 @@ def _blackjax_stats_to_dict(sample_stats, potential_energy) -> Dict:
177199 return converted_stats
178200
179201
180- def _get_log_likelihood (model : Model , samples , backend = None ) -> Dict :
202+ def _get_log_likelihood (
203+ model : Model , samples , backend = None , num_chunks : Optional [int ] = None
204+ ) -> Dict :
181205 """Compute log-likelihood for all observations"""
182206 elemwise_logp = model .logp (model .observed_RVs , sum = False )
183207 jax_fn = get_jaxified_graph (inputs = model .value_vars , outputs = elemwise_logp )
184- result = jax . vmap ( jax . vmap ( jax_fn ))( * jax . device_put ( samples , jax . devices ( backend )[ 0 ]) )
208+ result = _postprocess_samples ( jax_fn , samples , backend , num_chunks = num_chunks )
185209 return {v .name : r for v , r in zip (model .observed_RVs , result )}
186210
187211
@@ -275,6 +299,7 @@ def sample_blackjax_nuts(
275299 keep_untransformed : bool = False ,
276300 chain_method : str = "parallel" ,
277301 postprocessing_backend : Optional [str ] = None ,
302+ postprocessing_chunks : Optional [int ] = None ,
278303 idata_kwargs : Optional [Dict [str , Any ]] = None ,
279304) -> az .InferenceData :
280305 """
@@ -314,6 +339,10 @@ def sample_blackjax_nuts(
314339 "vectorized".
315340 postprocessing_backend : str, optional
316341 Specify how postprocessing should be computed. gpu or cpu
342+ postprocessing_chunks: Optional[int], default None
343+ Specify the number of chunks the postprocessing should be computed in. More
344+ chunks reduces memory usage at the cost of losing some vectorization, None
345+ uses jax.vmap
317346 idata_kwargs : dict, optional
318347 Keyword arguments for :func:`arviz.from_dict`. It also accepts a boolean as
319348 value for the ``log_likelihood`` key to indicate that the pointwise log
@@ -400,8 +429,8 @@ def sample_blackjax_nuts(
400429
401430 print ("Transforming variables..." , file = sys .stdout )
402431 jax_fn = get_jaxified_graph (inputs = model .value_vars , outputs = vars_to_sample )
403- result = jax . vmap ( jax . vmap ( jax_fn )) (
404- * jax . device_put ( raw_mcmc_samples , jax . devices ( postprocessing_backend )[ 0 ])
432+ result = _postprocess_samples (
433+ jax_fn , raw_mcmc_samples , postprocessing_backend , num_chunks = postprocessing_chunks
405434 )
406435 mcmc_samples = {v .name : r for v , r in zip (vars_to_sample , result )}
407436 mcmc_stats = _blackjax_stats_to_dict (stats , potential_energy )
@@ -417,7 +446,10 @@ def sample_blackjax_nuts(
417446 tic5 = datetime .now ()
418447 print ("Computing Log Likelihood..." , file = sys .stdout )
419448 log_likelihood = _get_log_likelihood (
420- model , raw_mcmc_samples , backend = postprocessing_backend
449+ model ,
450+ raw_mcmc_samples ,
451+ backend = postprocessing_backend ,
452+ num_chunks = postprocessing_chunks ,
421453 )
422454 tic6 = datetime .now ()
423455 print ("Log Likelihood time = " , tic6 - tic5 , file = sys .stdout )
@@ -478,6 +510,7 @@ def sample_numpyro_nuts(
478510 keep_untransformed : bool = False ,
479511 chain_method : str = "parallel" ,
480512 postprocessing_backend : Optional [str ] = None ,
513+ postprocessing_chunks : Optional [int ] = None ,
481514 idata_kwargs : Optional [Dict ] = None ,
482515 nuts_kwargs : Optional [Dict ] = None ,
483516) -> az .InferenceData :
@@ -522,6 +555,10 @@ def sample_numpyro_nuts(
522555 "parallel", and "vectorized".
523556 postprocessing_backend : Optional[str]
524557 Specify how postprocessing should be computed. gpu or cpu
558+ postprocessing_chunks: Optional[int], default None
559+ Specify the number of chunks the postprocessing should be computed in. More
560+ chunks reduces memory usage at the cost of losing some vectorization, None
561+ uses jax.vmap
525562 idata_kwargs : dict, optional
526563 Keyword arguments for :func:`arviz.from_dict`. It also accepts a boolean as
527564 value for the ``log_likelihood`` key to indicate that the pointwise log
@@ -622,8 +659,8 @@ def sample_numpyro_nuts(
622659
623660 print ("Transforming variables..." , file = sys .stdout )
624661 jax_fn = get_jaxified_graph (inputs = model .value_vars , outputs = vars_to_sample )
625- result = jax . vmap ( jax . vmap ( jax_fn )) (
626- * jax . device_put ( raw_mcmc_samples , jax . devices ( postprocessing_backend )[ 0 ])
662+ result = _postprocess_samples (
663+ jax_fn , raw_mcmc_samples , postprocessing_backend , num_chunks = postprocessing_chunks
627664 )
628665 mcmc_samples = {v .name : r for v , r in zip (vars_to_sample , result )}
629666
@@ -639,7 +676,10 @@ def sample_numpyro_nuts(
639676 tic5 = datetime .now ()
640677 print ("Computing Log Likelihood..." , file = sys .stdout )
641678 log_likelihood = _get_log_likelihood (
642- model , raw_mcmc_samples , backend = postprocessing_backend
679+ model ,
680+ raw_mcmc_samples ,
681+ backend = postprocessing_backend ,
682+ num_chunks = postprocessing_chunks ,
643683 )
644684 tic6 = datetime .now ()
645685 print ("Log Likelihood time = " , tic6 - tic5 , file = sys .stdout )
0 commit comments