Skip to content

Add Nnpe adapter class #488

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
May 27, 2025
Merged

Add Nnpe adapter class #488

merged 11 commits into from
May 27, 2025

Conversation

elseml
Copy link
Member

@elseml elseml commented May 23, 2025

Introduces a new adapter class, Nnpe, which implements Noisy Neural Posterior Estimation (NNPE) as described in Ward et al. (2022). NNPE augments training data with additive noise from a spike-and-slab distribution (mixture of Cauchy and Normal noise), aiming to improve robustness for noisy real world data by extending the scope of the training data. While the specific noise distributions seem to be mostly motivated by their usage in the non-amortized RNPE method proposed in the same paper, it proved to be a quick-and-easy improvement for situations where one expects the presence of similarly distributed noise in our recent benchmarks.

I tried to incorporate all best practices from existing adapter classes, but would be thankful for a thorough check of the implemented changes and the clarity of the provided explanations. I used a seed argument instead of passing an rng for easier serialization, but there might be a more elegant way.

Copy link

codecov bot commented May 23, 2025

Codecov Report

Attention: Patch coverage is 89.65517% with 6 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
bayesflow/adapters/transforms/nnpe.py 88.23% 6 Missing ⚠️
Files with missing lines Coverage Δ
bayesflow/adapters/adapter.py 85.06% <100.00%> (+0.38%) ⬆️
bayesflow/adapters/transforms/__init__.py 100.00% <100.00%> (ø)
bayesflow/adapters/transforms/nnpe.py 88.23% <88.23%> (ø)

... and 5 files with indirect coverage changes

@stefanradev93
Copy link
Contributor

