Skip to content

Add support for graphical simulators #487

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 18 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
20b54ce
initial commit
daniel-habermann May 23, 2025
02cc915
add networkx to project dependencies
daniel-habermann Jun 2, 2025
be9f390
initial implementation GraphicalSimulator
daniel-habermann Jun 15, 2025
4194062
samples method of GraphicalSimulator now returns a dict of appropriat…
daniel-habermann Jun 19, 2025
60e589a
add irt_simulator and threelevel_simulator
daniel-habermann Jun 20, 2025
d1624ee
enable sampling_fn with no arguments for non root nodes, change outpu…
daniel-habermann Jun 20, 2025
ae105f6
add one and three-level example simulators
daniel-habermann Jun 20, 2025
d8ac4fd
allow root node repetitions
daniel-habermann Jun 27, 2025
56f7681
export GraphicalSimulator
daniel-habermann Jun 28, 2025
e59b2b2
rename sampling_fn argument to sample_fn in GraphicalSimulator.add_no…
daniel-habermann Jun 28, 2025
9284333
move example simulators to own submodule
daniel-habermann Jun 28, 2025
b75c5f5
add unit tests for single level graphical model
daniel-habermann Jun 28, 2025
1243f8c
add unit tests for two_level and irt graphical simulators
daniel-habermann Jun 28, 2025
55b6dfd
rename GraphicalSimulator._call_sampling_fn to _call_sample_fn to ref…
daniel-habermann Jun 29, 2025
4035c16
rename examples in graphical_simulator.example_simulators
daniel-habermann Jun 29, 2025
b5a653c
update description of **kwargs parameter in GraphicalSimulator.sample…
daniel-habermann Jun 29, 2025
a3c1fb6
remove unused is_root_node function
daniel-habermann Jun 29, 2025
c5044c1
use 0-based index for internal representation
daniel-habermann Jul 6, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions bayesflow/experimental/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@
Unstable or largely untested networks, proceed with caution.
"""

from ..utils._docs import _add_imports_to_all
from .cif import CIF
from .continuous_time_consistency_model import ContinuousTimeConsistencyModel
from .diffusion_model import DiffusionModel
from .free_form_flow import FreeFormFlow

from ..utils._docs import _add_imports_to_all
from .graphical_simulator import GraphicalSimulator

_add_imports_to_all(include_modules=["diffusion_model"])
2 changes: 2 additions & 0 deletions bayesflow/experimental/graphical_simulator/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .graphical_simulator import GraphicalSimulator
from . import example_simulators
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .single_level_simulator import single_level_simulator
from .two_level_simulator import two_level_simulator
from .crossed_design_irt_simulator import crossed_design_irt_simulator
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
import numpy as np

from ..graphical_simulator import GraphicalSimulator


def crossed_design_irt_simulator():
r"""
Item Response Theory (IRT) model implemented as a graphical simulator.

schools
/ \
exams students
| |
questions |
\ /
observations
"""

# schools have different exam difficulties
def sample_school():
mu_exam_mean = np.random.normal(loc=1.1, scale=0.2)
sigma_exam_mean = abs(np.random.normal(loc=0, scale=1))

# hierarchical mu/sigma for the exam difficulty standard deviation (logscale)
mu_exam_std = np.random.normal(loc=0.5, scale=0.3)
sigma_exam_std = abs(np.random.normal(loc=0, scale=0.5))

return dict(
mu_exam_mean=mu_exam_mean,
sigma_exam_mean=sigma_exam_mean,
mu_exam_std=mu_exam_std,
sigma_exam_std=sigma_exam_std,
)

# exams have different question difficulties
def sample_exam(mu_exam_mean, sigma_exam_mean, mu_exam_std, sigma_exam_std):
# mean question difficulty for an exam
exam_mean = np.random.normal(loc=mu_exam_mean, scale=sigma_exam_mean)

# standard deviation of question difficulty
log_exam_std = np.random.normal(loc=mu_exam_std, scale=sigma_exam_std)
exam_std = float(np.exp(log_exam_std))

return dict(exam_mean=exam_mean, exam_std=exam_std)

# realizations of individual question difficulties
def sample_question(exam_mean, exam_std):
question_difficulty = np.random.normal(loc=exam_mean, scale=exam_std)

return dict(question_difficulty=question_difficulty)

# realizations of individual student abilities
def sample_student():
student_ability = np.random.normal(loc=0, scale=1)

return dict(student_ability=student_ability)

# realizations of individual observations
def sample_observation(question_difficulty, student_ability):
theta = np.exp(question_difficulty + student_ability) / (1 + np.exp(question_difficulty + student_ability))

obs = np.random.binomial(n=1, p=theta)

return dict(obs=obs)

def meta_fn():
return {
"num_exams": np.random.randint(2, 4),
"num_questions": np.random.randint(10, 21),
"num_students": np.random.randint(100, 201),
}

simulator = GraphicalSimulator(meta_fn=meta_fn)

simulator.add_node("schools", sample_fn=sample_school)
simulator.add_node("exams", sample_fn=sample_exam, reps="num_exams")
simulator.add_node("questions", sample_fn=sample_question, reps="num_questions")
simulator.add_node("students", sample_fn=sample_student, reps="num_students")
simulator.add_node("observations", sample_fn=sample_observation)

simulator.add_edge("schools", "exams")
simulator.add_edge("schools", "students")
simulator.add_edge("exams", "questions")
simulator.add_edge("questions", "observations")
simulator.add_edge("students", "observations")

return simulator
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import numpy as np

from ..graphical_simulator import GraphicalSimulator


def single_level_simulator():
"""
Simple single-level simulator that implements the same model as in
https://bayesflow.org/main/_examples/Linear_Regression_Starter.html
"""

def prior():
beta = np.random.normal([2, 0], [3, 1])
sigma = np.random.gamma(1, 1)

return {"beta": beta, "sigma": sigma}

def likelihood(beta, sigma, N):
x = np.random.normal(0, 1, size=N)
y = np.random.normal(beta[0] + beta[1] * x, sigma, size=N)

return {"x": x, "y": y}

def meta():
N = np.random.randint(5, 15)

return {"N": N}

simulator = GraphicalSimulator(meta_fn=meta)

simulator.add_node("prior", sample_fn=prior)
simulator.add_node("likelihood", sample_fn=likelihood)

simulator.add_edge("prior", "likelihood")

return simulator
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import numpy as np

from ..graphical_simulator import GraphicalSimulator


def two_level_simulator(repeated_roots=False):
r"""
Simple hierarchical model with two levels of parameters: hyperparameters
and local parameters, along with a shared parameter:

hypers
|
locals shared
\ /
\ /
y

Parameters
----------
repeated_roots : bool, default false.

"""

def sample_hypers():
hyper_mean = np.random.normal()
hyper_std = np.abs(np.random.normal())

return {"hyper_mean": hyper_mean, "hyper_std": hyper_std}

def sample_locals(hyper_mean, hyper_std):
local_mean = np.random.normal(hyper_mean, hyper_std)

return {"local_mean": local_mean}

def sample_shared():
shared_std = np.abs(np.random.normal())

return {"shared_std": shared_std}

def sample_y(local_mean, shared_std):
y = np.random.normal(local_mean, shared_std)

return {"y": y}

simulator = GraphicalSimulator()

if not repeated_roots:
simulator.add_node("hypers", sample_fn=sample_hypers)
else:
simulator.add_node("hypers", sample_fn=sample_hypers, reps=5)

simulator.add_node(
"locals",
sample_fn=sample_locals,
reps=6,
)

simulator.add_node("shared", sample_fn=sample_shared)
simulator.add_node("y", sample_fn=sample_y, reps=10)

simulator.add_edge("hypers", "locals")
simulator.add_edge("locals", "y")
simulator.add_edge("shared", "y")

return simulator
Loading