diff --git a/pymc/distributions/distribution.py b/pymc/distributions/distribution.py index 214ec505d7..bcd4613918 100644 --- a/pymc/distributions/distribution.py +++ b/pymc/distributions/distribution.py @@ -1630,7 +1630,9 @@ def partial_observed_rv_logprob(op, values, dist, mask, **kwargs): # For the logp, simply join the values [obs_value, unobs_value] = values antimask = ~mask - joined_value = pt.empty(constant_fold([dist.shape])[0]) + # We don't need it to be completely folded, just to avoid any RVs in the graph of the shape + [folded_shape] = constant_fold([dist.shape], raise_not_constant=False) + joined_value = pt.empty(folded_shape) joined_value = pt.set_subtensor(joined_value[mask], unobs_value) joined_value = pt.set_subtensor(joined_value[antimask], obs_value) joined_logp = logp(dist, joined_value) diff --git a/tests/distributions/test_distribution.py b/tests/distributions/test_distribution.py index d5f3359dd1..caac1777bc 100644 --- a/tests/distributions/test_distribution.py +++ b/tests/distributions/test_distribution.py @@ -979,8 +979,9 @@ def test_univariate(self, symbolic_rv): np.testing.assert_allclose(obs_logp, st.norm([1, 2]).logpdf([0.25, 0.5])) np.testing.assert_allclose(unobs_logp, st.norm([3]).logpdf([0.25])) + @pytest.mark.parametrize("mutable_shape", (False, True)) @pytest.mark.parametrize("obs_component_selected", (True, False)) - def test_multivariate_constant_mask_separable(self, obs_component_selected): + def test_multivariate_constant_mask_separable(self, obs_component_selected, mutable_shape): if obs_component_selected: mask = np.zeros((1, 4), dtype=bool) else: @@ -988,7 +989,11 @@ def test_multivariate_constant_mask_separable(self, obs_component_selected): obs_data = np.array([[0.1, 0.4, 0.1, 0.4]]) unobs_data = np.array([[0.4, 0.1, 0.4, 0.1]]) - rv = pm.Dirichlet.dist([1, 2, 3, 4], shape=(1, 4)) + if mutable_shape: + shape = (1, pytensor.shared(np.array(4, dtype=int))) + else: + shape = (1, 4) + rv = pm.Dirichlet.dist(pt.arange(shape[-1]) + 1, shape=shape) (obs_rv, obs_mask), (unobs_rv, unobs_mask), joined_rv = create_partial_observed_rv(rv, mask) # Test types @@ -1023,6 +1028,10 @@ def test_multivariate_constant_mask_separable(self, obs_component_selected): np.testing.assert_allclose(obs_logp, expected_obs_logp) np.testing.assert_allclose(unobs_logp, expected_unobs_logp) + if mutable_shape: + shape[-1].set_value(7) + assert tuple(joined_rv.shape.eval()) == (1, 7) + def test_multivariate_constant_mask_unseparable(self): mask = pt.constant(np.array([[True, True, False, False]])) obs_data = np.array([[0.1, 0.4, 0.1, 0.4]]) @@ -1097,14 +1106,19 @@ def test_multivariate_shared_mask_separable(self): np.testing.assert_almost_equal(obs_logp, new_expected_logp) np.testing.assert_array_equal(unobs_logp, []) - def test_multivariate_shared_mask_unseparable(self): + @pytest.mark.parametrize("mutable_shape", (False, True)) + def test_multivariate_shared_mask_unseparable(self, mutable_shape): # Even if the mask is initially not mixing support dims, # it could later be changed in a way that does! mask = shared(np.array([[True, True, True, True]])) obs_data = np.array([[0.1, 0.4, 0.1, 0.4]]) unobs_data = np.array([[0.4, 0.1, 0.4, 0.1]]) - rv = pm.Dirichlet.dist([1, 2, 3, 4], shape=(1, 4)) + if mutable_shape: + shape = mask.shape + else: + shape = (1, 4) + rv = pm.Dirichlet.dist([1, 2, 3, 4], shape=shape) (obs_rv, obs_mask), (unobs_rv, unobs_mask), joined_rv = create_partial_observed_rv(rv, mask) # Test types @@ -1134,16 +1148,22 @@ def test_multivariate_shared_mask_unseparable(self): # Test that we can update a shared mask mask.set_value(np.array([[False, False, True, True]])) + equivalent_value = np.array([0.1, 0.4, 0.4, 0.1]) assert tuple(obs_rv.shape.eval()) == (2,) assert tuple(unobs_rv.shape.eval()) == (2,) - new_expected_logp = pm.logp(rv, [0.1, 0.4, 0.4, 0.1]).eval() + new_expected_logp = pm.logp(rv, equivalent_value).eval() assert not np.isclose(expected_logp, new_expected_logp) # Otherwise test is weak obs_logp, unobs_logp = logp_fn() np.testing.assert_almost_equal(obs_logp, new_expected_logp) np.testing.assert_array_equal(unobs_logp, []) + if mutable_shape: + mask.set_value(np.array([[False, False, True, False], [False, False, False, True]])) + assert tuple(obs_rv.shape.eval()) == (6,) + assert tuple(unobs_rv.shape.eval()) == (2,) + def test_support_point(self): x = pm.GaussianRandomWalk.dist(init_dist=pm.Normal.dist(-5), mu=1, steps=9) ref_support_point = support_point(x).eval() diff --git a/tests/model/test_core.py b/tests/model/test_core.py index cdb93c72a8..c00250b739 100644 --- a/tests/model/test_core.py +++ b/tests/model/test_core.py @@ -28,6 +28,7 @@ import pytensor.sparse as sparse import pytensor.tensor as pt import pytest +import scipy import scipy.sparse as sps import scipy.stats as st @@ -38,7 +39,7 @@ import pymc as pm -from pymc import Deterministic, Model, Potential +from pymc import Deterministic, Model, MvNormal, Potential from pymc.blocking import DictToArrayBijection, RaveledVars from pymc.distributions import Normal, transforms from pymc.distributions.distribution import PartialObservedRV @@ -1504,11 +1505,39 @@ def test_truncated_normal(self): """ with Model() as m: mu = pm.TruncatedNormal("mu", mu=1, sigma=2, lower=0) - x = pm.TruncatedNormal( - "x", mu=mu, sigma=0.5, lower=0, observed=np.array([0.1, 0.2, 0.5, np.nan, np.nan]) - ) + with pytest.warns(ImputationWarning): + x = pm.TruncatedNormal( + "x", + mu=mu, + sigma=0.5, + lower=0, + observed=np.array([0.1, 0.2, 0.5, np.nan, np.nan]), + ) m.check_start_vals(m.initial_point()) + def test_coordinates(self): + # Regression test for https://github.com/pymc-devs/pymc/issues/7304 + + coords = {"trial": range(30), "feature": range(2)} + observed = np.zeros((30, 2)) + observed[0, 0] = np.nan + + with Model(coords=coords) as model: + with pytest.warns(ImputationWarning): + MvNormal( + "y", + mu=np.zeros(2), + cov=np.eye(2), + observed=observed, + dims=("trial", "feature"), + ) + + logp_fn = model.compile_logp() + np.testing.assert_allclose( + logp_fn({"y_unobserved": [0]}), + scipy.stats.multivariate_normal.logpdf([0, 0], cov=np.eye(2)) * 30, + ) + class TestShared: def test_deterministic(self):