Closed
Description
Describe the issue:
pymc.sample_smc
raises a NotImplementedError
due to a missing logp method if a pymc.CustomDist
is used in a model without dist
argument. In addition to using dist
, switching to pm.Potential
works.
Reproduceable code example:
import pymc as pm
import numpy as np
def _logp(value, mu, sigma):
dist = pm.Normal.dist(mu=mu, sigma=sigma)
return pm.logp(dist, value)
def _random(mu, sigma, rng, size):
if rng is None:
rng = np.random.default_rng()
sample = rng.normal(loc=mu, scale=sigma, size=size)
return sample
def _logcdf(value, mu, sigma):
dist = pm.Normal.dist(mu=mu, sigma=sigma)
return pm.logcdf(dist, value)
def _dist(mu, sigma, size):
return pm.Normal.dist(mu, sigma, size=size)
def main():
data = np.random.default_rng().normal(5, 2, 1000)
with pm.Model():
mu = pm.Normal("mu", mu=0, sigma=10)
sigma = pm.HalfNormal("sigma", sigma=10)
pm.CustomDist(
"y",
mu,
sigma,
logp=_logp,
random=_random,
logcdf=_logcdf,
observed=data,
)
sample = pm.sample_smc() # NotImplementedError
with pm.Model():
pm.CustomDist(
"y",
2,
10,
logp=_logp,
random=_random,
logcdf=_logcdf,
)
sample = pm.sample_smc() # NotImplementedError
with pm.Model():
mu = pm.Normal("mu", mu=0, sigma=10)
sigma = pm.HalfNormal("sigma", sigma=10)
pm.CustomDist(
"y",
mu,
sigma,
dist=_dist,
observed=data,
)
sample = pm.sample_smc() # Works
with pm.Model():
mu = pm.Normal("mu", mu=0, sigma=10)
sigma = pm.HalfNormal("sigma", sigma=10)
pm.Potential(
"y",
_logp(data, mu, sigma),
)
sample = pm.sample_smc() # Works
if __name__ == "__main__":
main()
Error message:
multiprocessing.pool.RemoteTraceback:
"""
Traceback (most recent call last):
File "\envs\pymc\Lib\multiprocessing\pool.py", line 125, in worker
result = (True, func(*args, **kwds))
^^^^^^^^^^^^^^^^^^^
File "\envs\pymc\Lib\multiprocessing\pool.py", line 51, in starmapstar
return list(itertools.starmap(args[0], args[1]))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "\envs\pymc\Lib\site-packages\pymc\smc\sampling.py", line 421, in _apply_args_and_kwargs
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "\envs\pymc\Lib\site-packages\pymc\smc\sampling.py", line 344, in _sample_smc_int
smc._initialize_kernel()
File "\envs\pymc\Lib\site-packages\pymc\smc\kernels.py", line 239, in _initialize_kernel
initial_point, [self.model.varlogp], self.variables, shared
^^^^^^^^^^^^^^^^^^
File "\envs\pymc\Lib\site-packages\pymc\model\core.py", line 832, in varlogp
return self.logp(vars=self.free_RVs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "\envs\pymc\Lib\site-packages\pymc\model\core.py", line 717, in logp
rv_logps = transformed_conditional_logp(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "\envs\pymc\Lib\site-packages\pymc\logprob\basic.py", line 612, in transformed_conditional_logp
temp_logp_terms = conditional_logp(
^^^^^^^^^^^^^^^^^
File "\envs\pymc\Lib\site-packages\pymc\logprob\basic.py", line 542, in conditional_logp
q_logprob_vars = _logprob(
^^^^^^^^^
File "\envs\pymc\Lib\functools.py", line 909, in wrapper
return dispatch(args[0].__class__)(*args, **kw)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "\envs\pymc\Lib\site-packages\pymc\logprob\abstract.py", line 63, in _logprob
raise NotImplementedError(f"Logprob method not implemented for {op}")
NotImplementedError: Logprob method not implemented for CustomDist_y_rv{0, (0, 0), floatX, False}
"""
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "issue.py", line 79, in <module>
main()
File "issue.py", line 44, in main
sample = pm.sample_smc() # Exception has occurred: NotImplementedError
^^^^^^^^^^^^^^^
File "\envs\pymc\Lib\site-packages\pymc\smc\sampling.py", line 213, in sample_smc
results = run_chains_parallel(
^^^^^^^^^^^^^^^^^^^^
File "\envs\pymc\Lib\site-packages\pymc\smc\sampling.py", line 390, in run_chains_parallel
results = _starmap_with_kwargs(
^^^^^^^^^^^^^^^^^^^^^
File "\envs\pymc\Lib\site-packages\pymc\smc\sampling.py", line 417, in _starmap_with_kwargs
return pool.starmap(_apply_args_and_kwargs, args_for_starmap)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "\envs\pymc\Lib\multiprocessing\pool.py", line 375, in starmap
return self._map_async(func, iterable, starmapstar, chunksize).get()
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "\envs\pymc\Lib\multiprocessing\pool.py", line 774, in get
raise self._value
NotImplementedError: Logprob method not implemented for CustomDist_y_rv{0, (0, 0), floatX, False}
PyMC version information:
Python 3.11.7
pymc 5.10.0
pytensor 2.18.6
Win 10
Environment set up via conda but updated pymc and pytensor with pip
Also fails with these environments
conda_env.txt
conda_env_dev.txt
Context for the issue:
I'm testing a model which suffers from slow sampling, possibly due to expensive gradient calculations. I tested SMC as a possible solution as suggested on the forums but got this error message.
Using the dist
argument could work in most cases, but there's cases when the distributions provided by pymc
are not enough. Using pm.Potential
could help with sampling but that would in turn make forward sampling less straightforward.