Skip to content

ModelComparisonSimulator: handle different outputs from individual simulators #452

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
May 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 76 additions & 3 deletions bayesflow/simulators/model_comparison_simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@
from bayesflow.utils.decorators import allow_batch_size

from bayesflow.utils import numpy_utils as npu
from bayesflow.utils import logging

from types import FunctionType
from typing import Literal

from .simulator import Simulator
from .lambda_simulator import LambdaSimulator
Expand All @@ -22,6 +24,8 @@ def __init__(
p: Sequence[float] = None,
logits: Sequence[float] = None,
use_mixed_batches: bool = True,
key_conflicts: Literal["drop", "fill", "error"] = "drop",
fill_value: float = np.nan,
shared_simulator: Simulator | Callable[[Sequence[int]], dict[str, any]] = None,
):
"""
Expand All @@ -38,11 +42,21 @@ def __init__(
A sequence of logits corresponding to model probabilities. Mutually exclusive with `p`.
If neither `p` nor `logits` is provided, defaults to uniform logits.
use_mixed_batches : bool, optional
If True, samples in a batch are drawn from different models. If False, the entire batch
is drawn from a single model chosen according to the model probabilities. Default is True.
Whether to draw samples in a batch from different models.

- If True (default), each sample in a batch may come from a different model.
- If False, the entire batch is drawn from a single model, selected according to model probabilities.
key_conflicts : str, optional
Policy for handling keys that are missing in the output of some models, when using mixed batches.

- "drop" (default): Drop conflicting keys from the batch output.
- "fill": Fill missing keys with the specified value.
- "error": An error is raised when key conflicts are detected.
fill_value : float, optional
If `key_conflicts=="fill"`, the missing keys will be filled with the value of this argument.
shared_simulator : Simulator or Callable, optional
A shared simulator whose outputs are passed to all model simulators. If a function is
provided, it is wrapped in a `LambdaSimulator` with batching enabled.
provided, it is wrapped in a :py:class:`~bayesflow.simulators.LambdaSimulator` with batching enabled.
"""
self.simulators = simulators

