Skip to content

Allow customizing style of model_graph nodes #7302

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
May 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
233 changes: 194 additions & 39 deletions pymc/model_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,10 @@
import warnings

from collections import defaultdict
from collections.abc import Iterable, Sequence
from collections.abc import Callable, Iterable, Sequence
from enum import Enum
from os import path
from typing import Any

from pytensor import function
from pytensor.graph import Apply
Expand All @@ -41,6 +43,119 @@ def fast_eval(var):
return function([], var, mode="FAST_COMPILE")()


class NodeType(str, Enum):
"""Enum for the types of nodes in the graph."""

POTENTIAL = "Potential"
FREE_RV = "Free Random Variable"
OBSERVED_RV = "Observed Random Variable"
DETERMINISTIC = "Deterministic"
DATA = "Data"


GraphvizNodeKwargs = dict[str, Any]
NodeFormatter = Callable[[TensorVariable], GraphvizNodeKwargs]


def default_potential(var: TensorVariable) -> GraphvizNodeKwargs:
"""Default data for potential in the graph."""
return {
"shape": "octagon",
"style": "filled",
"label": f"{var.name}\n~\nPotential",
}


def random_variable_symbol(var: TensorVariable) -> str:
"""Get the symbol of the random variable."""
symbol = var.owner.op.__class__.__name__

if symbol.endswith("RV"):
symbol = symbol[:-2]

return symbol


def default_free_rv(var: TensorVariable) -> GraphvizNodeKwargs:
"""Default data for free RV in the graph."""
symbol = random_variable_symbol(var)

return {
"shape": "ellipse",
"style": None,
"label": f"{var.name}\n~\n{symbol}",
}


def default_observed_rv(var: TensorVariable) -> GraphvizNodeKwargs:
"""Default data for observed RV in the graph."""
symbol = random_variable_symbol(var)

return {
"shape": "ellipse",
"style": "filled",
"label": f"{var.name}\n~\n{symbol}",
}


def default_deterministic(var: TensorVariable) -> GraphvizNodeKwargs:
"""Default data for the deterministic in the graph."""
return {
"shape": "box",
"style": None,
"label": f"{var.name}\n~\nDeterministic",
}


def default_data(var: TensorVariable) -> GraphvizNodeKwargs:
"""Default data for the data in the graph."""
return {
"shape": "box",
"style": "rounded, filled",
"label": f"{var.name}\n~\nData",
}


def get_node_type(var_name: VarName, model) -> NodeType:
"""Return the node type of the variable in the model."""
v = model[var_name]

if v in model.deterministics:
return NodeType.DETERMINISTIC
elif v in model.free_RVs:
return NodeType.FREE_RV
elif v in model.observed_RVs:
return NodeType.OBSERVED_RV
elif v in model.data_vars:
return NodeType.DATA
else:
return NodeType.POTENTIAL


NodeTypeFormatterMapping = dict[NodeType, NodeFormatter]

DEFAULT_NODE_FORMATTERS: NodeTypeFormatterMapping = {
NodeType.POTENTIAL: default_potential,
NodeType.FREE_RV: default_free_rv,
NodeType.OBSERVED_RV: default_observed_rv,
NodeType.DETERMINISTIC: default_deterministic,
NodeType.DATA: default_data,
}


def update_node_formatters(node_formatters: NodeTypeFormatterMapping) -> NodeTypeFormatterMapping:
node_formatters = {**DEFAULT_NODE_FORMATTERS, **node_formatters}

unknown_keys = set(node_formatters.keys()) - set(NodeType)
if unknown_keys:
raise ValueError(
f"Node formatters must be of type NodeType. Found: {list(unknown_keys)}."
f" Please use one of {[node_type.value for node_type in NodeType]}."
)

return node_formatters


