Skip to content

[WIP] Move standardization into approximators and make adapter stateless. #486

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 71 commits into from
Jun 12, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
71 commits
Select commit Hold shift + click to select a range
ceab303
Add standardization to continuous approximator and test
stefanradev93 May 22, 2025
d79b17a
Fix init bugs, adapt tnotebooks
stefanradev93 May 23, 2025
c777122
Add training flag to build_from_data
stefanradev93 May 23, 2025
7aeb9cb
Fix inference conditions check
stefanradev93 May 23, 2025
45ab9ea
Fix tests
stefanradev93 May 23, 2025
a83770a
Remove unnecessary init calls
stefanradev93 May 23, 2025
4df270a
Add deprecation warning
stefanradev93 May 24, 2025
8ea6782
Refactor compute metrics and add standardization to model comp
stefanradev93 May 25, 2025
b2a4f76
Fix standardization in cont approx
stefanradev93 May 26, 2025
deffc27
Fix sample keys -> condition keys
stefanradev93 May 26, 2025
43af4bd
amazing keras fix
LarsKue May 26, 2025
039fc8d
moving_mean and moving_std still not loading [WIP]
stefanradev93 May 26, 2025
02ded97
remove hacky approximator serialization test
LarsKue May 27, 2025
54d860e
fix building of models in tests
LarsKue May 27, 2025
2a86cc3
Fix standardization
stefanradev93 May 27, 2025
1df9269
Add standardizatrion to model comp and let it use inheritance
stefanradev93 May 27, 2025
49af469
make assert_models/layers_equal more thorough
LarsKue May 27, 2025
1fdde32
Merge remote-tracking branch 'origin/standardize_in_approx' into stan…
LarsKue May 27, 2025
0869e3f
[no ci] use map_shape_structure to convert shapes to arrays
vpratz May 31, 2025
1a845e3
Merge remote-tracking branch 'upstream/dev' into standardize_in_approx
vpratz Jun 1, 2025
bd2725d
Extend Standardization to support nested inputs (#501)
vpratz Jun 1, 2025
c5fb949
Update moments before transform and update test
stefanradev93 Jun 1, 2025
100d7c0
Update notebooks
stefanradev93 Jun 1, 2025
905bf05
Merge dev into branch
stefanradev93 Jun 1, 2025
38f2228
Refactor and simplify due to standardize
stefanradev93 Jun 1, 2025
0c24db2
Add comment for fetching the dict's first item, deprecate logits arg …
stefanradev93 Jun 2, 2025
5755135
Merge remote-tracking branch 'upstream/dev' into standardize_in_approx
vpratz Jun 2, 2025
4fa1bbb
add missing import in test
vpratz Jun 2, 2025
b2bfeea
Refactor preparation of data for networks and new point_appr.log_prob
han-ol Jun 3, 2025
5773d28
Merge branch 'standardize_in_approx' of https://github.com/bayesflow-…
stefanradev93 Jun 3, 2025
392d9f7
Add class attributes to inform proper standardization
han-ol Jun 4, 2025
2d5b2fb
Implement stable moving mean and std
stefanradev93 Jun 4, 2025
bde587c
Merge and add incremental moments
stefanradev93 Jun 4, 2025
1b2b5be
Adapt and fix tests
stefanradev93 Jun 4, 2025
d406a29
minor adaptations to moving average (update time, init)
vpratz Jun 5, 2025
a503bd9
increase tolerance of allclose tests
vpratz Jun 5, 2025
caf0491
[no ci] set trainable to False explicitly in ModelComparisonApproximator
vpratz Jun 5, 2025
dd24941
Merge branch 'standardize_in_approx' of https://github.com/bayesflow-…
stefanradev93 Jun 5, 2025
8268128
point estimate of covariance compatible with standardization
han-ol Jun 6, 2025
e32ae2e
properly set values to zero if std is zero
vpratz Jun 6, 2025
b7d6c0e
fix sample post-processing in point approximator
vpratz Jun 6, 2025
00d72ab
activate tests for multivariate normal score
vpratz Jun 6, 2025
c2ebd23
[no ci] undo prev commit: MVN test still not stable, was hidden by st…
vpratz Jun 6, 2025
cd45b85
specify explicit build functions for approximators
vpratz Jun 6, 2025
3f28f34
Merge remote-tracking branch 'upstream/dev' into standardize_in_approx
vpratz Jun 6, 2025
0952a29
set std for untrained standardization layer to one
vpratz Jun 6, 2025
5c529a2
[no ci] reformulate zero std case
vpratz Jun 6, 2025
399a1b4
approximator builds: add guards against building networks twice
vpratz Jun 6, 2025
dd0dc87
[no ci] add comparison with loaded approx to workflow test
vpratz Jun 6, 2025
d28df75
Cleanup and address building standardization layers when None specified
stefanradev93 Jun 6, 2025
40d2d1d
Cleanup and address building standardization layers when None specifi…
stefanradev93 Jun 6, 2025
c6d79ae
Add default case for std transform and add transformation to doc.
stefanradev93 Jun 6, 2025
df1761b
adapt handling of the special case M^2=0
vpratz Jun 7, 2025
3b93251
[no ci] minor fix in concatenate_valid_shapes
vpratz Jun 7, 2025
65cac46
Merge remote-tracking branch 'upstream/dev' into standardize_in_approx
vpratz Jun 7, 2025
1944186
[no ci] extend test suite for approximators
vpratz Jun 7, 2025
1ebf1cd
fixes for standardize=None case
vpratz Jun 7, 2025
71cd6b9
skip unstable MVN score case
vpratz Jun 7, 2025
3f0f9d1
Better transformation types
han-ol Jun 9, 2025
a3b59c3
Add test for both_sides_scale inverse standardization
han-ol Jun 9, 2025
183f608
Add test for left_side_scale inverse standardization
han-ol Jun 9, 2025
f0de38b
Remove flaky test failing due to sampling error
han-ol Jun 9, 2025
43ced5b
Fix input dtypes in inverse standardization transformation_type tests
han-ol Jun 9, 2025
c3e945e
Merge branch 'dev' into standardize_in_approx
han-ol Jun 9, 2025
82e28a7
Use concatenate_valid in _sample
han-ol Jun 10, 2025
ef97a6c
Replace PositiveDefinite link with CholeskyFactor
han-ol Jun 10, 2025
24c268b
Reintroduce test sampling with MVN score
han-ol Jun 10, 2025
e45f260
Address TODOs and adapt docstrings and workflow
stefanradev93 Jun 11, 2025
333c30f
Adapt notebooks
stefanradev93 Jun 11, 2025
48bb190
Fix in model comparison
stefanradev93 Jun 11, 2025
fd83567
Update readme and add point estimation nb
stefanradev93 Jun 12, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ Many examples from [Bayesian Cognitive Modeling: A Practical Course](https://bay
1. [Linear regression starter example](examples/Linear_Regression_Starter.ipynb)
2. [From ABC to BayesFlow](examples/From_ABC_to_BayesFlow.ipynb)
3. [Two moons starter example](examples/Two_Moons_Starter.ipynb)
4. [Rapid iteration with point estimators](examples/Lotka_Volterra_Point_Estimation_and_Expert_Stats.ipynb)
4. [Rapid iteration with point estimators](examples/Lotka_Volterra_Point_Estimation.ipynb)
5. [SIR model with custom summary network](examples/SIR_Posterior_Estimation.ipynb)
6. [Bayesian experimental design](examples/Bayesian_Experimental_Design.ipynb)
7. [Simple model comparison example](examples/One_Sample_TTest.ipynb)
Expand Down
9 changes: 9 additions & 0 deletions bayesflow/adapters/transforms/standardize.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from collections.abc import Sequence
import warnings

import numpy as np

Expand Down Expand Up @@ -69,6 +70,14 @@ def __init__(
):
super().__init__()

if mean is None or std is None:
warnings.warn(
"Dynamic standardization is deprecated and will be removed in later versions."
"Instead, use the standardize argument of the approximator / workflow instance or provide "
"fixed mean and std arguments. You may incur some redundant computations if you keep this transform.",
DeprecationWarning,
)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would be nice to have a convenience function that calculates mean and std for a dataset, in the format that would be required here. We could also advertise it in the deprecation warning. What do you think?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed, but such a function will not be very efficient when the entire data set is not (yet) in memory. I see its use mainly for OfflineDataset.

self.mean = mean
self.std = std

Expand Down
1 change: 1 addition & 0 deletions bayesflow/adapters/transforms/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def __repr__(self):

@classmethod
def from_config(cls, config: dict, custom_objects=None):
# noinspection PyArgumentList
return cls(**deserialize(config, custom_objects=custom_objects))

def get_config(self) -> dict:
Expand Down
16 changes: 9 additions & 7 deletions bayesflow/approximators/approximator.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,16 @@


class Approximator(BackendApproximator):
def build(self, data_shapes: any) -> None:
mock_data = keras.tree.map_structure(keras.ops.zeros, data_shapes)
self.build_from_data(mock_data)
def build(self, data_shapes: dict[str, tuple[int] | dict[str, dict]]) -> None:
raise NotImplementedError

@classmethod
def build_adapter(cls, **kwargs) -> Adapter:
# implemented by each respective architecture
raise NotImplementedError

def build_from_data(self, data: dict[str, any]) -> None:
self.compute_metrics(**filter_kwargs(data, self.compute_metrics), stage="training")
self.built = True
def build_from_data(self, adapted_data: dict[str, any]) -> None:
raise NotImplementedError

@classmethod
def build_dataset(
Expand Down Expand Up @@ -61,6 +59,9 @@ def build_dataset(
max_queue_size=max_queue_size,
)

def call(self, *args, **kwargs):
return self.compute_metrics(*args, **kwargs)

def fit(self, *, dataset: keras.utils.PyDataset = None, simulator: Simulator = None, **kwargs):
"""
Trains the approximator on the provided dataset or on-demand data generated from the given simulator.
Expand Down Expand Up @@ -132,6 +133,7 @@ def fit(self, *, dataset: keras.utils.PyDataset = None, simulator: Simulator = N
logging.info("Building on a test batch.")
mock_data = dataset[0]
mock_data = keras.tree.map_structure(keras.ops.convert_to_tensor, mock_data)
self.build_from_data(mock_data)
mock_data_shapes = keras.tree.map_structure(keras.ops.shape, mock_data)
self.build(mock_data_shapes)

return super().fit(dataset=dataset, **kwargs)
Loading