Skip to content

Disconnected node in model graph after deterministic operations #7722

Open
@williambdean

Description

@williambdean

The models that are now allowed following #7656 have a disconnected node in the model graph.

The sampling is as expected. It is just the graphviz representation that is incorrect.

import numpy as np
import pymc as pm
from pymc.model_graph import ModelGraph

seed = sum(map(ord, "Observed disconnected node"))
rng = np.random.default_rng(seed)

true_mu = 100
true_sigma = 30

n_obs = 10
coords = {
    "date": np.arange(n_obs),
}

dist = pm.Normal.dist(mu=true_mu, sigma=true_sigma, shape=n_obs)
data = pm.draw(dist, random_seed=rng)

scaling = data.max()

with pm.Model(coords=coords) as model:
    mu = pm.Normal("mu")
    sigma = pm.HalfNormal("sigma")

    target = pm.Data("target", data, dims="date")
    scaled_target = target / scaling

    pm.Normal("observed", mu=mu, sigma=sigma, observed=scaled_target, dims="date")

pm.model_to_graphviz(model).render("scaled_target")

ModelGraph(model).make_compute_graph()

The observed should have "target" in the compute_graph

defaultdict(set,
            {'mu': set(),
             'sigma': set(),
             'target': set(),
             'observed': {'mu', 'sigma'}})

Seems like it needs a fix here:

pymc/pymc/model_graph.py

Lines 322 to 343 in af81955

if var in self.model.observed_RVs:
obs_node = self.model.rvs_to_values[var]
# loop created so that the elif block can go through this again
# and remove any intermediate ops, notably dtype casting, to observations
while True:
obs_name = obs_node.name
if obs_name and obs_name != var_name:
input_map[var_name] = input_map[var_name].difference({obs_name})
input_map[obs_name] = input_map[obs_name].union({var_name})
break
elif (
# for cases where observations are cast to a certain dtype
# see issue 5795: https://github.com/pymc-devs/pymc/issues/5795
obs_node.owner
and isinstance(obs_node.owner.op, Elemwise)
and isinstance(obs_node.owner.op.scalar_op, Cast)
):
# we can retrieve the observation node by going up the graph
obs_node = obs_node.owner.inputs[0]
else:
break

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions