Skip to content

BUG: as_op not pickled, making parallel SMC crash #7078

Open
@jucor

Description

@jucor

Describe the issue:

As it stands, SMC sampler cannot be parallelized with custom ops.

When using SMC sampler with more than one core (i.e. parallel sampling) and an as_op custom op, the op is not pickled properly in the "manual" pickling at

# "manually" (de)serialize params before/after multiprocessing
, thus causing the run to fail.

Reproduceable code example:

import pymc as pm
import pytensor.tensor as pt

from pytensor.compile.ops import as_op

@as_op(itypes=[pt.dvector], otypes=[pt.dvector])
def twice(x):
    return 2*x

with pm.Model() as model:
    x = pm.Normal('x', mu=[0, 0], sigma=1)
    y = twice(x)
    z = pm.Normal(name='z', mu=y, observed=[1, 1])

    # Using cores=1 would work, but cores=2 throws an error
    trace = pm.sample_smc(10,cores=2)

Error message:

<details>
{
	"name": "AttributeError",
	"message": "module '__main__' has no attribute 'twice'",
	"stack": "---------------------------------------------------------------------------
RemoteTraceback                           Traceback (most recent call last)
RemoteTraceback: 
\"\"\"
Traceback (most recent call last):
  File \"/opt/homebrew/Caskroom/mambaforge/base/envs/pymc_env/lib/python3.11/multiprocessing/pool.py\", line 125, in worker
    result = (True, func(*args, **kwds))
                    ^^^^^^^^^^^^^^^^^^^
  File \"/opt/homebrew/Caskroom/mambaforge/base/envs/pymc_env/lib/python3.11/multiprocessing/pool.py\", line 51, in starmapstar
    return list(itertools.starmap(args[0], args[1]))
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File \"/opt/homebrew/Caskroom/mambaforge/base/envs/pymc_env/lib/python3.11/site-packages/pymc/smc/sampling.py\", line 419, in _apply_args_and_kwargs
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File \"/opt/homebrew/Caskroom/mambaforge/base/envs/pymc_env/lib/python3.11/site-packages/pymc/smc/sampling.py\", line 320, in _sample_smc_int
    (draws, kernel, start, model) = map(
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File \"/opt/homebrew/Caskroom/mambaforge/base/envs/pymc_env/lib/python3.11/site-packages/pytensor/compile/ops.py\", line 221, in load_back
    obj = getattr(module, name)
          ^^^^^^^^^^^^^^^^^^^^^
AttributeError: module '__main__' has no attribute 'twice'
\"\"\"

The above exception was the direct cause of the following exception:

AttributeError                            Traceback (most recent call last)
Cell In[14], line 2
      1 with model:
----> 2     trace = pm.sample_smc(10,cores=2)

File /opt/homebrew/Caskroom/mambaforge/base/envs/pymc_env/lib/python3.11/site-packages/pymc/smc/sampling.py:213, in sample_smc(draws, kernel, start, model, random_seed, chains, cores, compute_convergence_checks, return_inferencedata, idata_kwargs, progressbar, **kernel_kwargs)
    210 t1 = time.time()
    212 if cores > 1:
--> 213     results = run_chains_parallel(
    214         chains, progressbar, _sample_smc_int, params, random_seed, kernel_kwargs, cores
    215     )
    216 else:
    217     results = run_chains_sequential(
    218         chains, progressbar, _sample_smc_int, params, random_seed, kernel_kwargs
    219     )

File /opt/homebrew/Caskroom/mambaforge/base/envs/pymc_env/lib/python3.11/site-packages/pymc/smc/sampling.py:388, in run_chains_parallel(chains, progressbar, to_run, params, random_seed, kernel_kwargs, cores)
    386 params = tuple(cloudpickle.dumps(p) for p in params)
    387 kernel_kwargs = {key: cloudpickle.dumps(value) for key, value in kernel_kwargs.items()}
--> 388 results = _starmap_with_kwargs(
    389     pool,
    390     to_run,
    391     [(*params, random_seed[chain], chain, pbars[chain]) for chain in range(chains)],
    392     repeat(kernel_kwargs),
    393 )
    394 results = tuple(cloudpickle.loads(r) for r in results)
    395 pool.close()

File /opt/homebrew/Caskroom/mambaforge/base/envs/pymc_env/lib/python3.11/site-packages/pymc/smc/sampling.py:415, in _starmap_with_kwargs(pool, fn, args_iter, kwargs_iter)
    411 def _starmap_with_kwargs(pool, fn, args_iter, kwargs_iter):
    412     # Helper function to allow kwargs with Pool.starmap
    413     # Copied from https://stackoverflow.com/a/53173433/13311693
    414     args_for_starmap = zip(repeat(fn), args_iter, kwargs_iter)
--> 415     return pool.starmap(_apply_args_and_kwargs, args_for_starmap)

File /opt/homebrew/Caskroom/mambaforge/base/envs/pymc_env/lib/python3.11/multiprocessing/pool.py:375, in Pool.starmap(self, func, iterable, chunksize)
    369 def starmap(self, func, iterable, chunksize=None):
    370     '''
    371     Like `map()` method but the elements of the `iterable` are expected to
    372     be iterables as well and will be unpacked as arguments. Hence
    373     `func` and (a, b) becomes func(a, b).
    374     '''
--> 375     return self._map_async(func, iterable, starmapstar, chunksize).get()

File /opt/homebrew/Caskroom/mambaforge/base/envs/pymc_env/lib/python3.11/multiprocessing/pool.py:774, in ApplyResult.get(self, timeout)
    772     return self._value
    773 else:
--> 774     raise self._value

AttributeError: module '__main__' has no attribute 'twice'"
}
</details>

PyMC version information:

pymc: 5.10.3
pytensor: 2.18.4
python: 3.11.7

Installed in a fresh conda environment with
conda create -c conda-forge -n pymc_env "pymc>=5"

Context for the issue:

As it stands, SMC sampler cannot run the official PyMC example from https://www.pymc.io/projects/examples/en/latest/ode_models/ODE_Lotka_Volterra_multiple_ways.html
Any simple ODE where sunode is overkill will crash similarly, as it requires a custom op, that is not pickled.

The workaround of using a single core makes the method much slower than needed.

Is there a way to serialize the custom operation please?

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions