Skip to content

BUG: pymc.sample_smc fails with pymc.CustomDist #7224

Closed
@EliasRas

Description

@EliasRas

Describe the issue:

pymc.sample_smc raises a NotImplementedError due to a missing logp method if a pymc.CustomDist is used in a model without dist argument. In addition to using dist, switching to pm.Potential works.

Reproduceable code example:

import pymc as pm
import numpy as np


def _logp(value, mu, sigma):
    dist = pm.Normal.dist(mu=mu, sigma=sigma)

    return pm.logp(dist, value)


def _random(mu, sigma, rng, size):
    if rng is None:
        rng = np.random.default_rng()
    sample = rng.normal(loc=mu, scale=sigma, size=size)

    return sample


def _logcdf(value, mu, sigma):
    dist = pm.Normal.dist(mu=mu, sigma=sigma)

    return pm.logcdf(dist, value)


def _dist(mu, sigma, size):
    return pm.Normal.dist(mu, sigma, size=size)


def main():
    data = np.random.default_rng().normal(5, 2, 1000)

    with pm.Model():
        mu = pm.Normal("mu", mu=0, sigma=10)
        sigma = pm.HalfNormal("sigma", sigma=10)
        pm.CustomDist(
            "y",
            mu,
            sigma,
            logp=_logp,
            random=_random,
            logcdf=_logcdf,
            observed=data,
        )
        sample = pm.sample_smc()  # NotImplementedError

    with pm.Model():
        pm.CustomDist(
            "y",
            2,
            10,
            logp=_logp,
            random=_random,
            logcdf=_logcdf,
        )
        sample = pm.sample_smc()  # NotImplementedError

    with pm.Model():
        mu = pm.Normal("mu", mu=0, sigma=10)
        sigma = pm.HalfNormal("sigma", sigma=10)
        pm.CustomDist(
            "y",
            mu,
            sigma,
            dist=_dist,
            observed=data,
        )
        sample = pm.sample_smc()  # Works

    with pm.Model():
        mu = pm.Normal("mu", mu=0, sigma=10)
        sigma = pm.HalfNormal("sigma", sigma=10)
        pm.Potential(
            "y",
            _logp(data, mu, sigma),
        )
        sample = pm.sample_smc()  # Works


if __name__ == "__main__":
    main()

Error message:

multiprocessing.pool.RemoteTraceback: 
"""
Traceback (most recent call last):
  File "\envs\pymc\Lib\multiprocessing\pool.py", line 125, in worker
    result = (True, func(*args, **kwds))
                    ^^^^^^^^^^^^^^^^^^^
  File "\envs\pymc\Lib\multiprocessing\pool.py", line 51, in starmapstar
    return list(itertools.starmap(args[0], args[1]))
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "\envs\pymc\Lib\site-packages\pymc\smc\sampling.py", line 421, in _apply_args_and_kwargs
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "\envs\pymc\Lib\site-packages\pymc\smc\sampling.py", line 344, in _sample_smc_int
    smc._initialize_kernel()
  File "\envs\pymc\Lib\site-packages\pymc\smc\kernels.py", line 239, in _initialize_kernel
    initial_point, [self.model.varlogp], self.variables, shared
                    ^^^^^^^^^^^^^^^^^^
  File "\envs\pymc\Lib\site-packages\pymc\model\core.py", line 832, in varlogp
    return self.logp(vars=self.free_RVs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "\envs\pymc\Lib\site-packages\pymc\model\core.py", line 717, in logp
    rv_logps = transformed_conditional_logp(
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "\envs\pymc\Lib\site-packages\pymc\logprob\basic.py", line 612, in transformed_conditional_logp
    temp_logp_terms = conditional_logp(
                      ^^^^^^^^^^^^^^^^^
  File "\envs\pymc\Lib\site-packages\pymc\logprob\basic.py", line 542, in conditional_logp
    q_logprob_vars = _logprob(
                     ^^^^^^^^^
  File "\envs\pymc\Lib\functools.py", line 909, in wrapper
    return dispatch(args[0].__class__)(*args, **kw)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "\envs\pymc\Lib\site-packages\pymc\logprob\abstract.py", line 63, in _logprob
    raise NotImplementedError(f"Logprob method not implemented for {op}")
NotImplementedError: Logprob method not implemented for CustomDist_y_rv{0, (0, 0), floatX, False}
"""

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "issue.py", line 79, in <module>
    main()
  File "issue.py", line 44, in main
    sample = pm.sample_smc()  # Exception has occurred: NotImplementedError
             ^^^^^^^^^^^^^^^
  File "\envs\pymc\Lib\site-packages\pymc\smc\sampling.py", line 213, in sample_smc
    results = run_chains_parallel(
              ^^^^^^^^^^^^^^^^^^^^
  File "\envs\pymc\Lib\site-packages\pymc\smc\sampling.py", line 390, in run_chains_parallel
    results = _starmap_with_kwargs(
              ^^^^^^^^^^^^^^^^^^^^^
  File "\envs\pymc\Lib\site-packages\pymc\smc\sampling.py", line 417, in _starmap_with_kwargs
    return pool.starmap(_apply_args_and_kwargs, args_for_starmap)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "\envs\pymc\Lib\multiprocessing\pool.py", line 375, in starmap
    return self._map_async(func, iterable, starmapstar, chunksize).get()
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "\envs\pymc\Lib\multiprocessing\pool.py", line 774, in get
    raise self._value
NotImplementedError: Logprob method not implemented for CustomDist_y_rv{0, (0, 0), floatX, False}

PyMC version information:

Python 3.11.7
pymc 5.10.0
pytensor 2.18.6
Win 10
Environment set up via conda but updated pymc and pytensor with pip

Also fails with these environments
conda_env.txt
conda_env_dev.txt

Context for the issue:

I'm testing a model which suffers from slow sampling, possibly due to expensive gradient calculations. I tested SMC as a possible solution as suggested on the forums but got this error message.

Using the dist argument could work in most cases, but there's cases when the distributions provided by pymc are not enough. Using pm.Potential could help with sampling but that would in turn make forward sampling less straightforward.

Metadata

Metadata

Assignees

No one assigned

    Labels

    SMCSequential Monte Carlobug

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions