-
Notifications
You must be signed in to change notification settings - Fork 72
Add support for graphical simulators #487
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev
Are you sure you want to change the base?
Add support for graphical simulators #487
Conversation
Codecov ReportAll modified and coverable lines are covered by tests ✅
|
I just thought about an alternative interface, would be glad for some feedback: # we could make passing graph optional later
def sample_tau(graph):
tau = np.abs(np.random.normal())
return tau
def sample_omega(graph):
omega = np.abs(np.random.normal())
return omega
def sample_lambda(graph):
ncol = graph.sample_int("ncol", 5, 10)
tau = graph.sample_var("tau", shape=(ncol,))
lamb = np.random.normal(loc=0, scale=tau)
return lamb
def sample_x(graph):
nrow = graph.sample_int("nrow", 1, 10)
ncol = graph.sample_int("ncol", 5, 10) # cached between sample_lambda and sample_x
lamb = graph.sample_var("lambda", shape=(1, ncol))
omega = graph.sample_var("omega", shape=(1, 1))
x = np.random.normal(loc=lamb, scale=omega, size=(nrow, ncol))
return x
graph = GraphSimulator()
# each node returns exactly one variable
graph.add_node("tau", sample_tau)
graph.add_node("omega", sample_omega)
graph.add_node("lambda", sample_lambda)
graph.add_node("x", sample_x)
# returns a list of dicts
samples = graph.sample(10) I already have a working implementation for this, but it might not exactly fit the needs of the rest of the library, so I would like to discuss first. |
Thank you @daniel-habermann for your PR! I will review it (specifically the interface) in the next couple of days. @LarsKue since Daniel already spend quite a bit of time and thought for this PR, I would like to go with his implementation for now. If we end up not liking it for some reason, we can still discuss alternatives then. |
I'm not against radically changing the suggested interface, but one design consideration is consistency: In the current interface, a user can return dictionaries with arbitrary keys, so it would be quite difficult to explain why this is possible when using The same is true for What problem were you trying to resolve with your suggestion? If the concern is boilerplate and adding edges to the networks, I expect we can resolve almost all cases with code introspection, i.e. all a user has to provide are function definitions as in the usual |
Just tagging everyone :) @paul-buerkner @stefanradev93 @LarsKue @elseml @arrjon I've finally reached a stage where I'm happy with the general design and thought it would be a good moment to get your feedback. from bayesflow.experimental.graphical_simulator import GraphicalSimulator
import numpy as np
def sample_tau():
tau = np.abs(np.random.normal())
return dict(tau=tau)
def sample_omega():
omega = np.abs(np.random.normal())
return dict(omega=omega)
def sample_lambda_j(tau):
lambda_j = np.abs(np.random.normal(loc=0, scale=tau))
return dict(lambda_j=lambda_j)
def sample_x_ij(lambda_j, omega):
x_ij = np.random.normal(loc=lambda_j, scale=omega)
return dict(x_ij=x_ij)
def meta():
return {
"num_groups": np.random.randint(5, 10),
"num_obs": np.random.randint(1, 10)
}
simulator = GraphicalSimulator(meta_fn=meta)
simulator.add_node("tau", sampling_fn=sample_tau)
simulator.add_node("omega", sampling_fn=sample_omega)
simulator.add_node("lambda_j", sampling_fn=sample_lambda_j, reps="num_groups")
simulator.add_node("x_ij", sampling_fn=sample_x_ij, reps="num_obs")
simulator.add_edge("tau", "lambda_j")
simulator.add_edge("lambda_j", "x_ij")
simulator.add_edge("omega", "x_ij") Major changes to the previous versions are:
The main design goal was consistency with our other interfaces. Concretely, I wanted the output of a single-level model implemented as a def prior():
beta = np.random.normal([2, 0], [3, 1])
sigma = np.random.gamma(1, 1)
return {"beta": beta, "sigma": sigma}
def likelihood(beta, sigma, N):
x = np.random.normal(0, 1, size=N)
y = np.random.normal(beta[0] + beta[1] * x, sigma, size=N)
return {"x": x, "y": y}
def meta():
N = np.random.randint(5, 15)
return {"N": N}
simulator = GraphicalSimulator(meta_fn=meta)
simulator.add_node("prior", sampling_fn=prior)
simulator.add_node("likelihood", sampling_fn=likelihood)
simulator.add_edge("prior", "likelihood")
sim_draws = simulator.sample(500)
sim_draws["N"] # 13
sim_draws["beta"].shape # (500, 2)
sim_draws["sigma"].shape # (500, 1)
sim_draws["x"].shape # (500, 13)
sim_draws["y"].shape # (500, 13) Of course, the preferred way to define the number of observations would be to remove To channel feedback, here is a list of points that I believe are most important to agree on (but of course all other comments are also highly welcome): How do we vary the number of groups and observations during training?Our current simulators just return data sets of varying number of observations. This is fine for online training, but doesn't work well for offline training, which was always the default for my workflows. An alternative would be to allow simulation of non-rectangular datasets. The internal representation of the What is the output of
|
Hi Daniel, the interface is a fine job! Here are some more-or-less detailed thoughts regarding your questions: How do we vary the number of groups and observations during training?
What is the output of GraphicalSimulator.sample?
Do we want to allow repetitions of root nodes?
|
I discussed with @daniel-habermann this week and we agree with @stefanradev93. I believe everything is in place now. @daniel-habermann what are the next steps for this PR? |
I've just renamed the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you Daniel! Some quick comments and requests from my side. @stefanradev93 could you provide another review focusing more in the python code patterns (that I cannot properly review yet)?
bayesflow/experimental/graphical_simulator/graphical_simulator.py
Outdated
Show resolved
Hide resolved
bayesflow/experimental/graphical_simulator/example_simulators/irt.py
Outdated
Show resolved
Hide resolved
bayesflow/experimental/graphical_simulator/example_simulators/irt.py
Outdated
Show resolved
Hide resolved
bayesflow/experimental/graphical_simulator/example_simulators/single_level.py
Outdated
Show resolved
Hide resolved
bayesflow/experimental/graphical_simulator/example_simulators/two_level.py
Outdated
Show resolved
Hide resolved
bayesflow/experimental/graphical_simulator/example_simulators/two_level_repeated_roots.py
Outdated
Show resolved
Hide resolved
bayesflow/experimental/graphical_simulator/graphical_simulator.py
Outdated
Show resolved
Hide resolved
bayesflow/experimental/graphical_simulator/graphical_simulator.py
Outdated
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks really cool as far as I can judge from the outside!
Both of your suggested solutions for varying numbers in offline data sound reasonable to me. I also have a strong preference for retaining the additional dimensions output shape approach from BF v1.
Some thoughts on user friendliness (fine to delegate to for future work) since the graph specification approach might be quite unfamiliar for more applied users:
- Would it make sense to provide explicit error messages for errors related to graph construction (like cycles) or are the networkx error messages already expressive enough?
- Could we provide an automatic check with a corresponding warning for unused/unconnected nodes (e.g., via a
_validate_graph()
function)? - It might be easy to miss some edges for more complex probabilistic structures -> would it be feasible to add a visual diagnostic that automatically visualizes the graph? From what I saw on networkx visualization, this can get quite cluttered, but I guess the comparably sparse structure of DAGs helps here.
bayesflow/experimental/graphical_simulator/graphical_simulator.py
Outdated
Show resolved
Hide resolved
bayesflow/experimental/graphical_simulator/graphical_simulator.py
Outdated
Show resolved
Hide resolved
Hey, thank you for the comments! Part of the GraphicalApproximator design will include a module for graph introspection (e.g. inferring inverse structures, factorizations, conditional dependencies..), and all of this functionality will probably live in there. I also have some basic plotting capabilities, it is actually not that bad, even for larger graphs. While not the focus currently, it is also possible to entirely infer the graph structure from the input and output signatures of the sample functions. In this way, no graph structure has to be specified explicitly and it eventually could be a drop-in replacement for the current |
…ely shaped numpy arrays
…t dimensionality rules for sample method
…lect renamed sample method argument
5557026
to
a3c1fb6
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me now. @stefanradev93 can you review it too so we can then merge?
WIP pull request to add support for graphical simulators. I'm going to update this message and tag some people once it has reached a state where it makes sense to read further. All discussion and feedback welcome!
Summary and Motivation
This PR introduces initial support for graphical simulators. The main idea is to represent a complex simulation program as a directed acyclic graph (DAG), where nodes represent sets of parameters and edges denote conditional dependencies.$p(\theta)$ can often be expressed in the form of some factorization along a DAG $p(\theta_1, \dots, \theta_N) = \prod_{i=1}^{N} p(\theta_i | \text{Parents}(\theta_i))$ .
Such a structure is a natural representation for many Bayesian models because the joint distribution of parameters
The benefit of making these dependency structures explicit is that the converse is also true: By stating the conditional dependencies, a corresponding DAG also encodes conditional independencies implied by the distribution, which we can then use to automatically build efficient network architectures, for example for multilevel models.
Current Implementation
Consider a standard two-level hierarchical model:
Such a model can be represented by the following diagram:
where the dashed boxes denote that parameters are exchangeable. Currently, such a diagram would be implemented like this:
Design space
There is still a long list of design choices:
How to determine how often each node is executed for each batch?
For multilevel models, we want to vary the number of groups and observations within each group for each batch. Currently, this is achieved by the
sample_size
function argument, which expects a callable returning an integer.One question is if we even need such an argument, or could remove it by relying on something like the current
meta_fn
.If we go the
meta_fn
route, do we have a singlemeta_fn
for each node or a global oneHow can we represent more exotic models or non-DAG structures, like state space models
How do we handle which nodes return observed data?
This becomes important when talking about graph inversions. Currently, we can attach arbitrary metadata to each node and the graph inversion algorithm searches for an "observed" keyword, but from a user perspective this should probably be improved. We might even not care about this at all because the adapter defines
summary_conditions
orinference_conditions
.How is all of this represented internally?
The resulting data structure is non-rectangular because each batch might have a different number of calls for each node.