Skip to content

Commit a8cc8f2

Browse files
committed
Update MarginalModel
Also fixes error when jaxifying logp
1 parent 08bf742 commit a8cc8f2

File tree

2 files changed

+33
-19
lines changed

2 files changed

+33
-19
lines changed

pymc_experimental/model/marginal_model.py

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -410,7 +410,7 @@ def transform_input(inputs):
410410
marginalized_rv.type, dependent_logps
411411
)
412412

413-
rv_shape = constant_fold(tuple(marginalized_rv.shape))
413+
rv_shape = constant_fold(tuple(marginalized_rv.shape), raise_not_constant=False)
414414
rv_domain = get_domain_of_finite_discrete_rv(marginalized_rv)
415415
rv_domain_tensor = pt.moveaxis(
416416
pt.full(
@@ -579,6 +579,15 @@ def is_elemwise_subgraph(rv_to_marginalize, other_input_rvs, output_rvs):
579579
return True
580580

581581

582+
from pytensor.graph.basic import graph_inputs
583+
584+
585+
def collect_shared_vars(outputs, blockers):
586+
return [
587+
inp for inp in graph_inputs(outputs, blockers=blockers) if isinstance(inp, SharedVariable)
588+
]
589+
590+
582591
def replace_finite_discrete_marginal_subgraph(fgraph, rv_to_marginalize, all_rvs):
583592
# TODO: This should eventually be integrated in a more general routine that can
584593
# identify other types of supported marginalization, of which finite discrete
@@ -621,27 +630,21 @@ def replace_finite_discrete_marginal_subgraph(fgraph, rv_to_marginalize, all_rvs
621630
rvs_to_marginalize = [rv_to_marginalize, *dependent_rvs]
622631

623632
outputs = rvs_to_marginalize
624-
# Clone replace inner RV rng inputs so that we can be sure of the update order
625-
# replace_inputs = {rng: rng.type() for rng in updates_rvs_to_marginalize.keys()}
626-
# Clone replace outter RV inputs, so that their shared RNGs don't make it into
627-
# the inner graph of the marginalized RVs
628-
# FIXME: This shouldn't be needed!
629-
replace_inputs = {}
630-
replace_inputs.update({input_rv: input_rv.type() for input_rv in input_rvs})
631-
cloned_outputs = clone_replace(outputs, replace=replace_inputs)
633+
# We are strict about shared variables in SymbolicRandomVariables
634+
inputs = input_rvs + collect_shared_vars(rvs_to_marginalize, blockers=input_rvs)
632635

633636
if isinstance(rv_to_marginalize.owner.op, DiscreteMarkovChain):
634637
marginalize_constructor = DiscreteMarginalMarkovChainRV
635638
else:
636639
marginalize_constructor = FiniteDiscreteMarginalRV
637640

638641
marginalization_op = marginalize_constructor(
639-
inputs=list(replace_inputs.values()),
640-
outputs=cloned_outputs,
642+
inputs=inputs,
643+
outputs=outputs,
641644
ndim_supp=ndim_supp,
642645
)
643646

644-
marginalized_rvs = marginalization_op(*replace_inputs.keys())
647+
marginalized_rvs = marginalization_op(*inputs)
645648
fgraph.replace_all(tuple(zip(rvs_to_marginalize, marginalized_rvs)))
646649
return rvs_to_marginalize, marginalized_rvs
647650

pymc_experimental/tests/model/test_marginal_model.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -56,13 +56,8 @@ def test_marginalized_bernoulli_logp():
5656
idx = pm.Bernoulli.dist(0.7, name="idx")
5757
y = pm.Normal.dist(mu=mu[idx], sigma=1.0, name="y")
5858
marginal_rv_node = FiniteDiscreteMarginalRV(
59-
[mu],
60-
[idx, y],
61-
ndim_supp=0,
62-
n_updates=0,
63-
)(
64-
mu
65-
)[0].owner
59+
[mu], [idx, y], ndim_supp=0, n_updates=0, strict=False
60+
)(mu)[0].owner
6661

6762
y_vv = y.clone()
6863
(logp,) = _logprob(
@@ -758,3 +753,19 @@ def test_marginalized_hmm_multiple_emissions(batch_emission1, batch_emission2):
758753
test_value_emission2 = np.broadcast_to(-test_value, emission2_shape)
759754
test_point = {"emission_1": test_value_emission1, "emission_2": test_value_emission2}
760755
np.testing.assert_allclose(logp_fn(test_point), expected_logp)
756+
757+
758+
def test_mutable_indexing_jax_backend():
759+
pytest.importorskip("jax")
760+
from pymc.sampling.jax import get_jaxified_logp
761+
762+
with MarginalModel() as model:
763+
data = pm.Data(f"data", np.zeros(10))
764+
765+
cat_effect = pm.Normal("cat_effect", sigma=1, shape=5)
766+
cat_effect_idx = pm.Data("cat_effect_idx", np.array([0, 1] * 5))
767+
768+
is_outlier = pm.Bernoulli("is_outlier", 0.4, shape=10)
769+
pm.LogNormal("y", mu=cat_effect[cat_effect_idx], sigma=1 + is_outlier, observed=data)
770+
model.marginalize(["is_outlier"])
771+
get_jaxified_logp(model)

0 commit comments

Comments
 (0)