Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Marginalization is reset by freeze_dims_and_data #383

Open
jessegrabowski opened this issue Oct 22, 2024 · 1 comment · May be fixed by #388
Open

Marginalization is reset by freeze_dims_and_data #383

jessegrabowski opened this issue Oct 22, 2024 · 1 comment · May be fixed by #388

Comments

@jessegrabowski
Copy link
Member

Wasn't sure which repo this belongs in. If you marginalize a discrete variable with MarginalModel then call freeze_dims_and_data, the marginalization is undone:

import pymc as pm
from pymc_experimental import MarginalModel
from pymc.model.transform.optimization import freeze_dims_and_data
import pytensor.tensor as pt

with MarginalModel() as m:
    p = pm.Beta('p', 1, 1)
    idx = pm.Bernoulli('idx', p=p, size=(100,))
    mu = pm.Normal('mu', 0, [1, 100])
    x = pm.Normal('x', pm.math.switch(pt.eq(idx, 0) , mu[0], mu[1]), 1)

m.marginal(['idx'])
pm.inputvars(m.logp())   # [p_logodds__, mu, x]

pm.inputvars(freeze_dims_and_data(m).logp())  # Raises ValueError: Random variables detected in the logp graph
Full Traceback
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[19], line 1
----> 1 pm.inputvars(freeze_dims_and_data(m).logp())

File ~/mambaforge/envs/readystate-bonds/lib/python3.11/site-packages/pymc/model/core.py:742, in Model.logp(self, vars, jacobian, sum)
    740 rv_logps: list[TensorVariable] = []
    741 if rvs:
--> 742     rv_logps = transformed_conditional_logp(
    743         rvs=rvs,
    744         rvs_to_values=self.rvs_to_values,
    745         rvs_to_transforms=self.rvs_to_transforms,
    746         jacobian=jacobian,
    747     )
    748     assert isinstance(rv_logps, list)
    750 # Replace random variables by their value variables in potential terms

File ~/mambaforge/envs/readystate-bonds/lib/python3.11/site-packages/pymc/logprob/basic.py:630, in transformed_conditional_logp(rvs, rvs_to_values, rvs_to_transforms, jacobian, **kwargs)
    628 rvs_in_logp_expressions = _find_unallowed_rvs_in_graph(logp_terms_list)
    629 if rvs_in_logp_expressions:
--> 630     raise ValueError(RVS_IN_JOINT_LOGP_GRAPH_MSG % rvs_in_logp_expressions)
    632 return logp_terms_list

ValueError: Random variables detected in the logp graph: {bernoulli_rv{"()->()"}.out}.
This can happen when DensityDist logp or Interval transform functions reference nonlocal variables,
or when not all rvs have a corresponding value variable.
@ricardoV94
Copy link
Member

MarginalModel is not compatible with any model transformations. It's a temporary limitation until we get rid of the subclass.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
2 participants