-
-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Refactor model graph and allow suppressing dim lengths #7392
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactor model graph and allow suppressing dim lengths #7392
Conversation
pymc/model_graph.py
Outdated
# parents is a set of rv names that precede child rv nodes | ||
for parent in parents: | ||
yield child.replace(":", "&"), parent.replace(":", "&") | ||
|
||
def make_graph( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should make_graph
and make_networkx
now be functions that take plates and edges as inputs?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That would just remove calling get_plates
and edges
methods. Don't have much of a preference
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would make it more modular, in that if you find a way to create your own plates and edges, you can just pass it to the functions that then display it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, sure. I think that makes sense then.
The dictionary of {PlateMeta : set[NodeMeta]} is a bit weird and hard to work with. i.e. set is not subscritable and looking up by PlateMeta key is a bit tricky.
I was thinking of having another object, Plate
which would be:
@dataclass
class Plate:
plate_meta: PlateMeta
nodes: list[NodeMeta]
and that would be in the input to make_graph
and make_networkx
instead. Making the signature: (plates: list[Plate], edges: list[tuple[str, str]], ...)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, does it make sense as a method still? Do you see model_to_graphviz taking this input as well?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Lost track of the specific methods we're discussing. My low resolution guess was that once we have the plates / edges we can just pass them to a function that uses those to render graphviz or networkx graphs. Let me know if you were asking about something else or see a problem (or no point) with that
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good. Let me push something up and you can give feedback
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just pushed.
If user has arbitrary list[Plate] and list[tuple[VarName, VarName]] then they can use make_graph
or make_networkx
in order to make the graphviz or networkx, respectively.
pm.model_to_graphviz
and pm.model_to_networkx
are still wrappers.
ModelGraph
class can be used to create the plates and edges in the previous manner if desired with the get_plates
and edges
methods
pymc/model_graph.py
Outdated
# parents is a set of rv names that precede child rv nodes | ||
for parent in parents: | ||
yield child.replace(":", "&"), parent.replace(":", "&") | ||
|
||
def make_graph( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That would just remove calling get_plates
and edges
methods. Don't have much of a preference
pymc/model_graph.py
Outdated
# must be preceded by 'cluster' to get a box around it | ||
plate_label = create_plate_label(plate_meta, include_size=include_shape_size) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Noticing that the plate_label actually depends on var_name in case of previous "{var_name}_dim{d}". However, the plate_label is required before the looping of all_var_names
. i.e. all_var_names
is assumed to be one element? Maybe that should be an explicit case?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Didn't manage to follow, can you explain again?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The graph.subgraph
name of "cluster" + plate_label is dependent on the var_name which used to be constructed in the get_plates
method (the previous keys of dictionary where the plate_label).
However, after the subgraph is constructed, the all_var_names
is looped over. This is assuming that all_var_names
is only one element since the plate_label is used in the subgraph name.
pymc/model_graph.py
Outdated
for plate in self.get_plates(var_names): | ||
plate_meta = plate.meta | ||
all_vars = plate.variables | ||
if plate_meta.names or plate_meta.sizes: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we simplify? Could plate_meta be None for the scalar variables?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This logic would still be needed somewhere. Likely in get_plates
then.
How about having the __bool__
method for Plate
class that does this logic.
Then would act like None and read like:
if plate_meta: # Truthy if sizes or names
# plate_meta has sizes or names that are not empty tuples
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You have that information when you defined the plate.meta no? Can't you do it immediately?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes. I changed to have it happen in the get_plates
methods. Scalars will have Plate(meta=None, variables=[...])
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's enough to check for sizes
? It is not possible for a plate to have names, but not sizes?
We should rename those to dim_names, and dim_lengths
. And perhaps use None
for dim_lengths for which we don't know the name?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So IIUC, scalars should belong to a "Plate" with dim_names = (), and dim_lengths = ()?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
And now I understand your approach and I think it was better like you did. The __bool__
sounds fine as well!
Sorry I got confused by the names of the things
pymc/model_graph.py
Outdated
@@ -49,6 +49,9 @@ class PlateMeta: | |||
def __hash__(self): | |||
return hash((self.names, self.sizes)) | |||
|
|||
def __bool__(self) -> bool: | |||
return len(self.sizes) > 0 or len(self.names) > 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is a plate without names?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
with pm.Model():
pm.Normal("y", shape=3)
y has sizes but no dim names.
Currently creates Plate(meta=PlateMeta(names=(), sizes=(5, )), variables=[NodeMeta(var=y, node_type=...)])
Think there should be some cases to test now that this logic is exposed. Will be much easier to confirm
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
names here are dim names
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What happens for a Deterministic with dims=("test_dim", None)
? Apparently we still allow None dims for things that are not RVs
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That y
should be names=(None,)
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm thinking of pm.Deterministic("x", np.zeros((3, 3)), dims=("hello", None))
and pm.Deterministic("y", np.zeros((3, 3)), dims=(None, "hello")
. We don't want to put those in the same plate because dims can't be repeated, so they are definitely different things?
Can we add a test for that?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added a test. I had to wrap the data in as_tensor_variable or I'd get an error saying the data needs name attribute
pymc/model_graph.py
Outdated
plate_meta = PlateMeta( | ||
names=tuple(names), | ||
sizes=tuple(sizes), | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't understand this tbh. Are we creating one plate per variable? But a plate can contain multiple variables?
Also names is ambiguous, it is dim_names? We should name it like that to distinguish from var_names?
Also sizes -> dim_lengths
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah plates are hashable... so you mutate the same thing...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just working with what was there. The historical { str: set[VarName] } is created with loop which I changed to { PlateMeta : set[NodeMeta] }
But switched to list[Plate] ultimately.
Ideally, there could be more straight-foward path to list[Plate]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This logic feels rather convoluted to be honest. Maybe we can take a step back and see what is actually needed.
Step1: Collect the dim names and dim lengths of every variable we want to plot. This seems simple enough, and we can do in a loop
Step2: Merge variables that have identical dim_names and dim_lengths into "plates". The hashable Plate thing may be a good trick to achieve that, or just a defaultdict with keys: tuple[dim_names, dim_lengths]
Would the code be more readable if we didn't try to do both things at once?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Edit: Updated comment above
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
defaultdict with dims_names and dim_lengths is same as what is currently happening. But there is a wrapper class around it. Personally, I find the class helpful and more user friendly. But I could be wrong
For instance,
Plate(
DimInfo(names=("obs", "covariate"), sizes=(10, 5)),
variables=[
NodeInfo(X, node_type=DATA),
NodeInfo(X_transform, node_type=DETERMINISTIC),
NodeInfo(tvp, node_type=FREE_RV),
]
)
over
(("obs", "covariate"), (10, 5), (X, X_transform, tvp), (DATA, DETERMINSTIC, FREE_RV))
lines up a bit better in my mind that the first two are related objects and the last two are related objects as well
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, just thinking of how easy we make it for users to define their custom stuff. Either way seems manageable
The name Is Any thoughts here on terminology? |
I'm okay with Plate or Cluster. Why the Meta in it? |
Meta would be information about the variables / plate to construct a plate label. Previously it was always " x ".join([f"{dname} ({dlen})" for ...] |
I don't love the word meta, it's too abstract. |
I think itd be nice to keep the names and sizes together since they are related. How about DimInfo |
Is the question whether we represent a data structure that looks like (in terms of access): |
This PR refreshed my mind that #6485 and #7048 exist. To summarize: We can have variables that have entries in Then dims can have coords or not, but always have dim_lengths, which always work when we do the fast_eval for |
pymc/model_graph.py
Outdated
plate_label = create_plate_label( | ||
plate.variables[0].var.name, | ||
plate.meta, | ||
include_size=include_shape_size, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should create_plate_label
now take plate_formatters
that among other things decides on whether to include_size?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, I think that is fair. Where do you view that being exposed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I exposed the create_plate_label in both make_graph
and make_networkx
. However, left it out in the model_to_graphviz
and model_to_networkx
functions.
If a user defines Callable[[DimInfo], str]
function, then that can be used in the first two, more general functions
There is also the NodeType which is why I went for the small dataclass wrapper that contains TensorVariable and the preprocessed label. I think have a small data structure isn't the end of the world but also helps structure the problem a bit more. The user can clearly see what is part of the new data structures in my mind |
Need to
The 6335 comes up with this example: # Current main branch
coords = {
"obs": range(5),
}
with pm.Model(coords=coords) as model:
data = pt.as_tensor_variable(
np.ones((5, 3)),
name="data",
)
pm.Deterministic("C", data, dims=("obs", None))
pm.Deterministic("D", data, dims=("obs", None))
pm.Deterministic("E", data, dims=("obs", None))
pm.model_to_graphviz(model) Which makes sense that they will not be on the same plate, right? |
I did just catch this bug: It comes from the from pymc.model_graph import ModelGraph
coords = {
"obs": range(5),
}
with pm.Model(coords=coords) as model:
data = pt.as_tensor_variable(
np.ones((5, 3)),
name="C",
)
pm.Deterministic("C", data, dims=("obs", None))
error_compute_graph = ModelGraph(model).make_compute_graph() # defaultdict(set, {"C": {"C"}})
# Visualize error:
pm.model_to_graphviz(model) Result: Shall I make a separate issue? |
I think they should be in the same plate, because in the absense of dims, the shape is used to cluster RVs? |
Self loop is beautiful :) |
How should the {var_name}_dim{d} be handled then to put them on the same plate? Just "dim{d} ({dlen})"? |
Just the length? how does a plate without any dims look like? I imagine the mix would be 50 x trial(30) or however the trial dim is usually displayed. WDYT? |
This mixing of dlen and "{dname} ({dlen})" is what I had in mind. That is the current behavior. Here are some examples: import numpy as np
import pymc as pm
import pytensor.tensor as pt
coords = {
"obs": range(5),
}
with pm.Model(coords=coords) as model:
data = pt.as_tensor_variable(
np.ones((5, 3)),
name="data",
)
C = pm.Deterministic("C", data, dims=("obs", None))
D = pm.Deterministic("D", data, dims=("obs", None))
E = pm.Deterministic("E", data, dims=("obs", None))
pm.model_to_graphviz(model) # Same as above
pm.model_to_graphviz(model, include_dim_lengths=False) And larger example with various items: import numpy as np
import pymc as pm
import pytensor.tensor as pt
coords = {
"obs": range(5),
"covariates": ["X1", "X2", "X3"],
}
with pm.Model(coords=coords) as model:
data1 = pt.as_tensor_variable(
np.ones((5, 3)),
name="data1",
)
data2 = pt.as_tensor_variable(
np.ones((5, 3)),
name="data2",
)
C = pm.Deterministic("C", data1, dims=("obs", None))
CT = pm.Deterministic("CT", C.T, dims=(None, "obs"))
D = pm.Deterministic("D", C @ CT, dims=("obs", "obs"))
E = pm.Deterministic("E", data2, dims=("obs", None))
beta = pm.Normal("beta", dims="covariates")
pm.Deterministic("product", E[:, None, :] * beta[:, None], dims=("obs", None, "covariates"))
pm.model_to_graphviz(model) |
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #7392 +/- ##
==========================================
- Coverage 92.19% 92.18% -0.01%
==========================================
Files 103 103
Lines 17214 17249 +35
==========================================
+ Hits 15870 15901 +31
- Misses 1344 1348 +4
|
Thanks @wd60622 |
Description
Pulled any graph related information into two methods:
get_plates
: Get plate meta information and the nodes that are associated with each plateedges
: Edges between nodes as a list[tuple[VarName, VarName]]The
get_plates
methods returns a list ofPlate
objects which store all the variable information. That data include:DimInfo
with stores the dim names and lengthsNodeInfo
which stores the model variable and it's NodeType in the graph (introduced in Allow customizing style of model_graph nodes #7302)Plate
which is a collection of theDimInfo
andlist[NodeInfo]
With
list[tuple[VarName, VarName]]
andlist[Plate]
, a user can now make use of the exposedmake_graph
andmake_networkx
functions to create customized graphviz or networkx graphs.The previous behavior of
model_to_graphviz
andmodel_to_networkx
is still maintained. However, there is a newinclude_dim_lengths
parameter that can be used to include the dim lengths in the plate labels.The previous issue #6335 behavior has changed to now include all the variables on a plate with dlen instead of var_name_dim{d}. (See examples below)
Related Issue
Checklist
Type of change
📚 Documentation preview 📚: https://pymc--7392.org.readthedocs.build/en/7392/