class ModelGraph:
def __init__(self, model):
self.model = model
Expand Down Expand Up @@ -148,42 +263,23 @@ def make_compute_graph(

return input_map

def _make_node(self, var_name, graph, *, nx=False, cluster=False, formatting: str = "plain"):
def _make_node(
self,
var_name,
graph,
*,
node_formatters: NodeTypeFormatterMapping,
nx=False,
cluster=False,
formatting: str = "plain",
):
"""Attaches the given variable to a graphviz or networkx Digraph"""
v = self.model[var_name]

shape = None
style = None
label = str(v)

if v in self.model.potentials:
shape = "octagon"
style = "filled"
label = f"{var_name}\n~\nPotential"
elif v in self.model.basic_RVs:
shape = "ellipse"
if v in self.model.observed_RVs:
style = "filled"
else:
style = None
symbol = v.owner.op.__class__.__name__
if symbol.endswith("RV"):
symbol = symbol[:-2]
label = f"{var_name}\n~\n{symbol}"
elif v in self.model.deterministics:
shape = "box"
style = None
label = f"{var_name}\n~\nDeterministic"
else:
shape = "box"
style = "rounded, filled"
label = f"{var_name}\n~\nData"

kwargs = {
"shape": shape,
"style": style,
"label": label,
}
node_type = get_node_type(var_name, self.model)
node_formatter = node_formatters[node_type]

kwargs = node_formatter(v)

if cluster:
kwargs["cluster"] = cluster
Expand Down Expand Up @@ -240,6 +336,7 @@ def make_graph(
save=None,
figsize=None,
dpi=300,
node_formatters: NodeTypeFormatterMapping | None = None,
):
"""Make graphviz Digraph of PyMC model

Expand All @@ -255,18 +352,26 @@ def make_graph(
"The easiest way to install all of this is by running\n\n"
"\tconda install -c conda-forge python-graphviz"
)

node_formatters = node_formatters or {}
node_formatters = update_node_formatters(node_formatters)

graph = graphviz.Digraph(self.model.name)
for plate_label, all_var_names in self.get_plates(var_names).items():
if plate_label:
# must be preceded by 'cluster' to get a box around it
with graph.subgraph(name="cluster" + plate_label) as sub:
for var_name in all_var_names:
self._make_node(var_name, sub, formatting=formatting)
self._make_node(
var_name, sub, formatting=formatting, node_formatters=node_formatters
)
# plate label goes bottom right
sub.attr(label=plate_label, labeljust="r", labelloc="b", style="rounded")
else:
for var_name in all_var_names:
self._make_node(var_name, graph, formatting=formatting)
self._make_node(
var_name, graph, formatting=formatting, node_formatters=node_formatters
)

for child, parents in self.make_compute_graph(var_names=var_names).items():
# parents is a set of rv names that precede child rv nodes
Expand All @@ -287,7 +392,12 @@ def make_graph(

return graph

def make_networkx(self, var_names: Iterable[VarName] | None = None, formatting: str = "plain"):
def make_networkx(
self,
var_names: Iterable[VarName] | None = None,
formatting: str = "plain",
node_formatters: NodeTypeFormatterMapping | None = None,
):
"""Make networkx Digraph of PyMC model

Returns
Expand All @@ -302,6 +412,10 @@ def make_networkx(self, var_names: Iterable[VarName] | None = None, formatting:
"The easiest way to install all of this is by running\n\n"
"\tconda install networkx"
)

node_formatters = node_formatters or {}
node_formatters = update_node_formatters(node_formatters)

graphnetwork = networkx.DiGraph(name=self.model.name)
for plate_label, all_var_names in self.get_plates(var_names).items():
if plate_label:
Expand All @@ -314,6 +428,7 @@ def make_networkx(self, var_names: Iterable[VarName] | None = None, formatting:
var_name,
subgraphnetwork,
nx=True,
node_formatters=node_formatters,
cluster="cluster" + plate_label,
formatting=formatting,
)
Expand All @@ -332,7 +447,13 @@ def make_networkx(self, var_names: Iterable[VarName] | None = None, formatting:
graphnetwork.graph["name"] = self.model.name
else:
for var_name in all_var_names:
self._make_node(var_name, graphnetwork, nx=True, formatting=formatting)
self._make_node(
var_name,
graphnetwork,
nx=True,
formatting=formatting,
node_formatters=node_formatters,
)

for child, parents in self.make_compute_graph(var_names=var_names).items():
# parents is a set of rv names that precede child rv nodes
Expand All @@ -346,6 +467,7 @@ def model_to_networkx(
*,
var_names: Iterable[VarName] | None = None,
formatting: str = "plain",
node_formatters: NodeTypeFormatterMapping | None = None,
):
"""Produce a networkx Digraph from a PyMC model.

Expand All @@ -367,6 +489,10 @@ def model_to_networkx(
Subset of variables to be plotted that identify a subgraph with respect to the entire model graph
formatting : str, optional
one of { "plain" }
node_formatters : dict, optional
A dictionary mapping node types to functions that return a dictionary of node attributes.
Check out the networkx documentation for more information
how attributes are added to nodes: https://networkx.org/documentation/stable/reference/classes/generated/networkx.Graph.add_node.html

Examples
--------
Expand All @@ -392,6 +518,17 @@ def model_to_networkx(
obs = Normal("obs", theta, sigma=sigma, observed=y)

model_to_networkx(schools)

Add custom attributes to Free Random Variables and Observed Random Variables nodes.

.. code-block:: python

node_formatters = {
"Free Random Variable": lambda var: {"shape": "circle", "label": var.name},
"Observed Random Variable": lambda var: {"shape": "square", "label": var.name},
}
model_to_networkx(schools, node_formatters=node_formatters)

"""
if "plain" not in formatting:
raise ValueError(f"Unsupported formatting for graph nodes: '{formatting}'. See docstring.")
Expand All @@ -403,7 +540,9 @@ def model_to_networkx(
stacklevel=2,
)
model = pm.modelcontext(model)
return ModelGraph(model).make_networkx(var_names=var_names, formatting=formatting)
return ModelGraph(model).make_networkx(
var_names=var_names, formatting=formatting, node_formatters=node_formatters
)


def model_to_graphviz(
Expand All @@ -414,6 +553,7 @@ def model_to_graphviz(
save: str | None = None,
figsize: tuple[int, int] | None = None,
dpi: int = 300,
node_formatters: NodeTypeFormatterMapping | None = None,
):
"""Produce a graphviz Digraph from a PyMC model.

Expand Down Expand Up @@ -441,6 +581,10 @@ def model_to_graphviz(
the size of the saved figure.
dpi : int, optional
Dots per inch. It only affects the resolution of the saved figure. The default is 300.
node_formatters : dict, optional
A dictionary mapping node types to functions that return a dictionary of node attributes.
Check out graphviz documentation for more information on available
attributes. https://graphviz.org/docs/nodes/

Examples
--------
Expand Down Expand Up @@ -475,6 +619,16 @@ def model_to_graphviz(

# creates the file `schools.pdf`
model_to_graphviz(schools).render("schools")

Display Free Random Variables and Observed Random Variables nodes with custom formatting.

.. code-block:: python

node_formatters = {
"Free Random Variable": lambda var: {"shape": "circle", "label": var.name},
"Observed Random Variable": lambda var: {"shape": "square", "label": var.name},
}
model_to_graphviz(schools, node_formatters=node_formatters)
"""
if "plain" not in formatting:
raise ValueError(f"Unsupported formatting for graph nodes: '{formatting}'. See docstring.")
Expand All @@ -491,4 +645,5 @@ def model_to_graphviz(
save=save,
figsize=figsize,
dpi=dpi,
node_formatters=node_formatters,
)
Loading