Skip to content

Commit

Permalink
Allow mutable shape in PartialObservedRVs
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed May 14, 2024
1 parent 3729614 commit 43b40de
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 10 deletions.
4 changes: 3 additions & 1 deletion pymc/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -1630,7 +1630,9 @@ def partial_observed_rv_logprob(op, values, dist, mask, **kwargs):
# For the logp, simply join the values
[obs_value, unobs_value] = values
antimask = ~mask
joined_value = pt.empty(constant_fold([dist.shape])[0])
# We don't need it to be completely folded, just to avoid any RVs in the graph of the shape
[folded_shape] = constant_fold([dist.shape], raise_not_constant=False)
joined_value = pt.empty(folded_shape)
joined_value = pt.set_subtensor(joined_value[mask], unobs_value)
joined_value = pt.set_subtensor(joined_value[antimask], obs_value)
joined_logp = logp(dist, joined_value)
Expand Down
30 changes: 25 additions & 5 deletions tests/distributions/test_distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -979,16 +979,21 @@ def test_univariate(self, symbolic_rv):
np.testing.assert_allclose(obs_logp, st.norm([1, 2]).logpdf([0.25, 0.5]))
np.testing.assert_allclose(unobs_logp, st.norm([3]).logpdf([0.25]))

@pytest.mark.parametrize("mutable_shape", (False, True))
@pytest.mark.parametrize("obs_component_selected", (True, False))
def test_multivariate_constant_mask_separable(self, obs_component_selected):
def test_multivariate_constant_mask_separable(self, obs_component_selected, mutable_shape):
if obs_component_selected:
mask = np.zeros((1, 4), dtype=bool)
else:
mask = np.ones((1, 4), dtype=bool)
obs_data = np.array([[0.1, 0.4, 0.1, 0.4]])
unobs_data = np.array([[0.4, 0.1, 0.4, 0.1]])

rv = pm.Dirichlet.dist([1, 2, 3, 4], shape=(1, 4))
if mutable_shape:
shape = (1, pytensor.shared(np.array(4, dtype=int)))
else:
shape = (1, 4)
rv = pm.Dirichlet.dist(pt.arange(shape[-1]) + 1, shape=shape)
(obs_rv, obs_mask), (unobs_rv, unobs_mask), joined_rv = create_partial_observed_rv(rv, mask)

# Test types
Expand Down Expand Up @@ -1023,6 +1028,10 @@ def test_multivariate_constant_mask_separable(self, obs_component_selected):
np.testing.assert_allclose(obs_logp, expected_obs_logp)
np.testing.assert_allclose(unobs_logp, expected_unobs_logp)

if mutable_shape:
shape[-1].set_value(7)
assert tuple(joined_rv.shape.eval()) == (1, 7)

def test_multivariate_constant_mask_unseparable(self):
mask = pt.constant(np.array([[True, True, False, False]]))
obs_data = np.array([[0.1, 0.4, 0.1, 0.4]])
Expand Down Expand Up @@ -1097,14 +1106,19 @@ def test_multivariate_shared_mask_separable(self):
np.testing.assert_almost_equal(obs_logp, new_expected_logp)
np.testing.assert_array_equal(unobs_logp, [])

def test_multivariate_shared_mask_unseparable(self):
@pytest.mark.parametrize("mutable_shape", (False, True))
def test_multivariate_shared_mask_unseparable(self, mutable_shape):
# Even if the mask is initially not mixing support dims,
# it could later be changed in a way that does!
mask = shared(np.array([[True, True, True, True]]))
obs_data = np.array([[0.1, 0.4, 0.1, 0.4]])
unobs_data = np.array([[0.4, 0.1, 0.4, 0.1]])

rv = pm.Dirichlet.dist([1, 2, 3, 4], shape=(1, 4))
if mutable_shape:
shape = mask.shape
else:
shape = (1, 4)
rv = pm.Dirichlet.dist([1, 2, 3, 4], shape=shape)
(obs_rv, obs_mask), (unobs_rv, unobs_mask), joined_rv = create_partial_observed_rv(rv, mask)

# Test types
Expand Down Expand Up @@ -1134,16 +1148,22 @@ def test_multivariate_shared_mask_unseparable(self):

# Test that we can update a shared mask
mask.set_value(np.array([[False, False, True, True]]))
equivalent_value = np.array([0.1, 0.4, 0.4, 0.1])

assert tuple(obs_rv.shape.eval()) == (2,)
assert tuple(unobs_rv.shape.eval()) == (2,)

new_expected_logp = pm.logp(rv, [0.1, 0.4, 0.4, 0.1]).eval()
new_expected_logp = pm.logp(rv, equivalent_value).eval()
assert not np.isclose(expected_logp, new_expected_logp) # Otherwise test is weak
obs_logp, unobs_logp = logp_fn()
np.testing.assert_almost_equal(obs_logp, new_expected_logp)
np.testing.assert_array_equal(unobs_logp, [])

if mutable_shape:
mask.set_value(np.array([[False, False, True, False], [False, False, False, True]]))
assert tuple(obs_rv.shape.eval()) == (6,)
assert tuple(unobs_rv.shape.eval()) == (2,)

def test_support_point(self):
x = pm.GaussianRandomWalk.dist(init_dist=pm.Normal.dist(-5), mu=1, steps=9)
ref_support_point = support_point(x).eval()
Expand Down
37 changes: 33 additions & 4 deletions tests/model/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import pytensor.sparse as sparse
import pytensor.tensor as pt
import pytest
import scipy
import scipy.sparse as sps
import scipy.stats as st

Expand All @@ -38,7 +39,7 @@

import pymc as pm

from pymc import Deterministic, Model, Potential
from pymc import Deterministic, Model, MvNormal, Potential
from pymc.blocking import DictToArrayBijection, RaveledVars
from pymc.distributions import Normal, transforms
from pymc.distributions.distribution import PartialObservedRV
Expand Down Expand Up @@ -1504,11 +1505,39 @@ def test_truncated_normal(self):
"""
with Model() as m:
mu = pm.TruncatedNormal("mu", mu=1, sigma=2, lower=0)
x = pm.TruncatedNormal(
"x", mu=mu, sigma=0.5, lower=0, observed=np.array([0.1, 0.2, 0.5, np.nan, np.nan])
)
with pytest.warns(ImputationWarning):
x = pm.TruncatedNormal(
"x",
mu=mu,
sigma=0.5,
lower=0,
observed=np.array([0.1, 0.2, 0.5, np.nan, np.nan]),
)
m.check_start_vals(m.initial_point())

def test_coordinates(self):
# Regression test for https://github.com/pymc-devs/pymc/issues/7304

coords = {"trial": range(30), "feature": range(2)}
observed = np.zeros((30, 2))
observed[0, 0] = np.nan

with Model(coords=coords) as model:
with pytest.warns(ImputationWarning):
MvNormal(
"y",
mu=np.zeros(2),
cov=np.eye(2),
observed=observed,
dims=("trial", "feature"),
)

logp_fn = model.compile_logp()
np.testing.assert_allclose(
logp_fn({"y_unobserved": [0]}),
scipy.stats.multivariate_normal.logpdf([0, 0], cov=np.eye(2)) * 30,
)


class TestShared:
def test_deterministic(self):
Expand Down

0 comments on commit 43b40de

Please sign in to comment.