Thanks for the PR, Lasse! I suggest going for NNPE (all caps for the class name, since it's an abbreviation). Also, can you please add a docstring for the forward() method?

@elseml
Copy link
Member Author

elseml commented May 26, 2025

Sure, integrated both suggestions with a brief docstring for the forward() method since most information is already contained in the class docstring.

One more thing we might want to discuss: I followed the Standardize class in setting the stage="inference" default in the forward() method but we could also leave out the default arg since the forward() method is only activated when stage="training" is explicitly passed. What is our current stance here?

@elseml
Copy link
Member Author

elseml commented May 26, 2025

Since standardization will be moved into approximators with #486, I added automatic adjustment of the default scales of Ward et al. (2022) (which are designed for standardized data) by multiplication with the standard deviation of the data, which removes the reliance of the NNPE transform on previous data standardization.

@stefanradev93
Copy link
Contributor

Great, thanks Lasse! Looks ready to merge. Last question:

Is the fixed p=0.5 for the binomial always desirable or should this he a hyperparameter?


# Automatically determine scales if not provided
if self.spike_scale is None or self.slab_scale is None:
data_std = np.std(data)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will lead to different scales for each batch, I'm not sure if this is desirable. If we choose to do this, we should state it more explicitly in the docstring.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is an alternative solution:

  • Standardize data with batch_mean and batch_std
  • Add unscaled spikes and slabs
  • Re-scale back with batch batch_mean and batch_std

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for checking! Yes, this is the drawback of that solution. I still think it is preferable to more complex solutions, since the method is more about adding some noise at all rather than adding a very specific amount of noise (the default scales by Ward et al. also seem quite heuristically chosen to me). If you agree on this, I can add some more info in the docstring.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would we want to do the automatic scaling globally or per dimension? I think this would be the main difference in what @stefanradev93 proposed and how it is implemented now, right?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The way Lasse explained to me, the approach explicitly wants that scale(original) < scale(transformed). In that case, I think fluctuations between batches are fine, as the downstream Standardize layer (which will be part of approximators) will take care of that.

Copy link
Collaborator

@vpratz vpratz May 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Then the question still is how we want to automatically determine the scale, globally or per dimension? If dimensions don't have equal magnitude, we might accidentally erase some of them completely. On the other hand, some dimensions might have zero variation (e.g. in image datasets like MNIST), so we would have to decide how to deal with those...

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good question. I would scale dimensionwise.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I implemented it globally following the original NNPE implementation, but agree that dimensionwise scaling would be valuable in many situations and will implement it as an option. I think dimensions with zero variation are not problematic since in that case nothings breaks, there will simply be no noise added. Dimensionwise instead of global scaling will increase the variability of the std calculation between batches though.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Only to make sure, please set dimensionwise as the default, and make global scaling the option.

@elseml
Copy link
Member Author

elseml commented May 26, 2025

@stefanradev93: Ward et al. (2022) chose p=0.5 in the context of the non-amortized RNPE method and ported this to NNPE, so in the context of NNPE it is simply one of many possible noise-generating mechanisms with the advantage of being tested in at least two papers so far. I would stay with the fixed value here to reduce complexity.

else:
try:
scalar = float(passed)
except Exception:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use more specific exceptions here, I think ValueError and TypeError are the relevant ones.

@elseml
Copy link
Member Author

elseml commented May 27, 2025

Thanks for the careful review! The changes added quite some complexity, since the expected shape of passed scales now changes depending on global (expects float) or dimensionwise (expects 1D array of length data.shape[-1]) noise application, and both passed and automatically determined scales should be possible. But the transform should be quite flexible now to handle unstandardized and heterogeneous input data.

@stefanradev93 stefanradev93 merged commit 361fa45 into bayesflow-org:dev May 27, 2025
9 checks passed
stefanradev93 added a commit that referenced this pull request Jun 17, 2025
* Subset arrays (#411)

* made initial backend functions for adapter subsetting, need to still make the squeeze function and link it to the front end

* added subsample functionality, to do would be adding them to testing procedures

* made the take function and ran the linter

* changed name of subsampling function

* changed documentation, to be consistent with external notation, rather than internal shorthand

* small formation change to documentation

* changed subsample to have sample size and axis in the constructor

* moved transforms in the adapter.py so they're in alphabetical order like the other transforms

* changed random_subsample to maptransform rather than filter transform

* updated documentation with new naming convention

* added arguments of take to the constructor

* added feature to specify a percentage of the data to subsample rather than only integer input

* changed subsample in adapter.py to allow float as an input for the sample size

* renamed subsample_array and associated classes/functions to RandomSubsample and random_subsample respectively

* included TypeError to force users to only subsample one dataset at a time

* ran linter

* rerun formatter

* clean up random subsample transform and docs

* clean up take transform and docs

* nitpick clean-up

* skip shape check for subsampled adapter transform inverse

* fix serialization of new transforms

* skip randomly subsampled key in serialization consistency check

---------

Co-authored-by: LarsKue <lars@kuehmichel.de>

* [no ci] docs: start of user guide - draft intro, gen models

* [no ci] add draft for data processing section

* [no ci] user guide: add stub on summary/inference networks

* [no ci] user guide: add stub on additional topics

* [no ci] add early stage disclaimer to user guide

* pin dependencies in docs, fixes snowballstemmer error

* fix: correct check for "no accepted samples" in rejection_sample

Closes #466

* Stabilize MultivariateNormalScore by constraining initialization in PositiveDefinite link (#469)

* Refactor fill_triangular_matrix

* stable positive definite link, fix for #468

* Minor changes to docstring

* Remove self.built=True that prevented registering layer norm in build()

* np -> keras.ops

* Augmentation (#470)

* Remove old rounds data set, add documentation, and augmentation options to data sets

* Enable augmentation to parts of the data or the whole data

* Improve doc

* Enable augmentations in workflow

* Fix silly type check and improve readability of for loop

* Bring back num_batches

* Fixed log det jac computation of standardize transform

y = (x - mu) / sigma
log p(y) = log p(x) - log(sigma)

* Fix fill_triangular_matrix

The two lines were switched, leading to performance degradation.

* Deal with inference_network.log_prob to return dict (as PointInferenceNetwork does)

* Add diffusion model implementation (#408)

This commit contains the following changes (see PR #408 for discussions)

- DiffusionModel following the formalism in Kingma et. al (2023) [1]
- Stochastic sampler to solve SDEs
- Tests for the diffusion model

[1] https://arxiv.org/abs/2303.00848

---------

Co-authored-by: arrjon <jonas.arruda@uni-bonn.de>
Co-authored-by: Jonas Arruda <69197639+arrjon@users.noreply.github.com>
Co-authored-by: LarsKue <lars@kuehmichel.de>

* [no ci] networks docstrings: summary/inference network indicator (#462)

- From the table in the `bayesflow.networks` module overview, one cannot
  tell which network belongs to which group. This commit adds short
  labels to indicate inference networks (IN) and summary networks (SN)

* `ModelComparisonSimulator`: handle different outputs from individual simulators (#452)

Adds option to drop, fill or error when different keys are encountered in the outputs of different simulators. Fixes #441.

---------

Co-authored-by: Valentin Pratz <git@valentinpratz.de>

* Add classes and transforms to simplify multimodal training (#473)

* Add classes and transforms to simplify multimodal training

- Add class `MultimodalSummaryNetwork` to combine multiple summary
  networks, each for one modality.
- Add transforms `Group` and `Ungroup`, to gather the multimodal inputs
  in one variable (usually "summary_variables")
- Add tests for new behavior

* [no ci] add tutorial notebook for multimodal data

* [no ci] add missing training argument

* rename MultimodalSummaryNetwork to FusionNetwork

* [no ci] clarify that the network implements late fusion

* allow dispatch of summary/inference network from type

* add tests for find_network

* Add squeeze transform

Very basic transform, just the inverse of expand_dims

* [no ci] fix examples in ExpandDims docstring

* squeeze: adapt example, add comment for changing batch dims

* Permit Python version 3.12 (#474)

Allow Python version 3.12 after successful CI run: https://github.com/bayesflow-org/bayesflow/actions/runs/14988542031

* Change order in readme and reference new book [skip ci]

* make docs optional dependencies compatible with python 3.10

* Add a custom `Sequential` network to avoid issues with building and serialization in keras (#493)

* add custom sequential to fix #491

* revert using Sequential in classifier_two_sample_test.py

* Add docstring to custom Sequential

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* fix copilot docstring

* remove mlp override methods

---------

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Add Nnpe adapter class (#488)

* Add NNPE adapter

* Add NNPE adapter tests

* Only apply NNPE during training

* Integrate stage differentiation into tests

* Improve test coverage

* Fix inverse and add to tests

* Adjust class name and add docstring to forward method

* Enable compatibility with #486 by adjusting scales automatically

* Add dimensionwise noise application

* Update exception handling

* Fix tests

* Align diffusion model with other inference networks and remove deprecation warnings (#489)

* Align dm implementation with other networks

* Remove deprecation warning for using subnet_kwargs

* Fix tests

* Remove redundant training arg in get_alpha_sigma and some redundant comments

* Fix configs creation - do not get base config due to fixed call of super().__init__()

* Remove redundant training arg from tests

* Fix dispatch tests for dms

* Improve docs and mark option for x prediction in literal

* Fix start/stop time

* minor cleanup of refactory

---------

Co-authored-by: Valentin Pratz <git@valentinpratz.de>

* add replace nan adapter (#459)

* add replace nan adapter

* improved naming

* _mask as additional key

* update test

* improve

* fix serializable

* changed name to return_mask

* add mask naming

* [no ci] docs: add basic likelihood estimation example

Fixes #476. This is the barebones version showing the technical steps to
do likelihood estimation. Adding more background and motivation would be
nice.

* make metrics serializable

It seems that metrics do not store their state, I'm not sure yet if this
is intended behavior.

* Remove layer norm; add epsilon to std dev for stability of pos def link

this breaks serialization of point estimation with MultivariateNormalScore

* add end-to-end test for fusion network

* fix: ensure that build is called in FusionNetwork

* Correctly track train / validation losses (#485)

* correctly track train / validation losses

* remove mmd from two moons test

* reenable metrics in continuous approximator, add trackers

* readd custom metrics to two_moons test

* take batch size into account when aggregating metrics

* Add docs to backend approximator interfaces

* Add small doc improvements

* Fix typehints to docs.

---------

Co-authored-by: Valentin Pratz <git@valentinpratz.de>
Co-authored-by: stefanradev93 <stefan.radev93@gmail.com>

* Add shuffle parameter to datasets

Adds the option to disable data shuffling

---------

Co-authored-by: Lars <lars@kuehmichel.de>
Co-authored-by: Valentin Pratz <git@valentinpratz.de>

* fix: correct vjp/jvp calls in FreeFormFlow

The signature changed, making it necessary to set return_output=True

* test: add basic compute_metrics test for inference networks

* [no ci] extend point approximator tests

- remove skip for MVN
- add test for log-prob

* [no ci] skip unstable MVN sample test again

* update README with more specific install instructions

* fix FreeFormFlow: remove superfluous index form signature change

* [no ci] FreeFormFlow MLP defaults: set dropout to 0

* Better pairplots (#505)

* Hacky fix for pairplots

* Ensure that target sits in front of other elements

* Ensure consistent spacing between plot and legends + cleanup

* Update docs

* Fix the propagation of `legend_fontsize`

* Minor fix to comply with code style

* [no ci] Formatting: escaped space only in raw strings

* [no ci] fix typo in error message, model comparison approximator

* [no ci] fix: size_of could not handle basic int/float

Passing in basic types would lead to infinite recursion. Checks for
other types than int and float might be necessary as well.

* add tests for model comparison approximator

* Generalize sample shape to arbitrary N-D arrays

* [WIP] Move standardization into approximators and make adapter stateless. (#486)

* Add standardization to continuous approximator and test

* Fix init bugs, adapt tnotebooks

* Add training flag to build_from_data

* Fix inference conditions check

* Fix tests

* Remove unnecessary init calls

* Add deprecation warning

* Refactor compute metrics and add standardization to model comp

* Fix standardization in cont approx

* Fix sample keys -> condition keys

* amazing keras fix

* moving_mean and moving_std still not loading [WIP]

* remove hacky approximator serialization test

* fix building of models in tests

* Fix standardization

* Add standardizatrion to model comp and let it use inheritance

* make assert_models/layers_equal more thorough

* [no ci] use map_shape_structure to convert shapes to arrays

This automatically takes care of nested structures.

* Extend Standardization to support nested inputs (#501)

* extend Standardization to nested inputs

By using `keras.tree.flatten` und `keras.tree.pack_sequence_as`, we can
support arbitrary nested structures. A `flatten_shape` function is
introduced, analogous to `map_shape_structure`, for use in the build
function.

* keep tree utils in submodule

* Streamline call

* Fix typehint

---------

Co-authored-by: stefanradev93 <stefan.radev93@gmail.com>

* Update moments before transform and update test

* Update notebooks

* Refactor and simplify due to standardize

* Add comment for fetching the dict's first item, deprecate logits arg and fix typehint

* add missing import in test

* Refactor preparation of data for networks and new point_appr.log_prob

* ContinuousApproximator._prepare_data unifies all preparation in
  sample, log_prob and estimate for both ContinuousApproximator and
  PointApproximator
* PointApproximator now overrides log_prob

* Add class attributes to inform proper standardization

* Implement stable moving mean and std

* Adapt and fix tests

* minor adaptations to moving average (update time, init)

We should put the update before the standardization, to use the maximum
amount of information available. We can then also initialize the moving
M^2 with zero, as it will be filled immediately.

The special case of M^2 = 0 is not problematic, as no variance
automatically indicates that all entries are equal, and we can set
them to zero  (see my comment).

I added another test case to cover that case, and added a test for the
standard deviation to the existing test.

* increase tolerance of allclose tests

* [no ci] set trainable to False explicitly in ModelComparisonApproximator

* point estimate of covariance compatible with standardization

* properly set values to zero if std is zero

Cases for inf and -inf were missing

* fix sample post-processing in point approximator

* activate tests for multivariate normal score

* [no ci] undo prev commit: MVN test still not stable, was hidden by std of 0

* specify explicit build functions for approximators

* set std for untrained standardization layer to one

An untrained layer thereby does not modify the input.

* [no ci] reformulate zero std case

* approximator builds: add guards against building networks twice

* [no ci] add comparison with loaded approx to workflow test

* Cleanup and address building standardization layers  when None specified

* Cleanup and address building standardization layers when None specified 2

* Add default case for std transform and add transformation to doc.

* adapt handling of the special case M^2=0

* [no ci] minor fix in concatenate_valid_shapes

* [no ci] extend test suite for approximators

* fixes for standardize=None case

* skip unstable MVN score case

* Better transformation types

* Add test for both_sides_scale inverse standardization

* Add test for left_side_scale inverse standardization

* Remove flaky test failing due to sampling error

* Fix input dtypes in inverse standardization transformation_type tests

* Use concatenate_valid in _sample

* Replace PositiveDefinite link with CholeskyFactor

This finally makes the MVN score sampling test stable for the jax backend,
for which the keras.ops.cholesky operation is numerically unstable.

The score's sample method avoids calling keras.ops.cholesky to resolve
the issue. Instead the estimation head returns the Cholesky factor
directly rather than the covariance matrix (as it used to be).

* Reintroduce test sampling with MVN score

* Address TODOs and adapt docstrings and workflow

* Adapt notebooks

* Fix in model comparison

* Update readme and add point estimation nb

---------

Co-authored-by: LarsKue <lars@kuehmichel.de>
Co-authored-by: Valentin Pratz <git@valentinpratz.de>
Co-authored-by: Valentin Pratz <112951103+vpratz@users.noreply.github.com>
Co-authored-by: han-ol <g@hans.olischlaeger.com>
Co-authored-by: Hans Olischläger <106988117+han-ol@users.noreply.github.com>

* Replace deprecation with FutureWarning

* Adjust filename for LV

* Fix types for subnets

* [no ci] minor fixes to RandomSubsample transform

* [no ci] remove subnet deprecation in cont-time CM

* Remove empty file [no ci]

* Revert layer type for coupling flow [skip ci]

* remove failing import due to removed find_noise_schedule.py [no ci]

* Add utility function for batched simulations (#511)

The implementation is a simple wrapper leveraging the batching
capabilities of `rejection_sample`.

* Restore PositiveDefinite link with deprecation warning

* skip cycle consistency test for diffusion models

- the test is unstable for untrained diffusion models, as the networks
  output is not sufficiently smooth for the step size we use
- remove the diffusion_model marker

* Implement changes to NNPE adapter for #510 (#514)

* Move docstring to comment

* Always cast to _resolve_scale

* Fix typo

* [no ci] remove unnecessary serializable decorator on rmse

* fix type hint in squeeze [no ci]

* reintroduce comment in jax approximator [no ci]

* remove unnecessary getattr calls [no ci]

* Rename local variable transformation_type

* fix error type in diffusion model [no ci]

* remove non-functional per_training_step from plots.loss

* Update doc [skip ci]

* rename approximator.summaries to summarize with deprecation

* address remaining comments

---------

Co-authored-by: Leona Odole <88601208+eodole@users.noreply.github.com>
Co-authored-by: LarsKue <lars@kuehmichel.de>
Co-authored-by: Valentin Pratz <git@valentinpratz.de>
Co-authored-by: Hans Olischläger <106988117+han-ol@users.noreply.github.com>
Co-authored-by: han-ol <g@hans.olischlaeger.com>
Co-authored-by: Valentin Pratz <112951103+vpratz@users.noreply.github.com>
Co-authored-by: arrjon <jonas.arruda@uni-bonn.de>
Co-authored-by: Jonas Arruda <69197639+arrjon@users.noreply.github.com>
Co-authored-by: Simon Kucharsky <kucharssim@gmail.com>
Co-authored-by: Daniel Habermann <133031176+daniel-habermann@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Lasse Elsemüller <60779710+elseml@users.noreply.github.com>
Co-authored-by: Jerry Huang <57327805+jerrymhuang@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants