@@ -291,6 +291,62 @@ def true_logp(y):
291291 )
292292
293293
294+ @pytest .mark .filterwarnings ("error" )
295+ def test_nested_recover_marginals ():
296+ """Test that marginalization works when there are nested marginalized RVs"""
297+
298+ with MarginalModel () as m :
299+ idx = pm .Bernoulli ("idx" , p = 0.75 )
300+ sub_idx = pm .Bernoulli ("sub_idx" , p = pt .switch (pt .eq (idx , 0 ), 0.15 , 0.95 ))
301+ sub_dep = pm .Normal ("y" , mu = idx + sub_idx , sigma = 1.0 )
302+
303+ m .marginalize ([idx , sub_idx ])
304+
305+ rng = np .random .default_rng (211 )
306+
307+ with m :
308+ prior = pm .sample_prior_predictive (
309+ samples = 20 ,
310+ random_seed = rng ,
311+ return_inferencedata = False ,
312+ )
313+ idata = InferenceData (posterior = dict_to_dataset (prior ))
314+
315+ idata = m .recover_marginals (idata , include_samples = True )
316+ assert "idx" in idata
317+ assert "lp_idx" in idata
318+ assert idata .idx .shape == idata .y .shape
319+ assert idata .lp_idx .shape == idata .idx .shape + (2 ,)
320+ assert "sub_idx" in idata
321+ assert "lp_sub_idx" in idata
322+ assert idata .sub_idx .shape == idata .y .shape
323+ assert idata .lp_sub_idx .shape == idata .sub_idx .shape + (2 ,)
324+
325+ def true_idx_logp (y ):
326+ idx_0 = np .log (0.85 * 0.25 * norm .pdf (y , loc = 0 ) +
327+ 0.15 * 0.25 * norm .pdf (y , loc = 1 ))
328+ idx_1 = np .log (0.05 * 0.75 * norm .pdf (y , loc = 1 ) +
329+ 0.95 * 0.75 * norm .pdf (y , loc = 2 ))
330+ return np .stack ([idx_0 , idx_1 ]).T
331+
332+ np .testing .assert_almost_equal (
333+ true_idx_logp (idata .y .values .flatten ()),
334+ idata .lp_idx [0 ].values ,
335+ )
336+
337+ def true_sub_idx_logp (y ):
338+ sub_idx_0 = np .log (0.85 * 0.25 * norm .pdf (y , loc = 0 ) +
339+ 0.05 * 0.75 * norm .pdf (y , loc = 1 ))
340+ sub_idx_1 = np .log (0.15 * 0.25 * norm .pdf (y , loc = 1 ) +
341+ 0.95 * 0.75 * norm .pdf (y , loc = 2 ))
342+ return np .stack ([sub_idx_0 , sub_idx_1 ]).T
343+
344+ np .testing .assert_almost_equal (
345+ true_sub_idx_logp (idata .y .values .flatten ()),
346+ idata .lp_sub_idx [0 ].values ,
347+ )
348+
349+
294350@pytest .mark .filterwarnings ("error" )
295351def test_not_supported_marginalized ():
296352 """Marginalized graphs with non-Elemwise Operations are not supported as they
0 commit comments