Skip to content

Ensembles of Approximators #532

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 8 commits into
base: dev
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion bayesflow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
)

from .adapters import Adapter
from .approximators import ContinuousApproximator, PointApproximator
from .approximators import ContinuousApproximator, PointApproximator, ApproximatorEnsemble
from .datasets import OfflineDataset, OnlineDataset, DiskDataset
from .simulators import make_simulator
from .workflows import BasicWorkflow
Expand Down
2 changes: 2 additions & 0 deletions bayesflow/approximators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from .point_approximator import PointApproximator
from .model_comparison_approximator import ModelComparisonApproximator

from .approximator_ensemble import ApproximatorEnsemble

from ..utils._docs import _add_imports_to_all

_add_imports_to_all(include_modules=[])
3 changes: 1 addition & 2 deletions bayesflow/approximators/approximator.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,6 @@ def fit(self, *, dataset: keras.utils.PyDataset = None, simulator: Simulator = N
logging.info("Building on a test batch.")
mock_data = dataset[0]
mock_data = keras.tree.map_structure(keras.ops.convert_to_tensor, mock_data)
mock_data_shapes = keras.tree.map_structure(keras.ops.shape, mock_data)
self.build(mock_data_shapes)
self.build_from_data(mock_data)

return super().fit(dataset=dataset, **kwargs)
127 changes: 127 additions & 0 deletions bayesflow/approximators/approximator_ensemble.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
from collections.abc import Mapping

import numpy as np

import keras

from bayesflow.types import Tensor


from .approximator import Approximator


class ApproximatorEnsemble(Approximator):
def __init__(self, approximators: dict[str, Approximator], **kwargs):
super().__init__(**kwargs)

Check warning on line 15 in bayesflow/approximators/approximator_ensemble.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/approximators/approximator_ensemble.py#L15

Added line #L15 was not covered by tests

self.approximators = approximators

Check warning on line 17 in bayesflow/approximators/approximator_ensemble.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/approximators/approximator_ensemble.py#L17

Added line #L17 was not covered by tests

self.num_approximators = len(self.approximators)

Check warning on line 19 in bayesflow/approximators/approximator_ensemble.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/approximators/approximator_ensemble.py#L19

Added line #L19 was not covered by tests

def build_from_data(self, adapted_data: dict[str, any]):
data_shapes = keras.tree.map_structure(keras.ops.shape, adapted_data)
if len(data_shapes["inference_variables"]) > 2:

Check warning on line 23 in bayesflow/approximators/approximator_ensemble.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/approximators/approximator_ensemble.py#L22-L23

Added lines #L22 - L23 were not covered by tests
# Remove the ensemble dimension from data_shapes. This expects data_shapes are the shapes of a
# batch of training data, where the second axis corresponds to different approximators.
data_shapes = {k: v[:1] + v[2:] for k, v in data_shapes.items()}
self.build(data_shapes)

Check warning on line 27 in bayesflow/approximators/approximator_ensemble.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/approximators/approximator_ensemble.py#L26-L27

Added lines #L26 - L27 were not covered by tests

def build(self, data_shapes: dict[str, tuple[int] | dict[str, dict]]) -> None:
for approximator in self.approximators.values():
approximator.build(data_shapes)

Check warning on line 31 in bayesflow/approximators/approximator_ensemble.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/approximators/approximator_ensemble.py#L30-L31

Added lines #L30 - L31 were not covered by tests

def compute_metrics(
self,
inference_variables: Tensor,
inference_conditions: Tensor = None,
summary_variables: Tensor = None,
sample_weight: Tensor = None,
stage: str = "training",
) -> dict[str, dict[str, Tensor]]:
# Prepare empty dict for metrics
metrics = {}

Check warning on line 42 in bayesflow/approximators/approximator_ensemble.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/approximators/approximator_ensemble.py#L42

Added line #L42 was not covered by tests

# Define the variable slices as None (default) or respective input
_inference_variables = inference_variables
_inference_conditions = inference_conditions
_summary_variables = summary_variables
_sample_weight = sample_weight

Check warning on line 48 in bayesflow/approximators/approximator_ensemble.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/approximators/approximator_ensemble.py#L45-L48

Added lines #L45 - L48 were not covered by tests

for i, (approx_name, approximator) in enumerate(self.approximators.items()):

Check warning on line 50 in bayesflow/approximators/approximator_ensemble.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/approximators/approximator_ensemble.py#L50

Added line #L50 was not covered by tests
# During training each approximator receives its own separate slice
if stage == "training":

Check warning on line 52 in bayesflow/approximators/approximator_ensemble.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/approximators/approximator_ensemble.py#L52

Added line #L52 was not covered by tests
# Pick out the correct slice for each ensemble member
_inference_variables = inference_variables[:, i]
if inference_conditions is not None:
_inference_conditions = inference_conditions[:, i]
if summary_variables is not None:
_summary_variables = summary_variables[:, i]
if sample_weight is not None:
_sample_weight = sample_weight[:, i]

Check warning on line 60 in bayesflow/approximators/approximator_ensemble.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/approximators/approximator_ensemble.py#L54-L60

Added lines #L54 - L60 were not covered by tests

metrics[approx_name] = approximator.compute_metrics(

Check warning on line 62 in bayesflow/approximators/approximator_ensemble.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/approximators/approximator_ensemble.py#L62

Added line #L62 was not covered by tests
inference_variables=_inference_variables,
inference_conditions=_inference_conditions,
summary_variables=_summary_variables,
sample_weight=_sample_weight,
stage=stage,
)

# Flatten metrics dict
joint_metrics = {}
for approx_name in metrics.keys():
for metric_key, value in metrics[approx_name].items():
joint_metrics[f"{approx_name}/{metric_key}"] = value
metrics = joint_metrics

Check warning on line 75 in bayesflow/approximators/approximator_ensemble.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/approximators/approximator_ensemble.py#L71-L75

Added lines #L71 - L75 were not covered by tests

# Sum over losses
losses = [v for k, v in metrics.items() if "loss" in k]
metrics["loss"] = keras.ops.sum(losses)

Check warning on line 79 in bayesflow/approximators/approximator_ensemble.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/approximators/approximator_ensemble.py#L78-L79

Added lines #L78 - L79 were not covered by tests

return metrics

Check warning on line 81 in bayesflow/approximators/approximator_ensemble.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/approximators/approximator_ensemble.py#L81

Added line #L81 was not covered by tests

def sample(
self,
*,
num_samples: int,
conditions: Mapping[str, np.ndarray],
split: bool = False,
**kwargs,
) -> dict[str, dict[str, np.ndarray]]:
samples = {}
for approx_name, approximator in self.approximators.items():
if self._has_obj_method(approximator, "sample"):
samples[approx_name] = approximator.sample(

Check warning on line 94 in bayesflow/approximators/approximator_ensemble.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/approximators/approximator_ensemble.py#L91-L94

Added lines #L91 - L94 were not covered by tests
num_samples=num_samples, conditions=conditions, split=split, **kwargs
)
return samples

Check warning on line 97 in bayesflow/approximators/approximator_ensemble.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/approximators/approximator_ensemble.py#L97

Added line #L97 was not covered by tests

def log_prob(self, data: Mapping[str, np.ndarray], **kwargs) -> dict[str, np.ndarray]:
log_prob = {}
for approx_name, approximator in self.approximators.items():
if self._has_obj_method(approximator, "log_prob"):
log_prob[approx_name] = approximator.log_prob(data=data, **kwargs)
return log_prob

Check warning on line 104 in bayesflow/approximators/approximator_ensemble.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/approximators/approximator_ensemble.py#L100-L104

Added lines #L100 - L104 were not covered by tests

def estimate(
self,
conditions: Mapping[str, np.ndarray],
split: bool = False,
**kwargs,
) -> dict[str, dict[str, dict[str, np.ndarray]]]:
estimates = {}
for approx_name, approximator in self.approximators.items():
if self._has_obj_method(approximator, "estimate"):
estimates[approx_name] = approximator.estimate(conditions=conditions, split=split, **kwargs)
return estimates

Check warning on line 116 in bayesflow/approximators/approximator_ensemble.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/approximators/approximator_ensemble.py#L112-L116

Added lines #L112 - L116 were not covered by tests

def _has_obj_method(self, obj, name):
method = getattr(obj, name, None)
return callable(method)

Check warning on line 120 in bayesflow/approximators/approximator_ensemble.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/approximators/approximator_ensemble.py#L119-L120

Added lines #L119 - L120 were not covered by tests

def _batch_size_from_data(self, data: Mapping[str, any]) -> int:
"""
Fetches the current batch size from an input dictionary. Can only be used during training when
inference variables as present.
"""
return keras.ops.shape(data["inference_variables"])[0]

Check warning on line 127 in bayesflow/approximators/approximator_ensemble.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/approximators/approximator_ensemble.py#L127

Added line #L127 was not covered by tests
2 changes: 1 addition & 1 deletion bayesflow/approximators/model_comparison_approximator.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def build(self, data_shapes: dict[str, tuple[int] | dict[str, dict]]) -> None:
self.built = True

def build_from_data(self, adapted_data: dict[str, any]):
self.build(keras.tree.map_structure(keras.ops.shape(adapted_data)))
self.build(keras.tree.map_structure(keras.ops.shape, adapted_data))

@classmethod
def build_adapter(
Expand Down
1 change: 1 addition & 0 deletions bayesflow/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
"""

from .offline_dataset import OfflineDataset
from .offline_ensemble_dataset import OfflineEnsembleDataset
from .online_dataset import OnlineDataset
from .disk_dataset import DiskDataset

Expand Down
30 changes: 30 additions & 0 deletions bayesflow/datasets/offline_ensemble_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import numpy as np

from .offline_dataset import OfflineDataset


class OfflineEnsembleDataset(OfflineDataset):
"""
A dataset that is pre-simulated and stored in memory, extending :py:class:`OfflineDataset`.

The only difference is that it allows to train an :py:class:`ApproximatorEnsemble` in parallel by returning
batches with ``num_ensemble`` different random subsets of the available data.
"""

def __init__(self, num_ensemble: int, **kwargs):
super().__init__(**kwargs)
self.num_ensemble = num_ensemble

Check warning on line 16 in bayesflow/datasets/offline_ensemble_dataset.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/datasets/offline_ensemble_dataset.py#L15-L16

Added lines #L15 - L16 were not covered by tests

# Create indices with shape (num_samples, num_ensemble)
_indices = np.arange(self.num_samples, dtype="int64")
_indices = np.repeat(_indices[:, None], self.num_ensemble, axis=1)

Check warning on line 20 in bayesflow/datasets/offline_ensemble_dataset.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/datasets/offline_ensemble_dataset.py#L19-L20

Added lines #L19 - L20 were not covered by tests

# Shuffle independently along second axis
for i in range(self.num_ensemble):
np.random.shuffle(_indices[:, i])

Check warning on line 24 in bayesflow/datasets/offline_ensemble_dataset.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/datasets/offline_ensemble_dataset.py#L23-L24

Added lines #L23 - L24 were not covered by tests

self.indices = _indices

Check warning on line 26 in bayesflow/datasets/offline_ensemble_dataset.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/datasets/offline_ensemble_dataset.py#L26

Added line #L26 was not covered by tests

# Shuffle first axis
if self._shuffle:
self.shuffle()

Check warning on line 30 in bayesflow/datasets/offline_ensemble_dataset.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/datasets/offline_ensemble_dataset.py#L29-L30

Added lines #L29 - L30 were not covered by tests
Loading