Expand All @@ -68,6 +82,9 @@ def __init__(

self.logits = logits
self.use_mixed_batches = use_mixed_batches
self.key_conflicts = key_conflicts
self.fill_value = fill_value
self._key_conflicts_warning = True

@allow_batch_size
def sample(self, batch_shape: Shape, **kwargs) -> dict[str, np.ndarray]:
Expand Down Expand Up @@ -105,6 +122,7 @@ def sample(self, batch_shape: Shape, **kwargs) -> dict[str, np.ndarray]:
sims = [
simulator.sample(n, **(kwargs | data)) for simulator, n in zip(self.simulators, model_counts) if n > 0
]
sims = self._handle_key_conflicts(sims, model_counts)
sims = tree_concatenate(sims, numpy=True)
data |= sims

Expand All @@ -118,3 +136,58 @@ def sample(self, batch_shape: Shape, **kwargs) -> dict[str, np.ndarray]:
model_indices = npu.one_hot(np.full(batch_shape, model_index, dtype="int32"), num_models)

return data | {"model_indices": model_indices}

def _handle_key_conflicts(self, sims, batch_sizes):
batch_sizes = [b for b in batch_sizes if b > 0]

keys, all_keys, common_keys, missing_keys = self._determine_key_conflicts(sims=sims)

# all sims have the same keys
if all_keys == common_keys:
return sims

if self.key_conflicts == "drop":
sims = [{k: v for k, v in sim.items() if k in common_keys} for sim in sims]
return sims
elif self.key_conflicts == "fill":
combined_sims = {}
for sim in sims:
combined_sims = combined_sims | sim
for i, sim in enumerate(sims):
for missing_key in missing_keys[i]:
shape = combined_sims[missing_key].shape
shape = list(shape)
shape[0] = batch_sizes[i]
sim[missing_key] = np.full(shape=shape, fill_value=self.fill_value)
return sims
elif self.key_conflicts == "error":
raise ValueError(
"Different simulators provide outputs with different keys, cannot combine them into one batch."
)

def _determine_key_conflicts(self, sims):
keys = [set(sim.keys()) for sim in sims]
all_keys = set.union(*keys)
common_keys = set.intersection(*keys)
missing_keys = [all_keys - k for k in keys]

if all_keys == common_keys:
return keys, all_keys, common_keys, missing_keys

if self._key_conflicts_warning:
# issue warning only once
self._key_conflicts_warning = False

if self.key_conflicts == "drop":
logging.info(
f"Incompatible simulator output. \
The following keys will be dropped: {', '.join(sorted(all_keys - common_keys))}."
)
elif self.key_conflicts == "fill":
logging.info(
f"Incompatible simulator output. \
Attempting to replace keys: {', '.join(sorted(all_keys - common_keys))}, where missing, \
with value {self.fill_value}."
)

return keys, all_keys, common_keys, missing_keys
50 changes: 50 additions & 0 deletions tests/test_simulators/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,56 @@ def likelihood(mu, n):
return make_simulator([prior, likelihood], meta_fn=context)


@pytest.fixture()
def multimodel():
from bayesflow.simulators import make_simulator, ModelComparisonSimulator

def context(batch_size):
return dict(n=np.random.randint(10, 100))

def prior_0():
return dict(mu=0)

def prior_1():
return dict(mu=np.random.standard_normal())

def likelihood(n, mu):
return dict(y=np.random.normal(mu, 1, n))

simulator_0 = make_simulator([prior_0, likelihood])
simulator_1 = make_simulator([prior_1, likelihood])

simulator = ModelComparisonSimulator(simulators=[simulator_0, simulator_1], shared_simulator=context)

return simulator


@pytest.fixture(params=["drop", "fill", "error"])
def multimodel_key_conflicts(request):
from bayesflow.simulators import make_simulator, ModelComparisonSimulator

rng = np.random.default_rng()

def prior_1():
return dict(w=rng.uniform())

def prior_2():
return dict(c=rng.uniform())

def model_1(w):
return dict(x=w)

def model_2(c):
return dict(x=c)

simulator_1 = make_simulator([prior_1, model_1])
simulator_2 = make_simulator([prior_2, model_2])

simulator = ModelComparisonSimulator(simulators=[simulator_1, simulator_2], key_conflicts=request.param)

return simulator


@pytest.fixture()
def fixed_n():
return 5
Expand Down
22 changes: 22 additions & 0 deletions tests/test_simulators/test_simulators.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import pytest
import keras
import numpy as np

Expand Down Expand Up @@ -47,3 +48,24 @@ def test_fixed_sample(composite_gaussian, batch_size, fixed_n, fixed_mu):
assert samples["mu"].shape == (batch_size, 1)
assert np.all(samples["mu"] == fixed_mu)
assert samples["y"].shape == (batch_size, fixed_n)


def test_multimodel_sample(multimodel, batch_size):
samples = multimodel.sample(batch_size)

assert set(samples) == {"n", "mu", "y", "model_indices"}
assert samples["mu"].shape == (batch_size, 1)
assert samples["y"].shape == (batch_size, samples["n"])


def test_multimodel_key_conflicts_sample(multimodel_key_conflicts, batch_size):
if multimodel_key_conflicts.key_conflicts == "drop":
samples = multimodel_key_conflicts.sample(batch_size)
assert set(samples) == {"x", "model_indices"}
elif multimodel_key_conflicts.key_conflicts == "fill":
samples = multimodel_key_conflicts.sample(batch_size)
assert set(samples) == {"x", "model_indices", "c", "w"}
assert np.sum(np.isnan(samples["c"])) + np.sum(np.isnan(samples["w"])) == batch_size
elif multimodel_key_conflicts.key_conflicts == "error":
with pytest.raises(ValueError):
samples = multimodel_key_conflicts.sample(batch_size)