-
Notifications
You must be signed in to change notification settings - Fork 72
ModelComparisonSimulator
: handle different outputs from individual simulators
#452
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ModelComparisonSimulator
: handle different outputs from individual simulators
#452
Conversation
Codecov ReportAll modified and coverable lines are covered by tests ✅
|
Thanks for the PR, I skimmed it and like the idea behind the changes. I'll try to conduct a proper review some time this week. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good from my side. See individual comments.
Can we also add tests for the (few) missed edge-case lines?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See my comment on one edge case, I'm not sure if it is a relevant one. What do you think?
Apart from that, the PR looks good to me, I only added minor formatting fixes to the docstring.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The changes are looking great. Can we address Valentin's comments before we merge? I also left some minor comments, still.
add newlines to correctly render lists, make reference to other class a link
e77fb51
to
fc7cf0d
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for the revised changes, they look good to me now.
* Subset arrays (#411) * made initial backend functions for adapter subsetting, need to still make the squeeze function and link it to the front end * added subsample functionality, to do would be adding them to testing procedures * made the take function and ran the linter * changed name of subsampling function * changed documentation, to be consistent with external notation, rather than internal shorthand * small formation change to documentation * changed subsample to have sample size and axis in the constructor * moved transforms in the adapter.py so they're in alphabetical order like the other transforms * changed random_subsample to maptransform rather than filter transform * updated documentation with new naming convention * added arguments of take to the constructor * added feature to specify a percentage of the data to subsample rather than only integer input * changed subsample in adapter.py to allow float as an input for the sample size * renamed subsample_array and associated classes/functions to RandomSubsample and random_subsample respectively * included TypeError to force users to only subsample one dataset at a time * ran linter * rerun formatter * clean up random subsample transform and docs * clean up take transform and docs * nitpick clean-up * skip shape check for subsampled adapter transform inverse * fix serialization of new transforms * skip randomly subsampled key in serialization consistency check --------- Co-authored-by: LarsKue <lars@kuehmichel.de> * [no ci] docs: start of user guide - draft intro, gen models * [no ci] add draft for data processing section * [no ci] user guide: add stub on summary/inference networks * [no ci] user guide: add stub on additional topics * [no ci] add early stage disclaimer to user guide * pin dependencies in docs, fixes snowballstemmer error * fix: correct check for "no accepted samples" in rejection_sample Closes #466 * Stabilize MultivariateNormalScore by constraining initialization in PositiveDefinite link (#469) * Refactor fill_triangular_matrix * stable positive definite link, fix for #468 * Minor changes to docstring * Remove self.built=True that prevented registering layer norm in build() * np -> keras.ops * Augmentation (#470) * Remove old rounds data set, add documentation, and augmentation options to data sets * Enable augmentation to parts of the data or the whole data * Improve doc * Enable augmentations in workflow * Fix silly type check and improve readability of for loop * Bring back num_batches * Fixed log det jac computation of standardize transform y = (x - mu) / sigma log p(y) = log p(x) - log(sigma) * Fix fill_triangular_matrix The two lines were switched, leading to performance degradation. * Deal with inference_network.log_prob to return dict (as PointInferenceNetwork does) * Add diffusion model implementation (#408) This commit contains the following changes (see PR #408 for discussions) - DiffusionModel following the formalism in Kingma et. al (2023) [1] - Stochastic sampler to solve SDEs - Tests for the diffusion model [1] https://arxiv.org/abs/2303.00848 --------- Co-authored-by: arrjon <jonas.arruda@uni-bonn.de> Co-authored-by: Jonas Arruda <69197639+arrjon@users.noreply.github.com> Co-authored-by: LarsKue <lars@kuehmichel.de> * [no ci] networks docstrings: summary/inference network indicator (#462) - From the table in the `bayesflow.networks` module overview, one cannot tell which network belongs to which group. This commit adds short labels to indicate inference networks (IN) and summary networks (SN) * `ModelComparisonSimulator`: handle different outputs from individual simulators (#452) Adds option to drop, fill or error when different keys are encountered in the outputs of different simulators. Fixes #441. --------- Co-authored-by: Valentin Pratz <git@valentinpratz.de> * Add classes and transforms to simplify multimodal training (#473) * Add classes and transforms to simplify multimodal training - Add class `MultimodalSummaryNetwork` to combine multiple summary networks, each for one modality. - Add transforms `Group` and `Ungroup`, to gather the multimodal inputs in one variable (usually "summary_variables") - Add tests for new behavior * [no ci] add tutorial notebook for multimodal data * [no ci] add missing training argument * rename MultimodalSummaryNetwork to FusionNetwork * [no ci] clarify that the network implements late fusion * allow dispatch of summary/inference network from type * add tests for find_network * Add squeeze transform Very basic transform, just the inverse of expand_dims * [no ci] fix examples in ExpandDims docstring * squeeze: adapt example, add comment for changing batch dims * Permit Python version 3.12 (#474) Allow Python version 3.12 after successful CI run: https://github.com/bayesflow-org/bayesflow/actions/runs/14988542031 * Change order in readme and reference new book [skip ci] * make docs optional dependencies compatible with python 3.10 * Add a custom `Sequential` network to avoid issues with building and serialization in keras (#493) * add custom sequential to fix #491 * revert using Sequential in classifier_two_sample_test.py * Add docstring to custom Sequential Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * fix copilot docstring * remove mlp override methods --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Add Nnpe adapter class (#488) * Add NNPE adapter * Add NNPE adapter tests * Only apply NNPE during training * Integrate stage differentiation into tests * Improve test coverage * Fix inverse and add to tests * Adjust class name and add docstring to forward method * Enable compatibility with #486 by adjusting scales automatically * Add dimensionwise noise application * Update exception handling * Fix tests * Align diffusion model with other inference networks and remove deprecation warnings (#489) * Align dm implementation with other networks * Remove deprecation warning for using subnet_kwargs * Fix tests * Remove redundant training arg in get_alpha_sigma and some redundant comments * Fix configs creation - do not get base config due to fixed call of super().__init__() * Remove redundant training arg from tests * Fix dispatch tests for dms * Improve docs and mark option for x prediction in literal * Fix start/stop time * minor cleanup of refactory --------- Co-authored-by: Valentin Pratz <git@valentinpratz.de> * add replace nan adapter (#459) * add replace nan adapter * improved naming * _mask as additional key * update test * improve * fix serializable * changed name to return_mask * add mask naming * [no ci] docs: add basic likelihood estimation example Fixes #476. This is the barebones version showing the technical steps to do likelihood estimation. Adding more background and motivation would be nice. * make metrics serializable It seems that metrics do not store their state, I'm not sure yet if this is intended behavior. * Remove layer norm; add epsilon to std dev for stability of pos def link this breaks serialization of point estimation with MultivariateNormalScore * add end-to-end test for fusion network * fix: ensure that build is called in FusionNetwork * Correctly track train / validation losses (#485) * correctly track train / validation losses * remove mmd from two moons test * reenable metrics in continuous approximator, add trackers * readd custom metrics to two_moons test * take batch size into account when aggregating metrics * Add docs to backend approximator interfaces * Add small doc improvements * Fix typehints to docs. --------- Co-authored-by: Valentin Pratz <git@valentinpratz.de> Co-authored-by: stefanradev93 <stefan.radev93@gmail.com> * Add shuffle parameter to datasets Adds the option to disable data shuffling --------- Co-authored-by: Lars <lars@kuehmichel.de> Co-authored-by: Valentin Pratz <git@valentinpratz.de> * fix: correct vjp/jvp calls in FreeFormFlow The signature changed, making it necessary to set return_output=True * test: add basic compute_metrics test for inference networks * [no ci] extend point approximator tests - remove skip for MVN - add test for log-prob * [no ci] skip unstable MVN sample test again * update README with more specific install instructions * fix FreeFormFlow: remove superfluous index form signature change * [no ci] FreeFormFlow MLP defaults: set dropout to 0 * Better pairplots (#505) * Hacky fix for pairplots * Ensure that target sits in front of other elements * Ensure consistent spacing between plot and legends + cleanup * Update docs * Fix the propagation of `legend_fontsize` * Minor fix to comply with code style * [no ci] Formatting: escaped space only in raw strings * [no ci] fix typo in error message, model comparison approximator * [no ci] fix: size_of could not handle basic int/float Passing in basic types would lead to infinite recursion. Checks for other types than int and float might be necessary as well. * add tests for model comparison approximator * Generalize sample shape to arbitrary N-D arrays * [WIP] Move standardization into approximators and make adapter stateless. (#486) * Add standardization to continuous approximator and test * Fix init bugs, adapt tnotebooks * Add training flag to build_from_data * Fix inference conditions check * Fix tests * Remove unnecessary init calls * Add deprecation warning * Refactor compute metrics and add standardization to model comp * Fix standardization in cont approx * Fix sample keys -> condition keys * amazing keras fix * moving_mean and moving_std still not loading [WIP] * remove hacky approximator serialization test * fix building of models in tests * Fix standardization * Add standardizatrion to model comp and let it use inheritance * make assert_models/layers_equal more thorough * [no ci] use map_shape_structure to convert shapes to arrays This automatically takes care of nested structures. * Extend Standardization to support nested inputs (#501) * extend Standardization to nested inputs By using `keras.tree.flatten` und `keras.tree.pack_sequence_as`, we can support arbitrary nested structures. A `flatten_shape` function is introduced, analogous to `map_shape_structure`, for use in the build function. * keep tree utils in submodule * Streamline call * Fix typehint --------- Co-authored-by: stefanradev93 <stefan.radev93@gmail.com> * Update moments before transform and update test * Update notebooks * Refactor and simplify due to standardize * Add comment for fetching the dict's first item, deprecate logits arg and fix typehint * add missing import in test * Refactor preparation of data for networks and new point_appr.log_prob * ContinuousApproximator._prepare_data unifies all preparation in sample, log_prob and estimate for both ContinuousApproximator and PointApproximator * PointApproximator now overrides log_prob * Add class attributes to inform proper standardization * Implement stable moving mean and std * Adapt and fix tests * minor adaptations to moving average (update time, init) We should put the update before the standardization, to use the maximum amount of information available. We can then also initialize the moving M^2 with zero, as it will be filled immediately. The special case of M^2 = 0 is not problematic, as no variance automatically indicates that all entries are equal, and we can set them to zero (see my comment). I added another test case to cover that case, and added a test for the standard deviation to the existing test. * increase tolerance of allclose tests * [no ci] set trainable to False explicitly in ModelComparisonApproximator * point estimate of covariance compatible with standardization * properly set values to zero if std is zero Cases for inf and -inf were missing * fix sample post-processing in point approximator * activate tests for multivariate normal score * [no ci] undo prev commit: MVN test still not stable, was hidden by std of 0 * specify explicit build functions for approximators * set std for untrained standardization layer to one An untrained layer thereby does not modify the input. * [no ci] reformulate zero std case * approximator builds: add guards against building networks twice * [no ci] add comparison with loaded approx to workflow test * Cleanup and address building standardization layers when None specified * Cleanup and address building standardization layers when None specified 2 * Add default case for std transform and add transformation to doc. * adapt handling of the special case M^2=0 * [no ci] minor fix in concatenate_valid_shapes * [no ci] extend test suite for approximators * fixes for standardize=None case * skip unstable MVN score case * Better transformation types * Add test for both_sides_scale inverse standardization * Add test for left_side_scale inverse standardization * Remove flaky test failing due to sampling error * Fix input dtypes in inverse standardization transformation_type tests * Use concatenate_valid in _sample * Replace PositiveDefinite link with CholeskyFactor This finally makes the MVN score sampling test stable for the jax backend, for which the keras.ops.cholesky operation is numerically unstable. The score's sample method avoids calling keras.ops.cholesky to resolve the issue. Instead the estimation head returns the Cholesky factor directly rather than the covariance matrix (as it used to be). * Reintroduce test sampling with MVN score * Address TODOs and adapt docstrings and workflow * Adapt notebooks * Fix in model comparison * Update readme and add point estimation nb --------- Co-authored-by: LarsKue <lars@kuehmichel.de> Co-authored-by: Valentin Pratz <git@valentinpratz.de> Co-authored-by: Valentin Pratz <112951103+vpratz@users.noreply.github.com> Co-authored-by: han-ol <g@hans.olischlaeger.com> Co-authored-by: Hans Olischläger <106988117+han-ol@users.noreply.github.com> * Replace deprecation with FutureWarning * Adjust filename for LV * Fix types for subnets * [no ci] minor fixes to RandomSubsample transform * [no ci] remove subnet deprecation in cont-time CM * Remove empty file [no ci] * Revert layer type for coupling flow [skip ci] * remove failing import due to removed find_noise_schedule.py [no ci] * Add utility function for batched simulations (#511) The implementation is a simple wrapper leveraging the batching capabilities of `rejection_sample`. * Restore PositiveDefinite link with deprecation warning * skip cycle consistency test for diffusion models - the test is unstable for untrained diffusion models, as the networks output is not sufficiently smooth for the step size we use - remove the diffusion_model marker * Implement changes to NNPE adapter for #510 (#514) * Move docstring to comment * Always cast to _resolve_scale * Fix typo * [no ci] remove unnecessary serializable decorator on rmse * fix type hint in squeeze [no ci] * reintroduce comment in jax approximator [no ci] * remove unnecessary getattr calls [no ci] * Rename local variable transformation_type * fix error type in diffusion model [no ci] * remove non-functional per_training_step from plots.loss * Update doc [skip ci] * rename approximator.summaries to summarize with deprecation * address remaining comments --------- Co-authored-by: Leona Odole <88601208+eodole@users.noreply.github.com> Co-authored-by: LarsKue <lars@kuehmichel.de> Co-authored-by: Valentin Pratz <git@valentinpratz.de> Co-authored-by: Hans Olischläger <106988117+han-ol@users.noreply.github.com> Co-authored-by: han-ol <g@hans.olischlaeger.com> Co-authored-by: Valentin Pratz <112951103+vpratz@users.noreply.github.com> Co-authored-by: arrjon <jonas.arruda@uni-bonn.de> Co-authored-by: Jonas Arruda <69197639+arrjon@users.noreply.github.com> Co-authored-by: Simon Kucharsky <kucharssim@gmail.com> Co-authored-by: Daniel Habermann <133031176+daniel-habermann@users.noreply.github.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: Lasse Elsemüller <60779710+elseml@users.noreply.github.com> Co-authored-by: Jerry Huang <57327805+jerrymhuang@users.noreply.github.com>
fixes #441
As per #441 (comment), this PR implements
However, by default the simulator will just drop (with an info warning) keys that are not common for all simulators, since in most situations we would not need those outputs in the first place.