Description
Describe the issue:
As it stands, SMC sampler cannot be parallelized with custom ops.
When using SMC sampler with more than one core (i.e. parallel sampling) and an as_op
custom op, the op is not pickled properly in the "manual" pickling at
Line 385 in 118be0f
Reproduceable code example:
import pymc as pm
import pytensor.tensor as pt
from pytensor.compile.ops import as_op
@as_op(itypes=[pt.dvector], otypes=[pt.dvector])
def twice(x):
return 2*x
with pm.Model() as model:
x = pm.Normal('x', mu=[0, 0], sigma=1)
y = twice(x)
z = pm.Normal(name='z', mu=y, observed=[1, 1])
# Using cores=1 would work, but cores=2 throws an error
trace = pm.sample_smc(10,cores=2)
Error message:
<details>
{
"name": "AttributeError",
"message": "module '__main__' has no attribute 'twice'",
"stack": "---------------------------------------------------------------------------
RemoteTraceback Traceback (most recent call last)
RemoteTraceback:
\"\"\"
Traceback (most recent call last):
File \"/opt/homebrew/Caskroom/mambaforge/base/envs/pymc_env/lib/python3.11/multiprocessing/pool.py\", line 125, in worker
result = (True, func(*args, **kwds))
^^^^^^^^^^^^^^^^^^^
File \"/opt/homebrew/Caskroom/mambaforge/base/envs/pymc_env/lib/python3.11/multiprocessing/pool.py\", line 51, in starmapstar
return list(itertools.starmap(args[0], args[1]))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File \"/opt/homebrew/Caskroom/mambaforge/base/envs/pymc_env/lib/python3.11/site-packages/pymc/smc/sampling.py\", line 419, in _apply_args_and_kwargs
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File \"/opt/homebrew/Caskroom/mambaforge/base/envs/pymc_env/lib/python3.11/site-packages/pymc/smc/sampling.py\", line 320, in _sample_smc_int
(draws, kernel, start, model) = map(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File \"/opt/homebrew/Caskroom/mambaforge/base/envs/pymc_env/lib/python3.11/site-packages/pytensor/compile/ops.py\", line 221, in load_back
obj = getattr(module, name)
^^^^^^^^^^^^^^^^^^^^^
AttributeError: module '__main__' has no attribute 'twice'
\"\"\"
The above exception was the direct cause of the following exception:
AttributeError Traceback (most recent call last)
Cell In[14], line 2
1 with model:
----> 2 trace = pm.sample_smc(10,cores=2)
File /opt/homebrew/Caskroom/mambaforge/base/envs/pymc_env/lib/python3.11/site-packages/pymc/smc/sampling.py:213, in sample_smc(draws, kernel, start, model, random_seed, chains, cores, compute_convergence_checks, return_inferencedata, idata_kwargs, progressbar, **kernel_kwargs)
210 t1 = time.time()
212 if cores > 1:
--> 213 results = run_chains_parallel(
214 chains, progressbar, _sample_smc_int, params, random_seed, kernel_kwargs, cores
215 )
216 else:
217 results = run_chains_sequential(
218 chains, progressbar, _sample_smc_int, params, random_seed, kernel_kwargs
219 )
File /opt/homebrew/Caskroom/mambaforge/base/envs/pymc_env/lib/python3.11/site-packages/pymc/smc/sampling.py:388, in run_chains_parallel(chains, progressbar, to_run, params, random_seed, kernel_kwargs, cores)
386 params = tuple(cloudpickle.dumps(p) for p in params)
387 kernel_kwargs = {key: cloudpickle.dumps(value) for key, value in kernel_kwargs.items()}
--> 388 results = _starmap_with_kwargs(
389 pool,
390 to_run,
391 [(*params, random_seed[chain], chain, pbars[chain]) for chain in range(chains)],
392 repeat(kernel_kwargs),
393 )
394 results = tuple(cloudpickle.loads(r) for r in results)
395 pool.close()
File /opt/homebrew/Caskroom/mambaforge/base/envs/pymc_env/lib/python3.11/site-packages/pymc/smc/sampling.py:415, in _starmap_with_kwargs(pool, fn, args_iter, kwargs_iter)
411 def _starmap_with_kwargs(pool, fn, args_iter, kwargs_iter):
412 # Helper function to allow kwargs with Pool.starmap
413 # Copied from https://stackoverflow.com/a/53173433/13311693
414 args_for_starmap = zip(repeat(fn), args_iter, kwargs_iter)
--> 415 return pool.starmap(_apply_args_and_kwargs, args_for_starmap)
File /opt/homebrew/Caskroom/mambaforge/base/envs/pymc_env/lib/python3.11/multiprocessing/pool.py:375, in Pool.starmap(self, func, iterable, chunksize)
369 def starmap(self, func, iterable, chunksize=None):
370 '''
371 Like `map()` method but the elements of the `iterable` are expected to
372 be iterables as well and will be unpacked as arguments. Hence
373 `func` and (a, b) becomes func(a, b).
374 '''
--> 375 return self._map_async(func, iterable, starmapstar, chunksize).get()
File /opt/homebrew/Caskroom/mambaforge/base/envs/pymc_env/lib/python3.11/multiprocessing/pool.py:774, in ApplyResult.get(self, timeout)
772 return self._value
773 else:
--> 774 raise self._value
AttributeError: module '__main__' has no attribute 'twice'"
}
</details>
PyMC version information:
pymc: 5.10.3
pytensor: 2.18.4
python: 3.11.7
Installed in a fresh conda environment with
conda create -c conda-forge -n pymc_env "pymc>=5"
Context for the issue:
As it stands, SMC sampler cannot run the official PyMC example from https://www.pymc.io/projects/examples/en/latest/ode_models/ODE_Lotka_Volterra_multiple_ways.html
Any simple ODE where sunode is overkill will crash similarly, as it requires a custom op, that is not pickled.
The workaround of using a single core makes the method much slower than needed.
Is there a way to serialize the custom operation please?