@@ -168,7 +168,11 @@ def _get_log_likelihood(
168
168
elemwise_logp = model .logp (model .observed_RVs , sum = False )
169
169
jax_fn = get_jaxified_graph (inputs = model .value_vars , outputs = elemwise_logp )
170
170
result = _postprocess_samples (
171
- jax_fn , samples , backend , postprocessing_vectorize = postprocessing_vectorize
171
+ jax_fn ,
172
+ samples ,
173
+ backend ,
174
+ postprocessing_vectorize = postprocessing_vectorize ,
175
+ donate_samples = False ,
172
176
)
173
177
return {v .name : r for v , r in zip (model .observed_RVs , result )}
174
178
@@ -181,7 +185,8 @@ def _postprocess_samples(
181
185
jax_fn : Callable ,
182
186
raw_mcmc_samples : list [TensorVariable ],
183
187
postprocessing_backend : Literal ["cpu" , "gpu" ] | None = None ,
184
- postprocessing_vectorize : Literal ["vmap" , "scan" ] = "scan" ,
188
+ postprocessing_vectorize : Literal ["vmap" , "scan" ] = "vmap" ,
189
+ donate_samples : bool = False ,
185
190
) -> list [TensorVariable ]:
186
191
if postprocessing_vectorize == "scan" :
187
192
t_raw_mcmc_samples = [jnp .swapaxes (t , 0 , 1 ) for t in raw_mcmc_samples ]
@@ -193,7 +198,12 @@ def _postprocess_samples(
193
198
)
194
199
return [jnp .swapaxes (t , 0 , 1 ) for t in outs ]
195
200
elif postprocessing_vectorize == "vmap" :
196
- return jax .vmap (jax .vmap (jax_fn ))(* _device_put (raw_mcmc_samples , postprocessing_backend ))
201
+
202
+ def process_fn (x ):
203
+ return jax .vmap (jax .vmap (jax_fn ))(* _device_put (x , postprocessing_backend ))
204
+
205
+ return jax .jit (process_fn , donate_argnums = 0 if donate_samples else None )(raw_mcmc_samples )
206
+
197
207
else :
198
208
raise ValueError (f"Unrecognized postprocessing_vectorize: { postprocessing_vectorize } " )
199
209
@@ -253,7 +263,16 @@ def _blackjax_inference_loop(
253
263
def _one_step (state , xs ):
254
264
_ , rng_key = xs
255
265
state , info = kernel (rng_key , state )
256
- return state , (state , info )
266
+ position = state .position
267
+ stats = {
268
+ "diverging" : info .is_divergent ,
269
+ "energy" : info .energy ,
270
+ "tree_depth" : info .num_trajectory_expansions ,
271
+ "n_steps" : info .num_integration_steps ,
272
+ "acceptance_rate" : info .acceptance_rate ,
273
+ "lp" : state .logdensity ,
274
+ }
275
+ return state , (position , stats )
257
276
258
277
progress_bar = adaptation_kwargs .pop ("progress_bar" , False )
259
278
if progress_bar :
@@ -264,43 +283,9 @@ def _one_step(state, xs):
264
283
one_step = jax .jit (_one_step )
265
284
266
285
keys = jax .random .split (seed , draws )
267
- _ , (states , infos ) = jax .lax .scan (one_step , last_state , (jnp .arange (draws ), keys ))
268
-
269
- return states , infos
270
-
271
-
272
- def _blackjax_stats_to_dict (sample_stats , potential_energy ) -> dict :
273
- """Extract compatible stats from blackjax NUTS sampler
274
- with PyMC/Arviz naming conventions.
275
-
276
- Parameters
277
- ----------
278
- sample_stats: NUTSInfo
279
- Blackjax NUTSInfo object containing sampler statistics
280
- potential_energy: ArrayLike
281
- Potential energy values of sampled positions.
286
+ _ , (samples , stats ) = jax .lax .scan (one_step , last_state , (jnp .arange (draws ), keys ))
282
287
283
- Returns
284
- -------
285
- Dict[str, ArrayLike]
286
- Dictionary of sampler statistics.
287
- """
288
- rename_key = {
289
- "is_divergent" : "diverging" ,
290
- "energy" : "energy" ,
291
- "num_trajectory_expansions" : "tree_depth" ,
292
- "num_integration_steps" : "n_steps" ,
293
- "acceptance_rate" : "acceptance_rate" , # naming here is
294
- "acceptance_probability" : "acceptance_rate" , # depending on blackjax version
295
- }
296
- converted_stats = {}
297
- converted_stats ["lp" ] = potential_energy
298
- for old_name , new_name in rename_key .items ():
299
- value = getattr (sample_stats , old_name , None )
300
- if value is None :
301
- continue
302
- converted_stats [new_name ] = value
303
- return converted_stats
288
+ return samples , stats
304
289
305
290
306
291
def _sample_blackjax_nuts (
@@ -410,11 +395,7 @@ def _sample_blackjax_nuts(
410
395
** nuts_kwargs ,
411
396
)
412
397
413
- states , stats = map_fn (get_posterior_samples )(keys , initial_points )
414
- raw_mcmc_samples = states .position
415
- potential_energy = states .logdensity .block_until_ready ()
416
- sample_stats = _blackjax_stats_to_dict (stats , potential_energy )
417
-
398
+ raw_mcmc_samples , sample_stats = map_fn (get_posterior_samples )(keys , initial_points )
418
399
return raw_mcmc_samples , sample_stats , blackjax
419
400
420
401
@@ -515,7 +496,7 @@ def sample_jax_nuts(
515
496
keep_untransformed : bool = False ,
516
497
chain_method : str = "parallel" ,
517
498
postprocessing_backend : Literal ["cpu" , "gpu" ] | None = None ,
518
- postprocessing_vectorize : Literal ["vmap" , "scan" ] = "scan" ,
499
+ postprocessing_vectorize : Literal ["vmap" , "scan" ] | None = None ,
519
500
postprocessing_chunks = None ,
520
501
idata_kwargs : dict | None = None ,
521
502
compute_convergence_checks : bool = True ,
@@ -597,6 +578,16 @@ def sample_jax_nuts(
597
578
DeprecationWarning ,
598
579
)
599
580
581
+ if postprocessing_vectorize is not None :
582
+ import warnings
583
+
584
+ warnings .warn (
585
+ 'postprocessing_vectorize={"scan", "vmap"} will be removed in a future release.' ,
586
+ FutureWarning ,
587
+ )
588
+ else :
589
+ postprocessing_vectorize = "vmap"
590
+
600
591
model = modelcontext (model )
601
592
602
593
if var_names is not None :
@@ -645,15 +636,6 @@ def sample_jax_nuts(
645
636
)
646
637
tic2 = datetime .now ()
647
638
648
- jax_fn = get_jaxified_graph (inputs = model .value_vars , outputs = vars_to_sample )
649
- result = _postprocess_samples (
650
- jax_fn ,
651
- raw_mcmc_samples ,
652
- postprocessing_backend = postprocessing_backend ,
653
- postprocessing_vectorize = postprocessing_vectorize ,
654
- )
655
- mcmc_samples = {v .name : r for v , r in zip (vars_to_sample , result )}
656
-
657
639
if idata_kwargs is None :
658
640
idata_kwargs = {}
659
641
else :
@@ -669,6 +651,17 @@ def sample_jax_nuts(
669
651
else :
670
652
log_likelihood = None
671
653
654
+ jax_fn = get_jaxified_graph (inputs = model .value_vars , outputs = vars_to_sample )
655
+ result = _postprocess_samples (
656
+ jax_fn ,
657
+ raw_mcmc_samples ,
658
+ postprocessing_backend = postprocessing_backend ,
659
+ postprocessing_vectorize = postprocessing_vectorize ,
660
+ donate_samples = True ,
661
+ )
662
+ del raw_mcmc_samples
663
+ mcmc_samples = {v .name : r for v , r in zip (vars_to_sample , result )}
664
+
672
665
attrs = {
673
666
"sampling_time" : (tic2 - tic1 ).total_seconds (),
674
667
"tuning_steps" : tune ,
0 commit comments