Skip to content

Commit 591afd0

Browse files
committed
Add nested model test
1 parent 94fc1b1 commit 591afd0

File tree

1 file changed

+56
-0
lines changed

1 file changed

+56
-0
lines changed

pymc_experimental/tests/model/test_marginal_model.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -291,6 +291,62 @@ def true_logp(y):
291291
)
292292

293293

294+
@pytest.mark.filterwarnings("error")
295+
def test_nested_recover_marginals():
296+
"""Test that marginalization works when there are nested marginalized RVs"""
297+
298+
with MarginalModel() as m:
299+
idx = pm.Bernoulli("idx", p=0.75)
300+
sub_idx = pm.Bernoulli("sub_idx", p=pt.switch(pt.eq(idx, 0), 0.15, 0.95))
301+
sub_dep = pm.Normal("y", mu=idx + sub_idx, sigma=1.0)
302+
303+
m.marginalize([idx, sub_idx])
304+
305+
rng = np.random.default_rng(211)
306+
307+
with m:
308+
prior = pm.sample_prior_predictive(
309+
samples=20,
310+
random_seed=rng,
311+
return_inferencedata=False,
312+
)
313+
idata = InferenceData(posterior=dict_to_dataset(prior))
314+
315+
idata = m.recover_marginals(idata, include_samples=True)
316+
assert "idx" in idata
317+
assert "lp_idx" in idata
318+
assert idata.idx.shape == idata.y.shape
319+
assert idata.lp_idx.shape == idata.idx.shape + (2,)
320+
assert "sub_idx" in idata
321+
assert "lp_sub_idx" in idata
322+
assert idata.sub_idx.shape == idata.y.shape
323+
assert idata.lp_sub_idx.shape == idata.sub_idx.shape + (2,)
324+
325+
def true_idx_logp(y):
326+
idx_0 = np.log(0.85*0.25*norm.pdf(y, loc=0) +
327+
0.15*0.25*norm.pdf(y, loc=1))
328+
idx_1 = np.log(0.05*0.75*norm.pdf(y, loc=1) +
329+
0.95*0.75*norm.pdf(y, loc=2))
330+
return np.stack([idx_0, idx_1]).T
331+
332+
np.testing.assert_almost_equal(
333+
true_idx_logp(idata.y.values.flatten()),
334+
idata.lp_idx[0].values,
335+
)
336+
337+
def true_sub_idx_logp(y):
338+
sub_idx_0 = np.log(0.85*0.25*norm.pdf(y, loc=0) +
339+
0.05*0.75*norm.pdf(y, loc=1))
340+
sub_idx_1 = np.log(0.15*0.25*norm.pdf(y, loc=1) +
341+
0.95*0.75*norm.pdf(y, loc=2))
342+
return np.stack([sub_idx_0, sub_idx_1]).T
343+
344+
np.testing.assert_almost_equal(
345+
true_sub_idx_logp(idata.y.values.flatten()),
346+
idata.lp_sub_idx[0].values,
347+
)
348+
349+
294350
@pytest.mark.filterwarnings("error")
295351
def test_not_supported_marginalized():
296352
"""Marginalized graphs with non-Elemwise Operations are not supported as they

0 commit comments

Comments
 (0)