Description
Describe the issue:
The Jax-based samplers crash after sampling, following the "Transforming variables..." message on medium-to-large models (thousands of rows, hundreds of parameters). This occurs both on GPU and CPU systems, and using either the numpyro or blackjax samplers. The failure on GPU returns a backtrace that isolates the issue at the vmap
in _postprocess_samples
. On a CPU (MacBook Pro M1), the process is simply killed without any error messages. I have tried running the GPU model with the postprocessing_backend="cpu"
argument for the numpyro sampler, but this does not seem to make a difference. Should it be using vmap when the postprocessing backend is CPU?
Reproduceable code example:
Will add example when I can come up with one
Error message:
CPU machine error:
Compilation time = 0:00:09.225151
Sampling...
Running chain 0: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2000/2000 [12:00:33<00:00, 21.62s/it]
Running chain 1: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2000/2000 [12:00:33<00:00, 21.62s/it]
Sampling time = 12:00:35.215191
Transforming variables...
Killed: 9
/Users/cfonnesbeck/mambaforge/envs/pymc/lib/python3.11/multiprocessing/resource_tracker.py:224: UserWarning: resource_tracker: There appear to be 1 leaked semaphore objects to clean up at shutdown
warnings.warn('resource_tracker: There appear to be %d '
PyMC version information:
PyMC 5.3.0
PyTensor 2.11.1
Context for the issue:
The numpyro sampler is currently unusable for moderate-sized models due to this issue.