Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
67 commits
Select commit Hold shift + click to select a range
e329f4b
fix trainable parameters in distributions (#520)
vpratz Jun 22, 2025
e4e6da4
Merge branch 'dev' of https://github.com/bayesflow-org/bayesflow into…
stefanradev93 Jun 22, 2025
f916855
Improve numerical precision in MVNScore.log_prob
han-ol Jun 23, 2025
0c99bd9
add log_gamma diagnostic (#522)
daniel-habermann Jun 30, 2025
17540b1
Merge remote-tracking branch 'upstream/main' into dev [skip ci]
vpratz Jun 30, 2025
55d51df
Breaking changes: Fix bugs regarding counts in standardization layer …
vpratz Jul 1, 2025
7b27f14
Merge branch 'dev' of https://github.com/bayesflow-org/bayesflow into…
stefanradev93 Jul 1, 2025
2a19d32
rename log_gamma to calibration_log_gamma (#527)
daniel-habermann Jul 1, 2025
c9feff2
Merge branch 'dev' of https://github.com/bayesflow-org/bayesflow into…
stefanradev93 Jul 2, 2025
b23ebf5
simple fix
arrjon Jul 2, 2025
47bf8e8
Hotfix: numercial stability of non-log-stabilized sinkhorn plan (#531)
LarsKue Jul 8, 2025
36b38f0
isinstance sequence
arrjon Jul 8, 2025
f1c0c87
Merge pull request #530 from bayesflow-org/529-bug-serialization-of-t…
vpratz Jul 8, 2025
13112bc
Pass correct training stage in compute_metrics (#534)
han-ol Jul 9, 2025
2038d66
Custom test quantity support for calibration_ecdf (#528)
han-ol Jul 9, 2025
0ea79d7
Merge branch 'dev' of https://github.com/bayesflow-org/bayesflow into…
stefanradev93 Jul 11, 2025
0f729a4
Log gamma test fix (#535)
daniel-habermann Jul 11, 2025
d29a743
Stateless adapters (#536)
stefanradev93 Jul 14, 2025
86731f7
Fix training strategies in BasicWorkflow
elseml Jul 14, 2025
b54271f
move multimodal data notebook to regular examples [no ci]
vpratz Jul 14, 2025
61e0e88
make pip install call on homepage more verbose [no ci]
vpratz Jul 14, 2025
3b380f8
Merge remote-tracking branch 'upstream/main' into dev
vpratz Jul 14, 2025
887fbbc
remove deprecated summaries function
vpratz Jul 14, 2025
ebbddce
detail subsampling behavior docs for SIR simulator [no ci]
vpratz Jul 15, 2025
eb1f6ce
move DiffusionModel from experimental to networks
vpratz Jul 15, 2025
10db509
Add citation for resnet (#537) [no ci]
eodole Jul 15, 2025
706e3fd
Merge pull request #538 from bayesflow-org/stabilize-diffusion-model
vpratz Jul 15, 2025
5d3594c
Merge branch 'dev' of https://github.com/bayesflow-org/bayesflow into…
stefanradev93 Jul 19, 2025
13c1201
Bump up version [skip ci]
stefanradev93 Jul 19, 2025
47d2766
Merge remote-tracking branch 'upstream/main' into dev
vpratz Jul 22, 2025
d9e9782
Allow separate inputs to subnets for continuous models (#521)
arrjon Jul 24, 2025
8afff13
Auto-select backend (#543)
LarsKue Jul 24, 2025
fb3191b
Merge remote-tracking branch 'upstream/main' into dev [no ci]
vpratz Jul 28, 2025
d68c9dd
Breaking: parameterize MVNormalScore by inverse cholesky factor to im…
vpratz Jul 28, 2025
f6a1708
fix unconditional sampling in ContinuousApproximator (#548)
vpratz Jul 28, 2025
d333870
Test quantities Linear Regression Starter notebook (#544)
han-ol Jul 29, 2025
326f05a
fix: optimizer was not used in workflow with multiple fits
vpratz Aug 4, 2025
3d391d8
fix: remove extra deserialize call for SummaryNetwork
vpratz Aug 5, 2025
952862c
Compatibility: deserialize when get_config was overridden
vpratz Aug 5, 2025
58ad41b
unify log_prob signature in PointApproximator [no ci]
vpratz Aug 5, 2025
292419c
Tutorial on spatial data with Gaussian Random Fields (#540) [no ci]
vpratz Aug 5, 2025
f180ab4
Support non-array data in test_quantity calibration ecdf [no ci]
han-ol Aug 6, 2025
d2ac255
import calibration_log_gamma in diagnostics namespace [no ci]
han-ol Aug 7, 2025
7cabf17
Add wrapper around scipy.integrate.solve_ivp for integration
vpratz Aug 9, 2025
8567049
minor fixes and improvements to the pairs plot functions
vpratz Aug 10, 2025
0b09487
fix: layers were not deserialized for Sequential and Residual
vpratz Aug 13, 2025
d582111
add serialization tests for Sequential and Residual
vpratz Aug 13, 2025
954c16c
Fix: ensure checkpoint filepath exists before training
han-ol Aug 15, 2025
74a5036
Revert 954c16c0 since it was unnecessary
han-ol Aug 15, 2025
eead4f3
improvements to diagnostic plots (#556)
vpratz Aug 17, 2025
c1407df
Add pairs plot for arbitrary quantities (#550)
vpratz Aug 19, 2025
f269d8f
minor fix in diffusion edm schedule (#560)
arrjon Aug 21, 2025
4e47cc4
minor fix in diffusion edm schedule
arrjon Aug 21, 2025
12b06b9
DeepSet: Adapt output dimension of invariant module inside the equiva…
vpratz Aug 21, 2025
6914baf
pairs_postorior: inconsistent type hint fix (#562)
thegialeo Aug 22, 2025
3ff135d
allow exploding variance type in EDM schedule
arrjon Aug 25, 2025
2aa0c02
Merge remote-tracking branch 'origin/dev' into dev
arrjon Aug 25, 2025
55c18e2
fix type hint
arrjon Aug 25, 2025
ffda7a1
Bump up version [skip ci]
stefanradev93 Aug 26, 2025
076fdc8
Fix instructions for backend spec [skip ci]
stefanradev93 Aug 26, 2025
cd2d093
Add New Flow Matching Schedules (#565)
arrjon Aug 26, 2025
747fe5e
change default integration method to rk45
vpratz Aug 26, 2025
a28afb6
fix nan to num inverse
arrjon Aug 26, 2025
04bb665
fix setting markersize in lotka volterra notebook
vpratz Aug 26, 2025
e0ec0e4
Merge pull request #567 from bayesflow-org/fix_nan_tonum
arrjon Aug 26, 2025
958661c
fix: actually set KERAS_BACKEND to chosen backend
vpratz Aug 26, 2025
a0d8d5b
Fix warning msg
stefanradev93 Aug 26, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ Note that BayesFlow **will not run** without a backend.

If you don't know which backend to use, we recommend JAX as it is currently the fastest backend.

Once installed, [set the backend environment variable as required by keras](https://keras.io/getting_started/#configuring-your-backend).
As of version ``2.0.7``, the backend will be set automatically. If you have multiple backends, you can manually [set the backend environment variable as described by keras](https://keras.io/getting_started/#configuring-your-backend).
For example, inside your Python script write:

```python
Expand All @@ -97,8 +97,6 @@ Or just plainly set the environment variable in your shell:
export KERAS_BACKEND=jax
```

This way, you also don't have to manually set the backend every time you are starting Python to use BayesFlow.

## Getting Started

Using the high-level interface is easy, as demonstrated by the minimal working example below:
Expand Down
101 changes: 80 additions & 21 deletions bayesflow/__init__.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,12 @@
from . import (
approximators,
adapters,
augmentations,
datasets,
diagnostics,
distributions,
experimental,
networks,
simulators,
utils,
workflows,
wrappers,
)

from .adapters import Adapter
from .approximators import ContinuousApproximator, PointApproximator
from .datasets import OfflineDataset, OnlineDataset, DiskDataset
from .simulators import make_simulator
from .workflows import BasicWorkflow
# ruff: noqa: E402
# disable E402 to allow for setup code before importing any internals (which could import keras)


def setup():
# perform any necessary setup without polluting the namespace
import keras
import os
import logging
from importlib.util import find_spec

# set the basic logging level if the user hasn't already
logging.basicConfig(level=logging.INFO)
Expand All @@ -32,8 +15,63 @@ def setup():
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)

issue_url = "https://github.com/bayesflow-org/bayesflow/issues/new?template=bug_report.md"

if "KERAS_BACKEND" not in os.environ:
# check for available backends and automatically set the KERAS_BACKEND env variable or raise an error
class Backend:
def __init__(self, display_name, package_name, env_name, install_url, priority):
self.display_name = display_name
self.package_name = package_name
self.env_name = env_name
self.install_url = install_url
self.priority = priority

backends = [
Backend("JAX", "jax", "jax", "https://docs.jax.dev/en/latest/quickstart.html#installation", 0),
Backend("PyTorch", "torch", "torch", "https://pytorch.org/get-started/locally/", 1),
Backend("TensorFlow", "tensorflow", "tensorflow", "https://www.tensorflow.org/install", 2),
]

found_backends = []
for backend in backends:
if find_spec(backend.package_name) is not None:
found_backends.append(backend)

if not found_backends:
message = "No suitable backend found. Please install one of the following:\n"
for backend in backends:
message += f"{backend.display_name}\n"
message += "\n"

message += f"If you continue to see this error, please file a bug report at {issue_url}.\n"
message += (
"You can manually select a backend by setting the KERAS_BACKEND environment variable as shown below:\n"
)
message += "https://keras.io/getting_started/#configuring-your-backend"

raise ImportError(message)

if len(found_backends) > 1:
found_backends.sort(key=lambda b: b.priority)
chosen_backend = found_backends[0]
os.environ["KERAS_BACKEND"] = chosen_backend.env_name

logging.warning(
f"Multiple Keras-compatible backends detected ({', '.join(b.display_name for b in found_backends)}).\n"
f"Defaulting to {chosen_backend.display_name}.\n"
"To override, set the KERAS_BACKEND environment variable before importing bayesflow.\n"
"See: https://keras.io/getting_started/#configuring-your-backend"
)
else:
os.environ["KERAS_BACKEND"] = found_backends[0].env_name

import keras
from bayesflow.utils import logging

if keras.backend.backend().lower() != os.environ["KERAS_BACKEND"].lower():
logging.warning("Automatic backend selection failed, most likely because Keras was imported before BayesFlow.")

logging.info(f"Using backend {keras.backend.backend()!r}")

if keras.backend.backend() == "torch":
Expand All @@ -60,3 +98,24 @@ def setup():
# call and clean up namespace
setup()
del setup

from . import (
approximators,
adapters,
augmentations,
datasets,
diagnostics,
distributions,
experimental,
networks,
simulators,
utils,
workflows,
wrappers,
)

from .adapters import Adapter
from .approximators import ContinuousApproximator, PointApproximator
from .datasets import OfflineDataset, OnlineDataset, DiskDataset
from .simulators import make_simulator
from .workflows import BasicWorkflow
2 changes: 2 additions & 0 deletions bayesflow/adapters/transforms/nan_to_num.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,8 @@ def inverse(self, data: dict[str, any], **kwargs) -> dict[str, any]:
data = data.copy()

# Retrieve mask and values to reconstruct NaNs
if self.key not in data.keys():
return data
values = data[self.key]

if not self.return_mask:
Expand Down
2 changes: 1 addition & 1 deletion bayesflow/approximators/continuous_approximator.py
Original file line number Diff line number Diff line change
Expand Up @@ -537,7 +537,7 @@ def _sample(
)
batch_shape = keras.ops.shape(inference_conditions)[:-1]
else:
batch_shape = keras.ops.shape(inference_conditions)[1:-1]
batch_shape = (num_samples,)

return self.inference_network.sample(
batch_shape, conditions=inference_conditions, **filter_kwargs(kwargs, self.inference_network.sample)
Expand Down
7 changes: 1 addition & 6 deletions bayesflow/approximators/point_approximator.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,12 +143,7 @@ def sample(

return samples

def log_prob(
self,
*,
data: Mapping[str, np.ndarray],
**kwargs,
) -> np.ndarray | dict[str, np.ndarray]:
def log_prob(self, data: Mapping[str, np.ndarray], **kwargs) -> np.ndarray | dict[str, np.ndarray]:
"""
Computes the log-probability of given data under the parametric distribution(s) for given input conditions.

Expand Down
3 changes: 3 additions & 0 deletions bayesflow/diagnostics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from .metrics import (
bootstrap_comparison,
calibration_error,
calibration_log_gamma,
posterior_contraction,
summary_space_comparison,
)
Expand All @@ -18,7 +19,9 @@
mc_confusion_matrix,
mmd_hypothesis_test,
pairs_posterior,
pairs_quantity,
pairs_samples,
plot_quantity,
recovery,
recovery_from_estimates,
z_score_contraction,
Expand Down
10 changes: 6 additions & 4 deletions bayesflow/diagnostics/metrics/posterior_contraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ def posterior_contraction(
targets: Mapping[str, np.ndarray] | np.ndarray,
variable_keys: Sequence[str] = None,
variable_names: Sequence[str] = None,
aggregation: Callable = np.median,
aggregation: Callable | None = np.median,
) -> dict[str, any]:
"""
Computes the posterior contraction (PC) from prior to posterior for the given samples.
Expand All @@ -27,16 +27,17 @@ def posterior_contraction(
By default, select all keys.
variable_names : Sequence[str], optional (default = None)
Optional variable names to show in the output.
aggregation : callable, optional (default = np.median)
aggregation : callable or None, optional (default = np.median)
Function to aggregate the PC across draws. Typically `np.mean` or `np.median`.
If None is provided, the individual values are returned.

Returns
-------
result : dict
Dictionary containing:

- "values" : float or np.ndarray
The aggregated posterior contraction per variable
The (optionally aggregated) posterior contraction per variable
- "metric_name" : str
The name of the metric ("Posterior Contraction").
- "variable_names" : str
Expand All @@ -59,6 +60,7 @@ def posterior_contraction(
post_vars = samples["estimates"].var(axis=1, ddof=1)
prior_vars = samples["targets"].var(axis=0, keepdims=True, ddof=1)
contraction = np.clip(1 - (post_vars / prior_vars), 0, 1)
contraction = aggregation(contraction, axis=0)
if aggregation is not None:
contraction = aggregation(contraction, axis=0)
variable_names = samples["estimates"].variable_names
return {"values": contraction, "metric_name": "Posterior Contraction", "variable_names": variable_names}
2 changes: 2 additions & 0 deletions bayesflow/diagnostics/plots/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
from .mc_confusion_matrix import mc_confusion_matrix
from .mmd_hypothesis_test import mmd_hypothesis_test
from .pairs_posterior import pairs_posterior
from .pairs_quantity import pairs_quantity
from .plot_quantity import plot_quantity
from .pairs_samples import pairs_samples
from .recovery import recovery
from .recovery_from_estimates import recovery_from_estimates
Expand Down
40 changes: 12 additions & 28 deletions bayesflow/diagnostics/plots/calibration_ecdf.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from collections.abc import Callable, Mapping, Sequence

import numpy as np
import keras
import matplotlib.pyplot as plt

from ...utils.dict_utils import compute_test_quantities
from ...utils.plot_utils import prepare_plot_data, add_titles_and_labels, prettify_subplots
from ...utils.ecdf import simultaneous_ecdf_bands
from ...utils.ecdf.ranks import fractional_ranks, distance_ranks
Expand Down Expand Up @@ -136,33 +136,17 @@ def calibration_ecdf(

# Optionally, compute and prepend test quantities from draws
if test_quantities is not None:
test_quantities_estimates = {}
test_quantities_targets = {}

for key, test_quantity_fn in test_quantities.items():
# Apply test_quantity_func to ground-truths
tq_targets = test_quantity_fn(data=targets)
test_quantities_targets[key] = np.expand_dims(tq_targets, axis=1)

# # Flatten estimates for batch processing in test_quantity_fn, apply function, and restore shape
num_conditions, num_samples = next(iter(estimates.values())).shape[:2]
flattened_estimates = keras.tree.map_structure(lambda t: np.reshape(t, (-1, *t.shape[2:])), estimates)
flat_tq_estimates = test_quantity_fn(data=flattened_estimates)
test_quantities_estimates[key] = np.reshape(flat_tq_estimates, (num_conditions, num_samples, 1))

# Add custom test quantities to variable keys and names for plotting
# keys and names are set to the test_quantities dict keys
test_quantities_names = list(test_quantities.keys())

if variable_keys is None:
variable_keys = list(estimates.keys())

if isinstance(variable_names, list):
variable_names = test_quantities_names + variable_names

variable_keys = test_quantities_names + variable_keys
estimates = test_quantities_estimates | estimates
targets = test_quantities_targets | targets
updated_data = compute_test_quantities(
targets=targets,
estimates=estimates,
variable_keys=variable_keys,
variable_names=variable_names,
test_quantities=test_quantities,
)
variable_names = updated_data["variable_names"]
variable_keys = updated_data["variable_keys"]
estimates = updated_data["estimates"]
targets = updated_data["targets"]

plot_data = prepare_plot_data(
estimates=estimates,
Expand Down
13 changes: 10 additions & 3 deletions bayesflow/diagnostics/plots/calibration_ecdf_from_quantiles.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def calibration_ecdf_from_quantiles(
fill_color: str = "grey",
num_row: int = None,
num_col: int = None,
markersize: float = None,
**kwargs,
) -> plt.Figure:
"""
Expand Down Expand Up @@ -97,6 +98,8 @@ def calibration_ecdf_from_quantiles(
num_col : int, optional, default: None
The number of columns for the subplots.
Dynamically determined if None.
markersize : float, optional, default: None
The marker size in points.
**kwargs : dict, optional, default: {}
Keyword arguments can be passed to control the behavior of
ECDF simultaneous band computation through the ``ecdf_bands_kwargs``
Expand Down Expand Up @@ -142,11 +145,15 @@ def calibration_ecdf_from_quantiles(

if stacked:
if j == 0:
plot_data["axes"][0].plot(xx, yy, marker="o", color=rank_ecdf_color, alpha=0.95, label="Rank ECDFs")
plot_data["axes"][0].plot(
xx, yy, marker="o", color=rank_ecdf_color, markersize=markersize, alpha=0.95, label="Rank ECDFs"
)
else:
plot_data["axes"][0].plot(xx, yy, marker="o", color=rank_ecdf_color, alpha=0.95)
plot_data["axes"][0].plot(xx, yy, marker="o", color=rank_ecdf_color, markersize=markersize, alpha=0.95)
else:
plot_data["axes"].flat[j].plot(xx, yy, marker="o", color=rank_ecdf_color, alpha=0.95, label="Rank ECDF")
plot_data["axes"].flat[j].plot(
xx, yy, marker="o", color=rank_ecdf_color, markersize=markersize, alpha=0.95, label="Rank ECDF"
)

# Compute uniform ECDF and bands
alpha, z, L, U = pointwise_ecdf_bands(estimates.shape[0], **kwargs.pop("ecdf_bands_kwargs", {}))
Expand Down
5 changes: 4 additions & 1 deletion bayesflow/diagnostics/plots/mc_calibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def mc_calibration(
color: str = "#132a70",
num_col: int = None,
num_row: int = None,
markersize: float = None,
) -> plt.Figure:
"""Plots the calibration curves, the ECEs and the marginal histograms of predicted posterior model probabilities
for a model comparison problem. The marginal histograms inform about the fraction of predictions in each bin.
Expand Down Expand Up @@ -60,6 +61,8 @@ def mc_calibration(
The number of rows for the subplots. Dynamically determined if None.
num_col : int, optional, default: None
The number of columns for the subplots. Dynamically determined if None.
markersize : float, optional, default: None
The marker size in points.

Returns
-------
Expand Down Expand Up @@ -88,7 +91,7 @@ def mc_calibration(

for j, ax in enumerate(plot_data["axes"].flat):
# Plot calibration curve
ax.plot(ece["probs_pred"][j], ece["probs_true"][j], "o-", color=color)
ax.plot(ece["probs_pred"][j], ece["probs_true"][j], "o-", color=color, markersize=markersize)

# Plot PMP distribution over bins
uniform_bins = np.linspace(0.0, 1.0, num_bins + 1)
Expand Down
Loading
Loading