-
Notifications
You must be signed in to change notification settings - Fork 72
[WIP] Move standardization into approximators and make adapter stateless. #486
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
71 commits
Select commit
Hold shift + click to select a range
ceab303
Add standardization to continuous approximator and test
stefanradev93 d79b17a
Fix init bugs, adapt tnotebooks
stefanradev93 c777122
Add training flag to build_from_data
stefanradev93 7aeb9cb
Fix inference conditions check
stefanradev93 45ab9ea
Fix tests
stefanradev93 a83770a
Remove unnecessary init calls
stefanradev93 4df270a
Add deprecation warning
stefanradev93 8ea6782
Refactor compute metrics and add standardization to model comp
stefanradev93 b2a4f76
Fix standardization in cont approx
stefanradev93 deffc27
Fix sample keys -> condition keys
stefanradev93 43af4bd
amazing keras fix
LarsKue 039fc8d
moving_mean and moving_std still not loading [WIP]
stefanradev93 02ded97
remove hacky approximator serialization test
LarsKue 54d860e
fix building of models in tests
LarsKue 2a86cc3
Fix standardization
stefanradev93 1df9269
Add standardizatrion to model comp and let it use inheritance
stefanradev93 49af469
make assert_models/layers_equal more thorough
LarsKue 1fdde32
Merge remote-tracking branch 'origin/standardize_in_approx' into stan…
LarsKue 0869e3f
[no ci] use map_shape_structure to convert shapes to arrays
vpratz 1a845e3
Merge remote-tracking branch 'upstream/dev' into standardize_in_approx
vpratz bd2725d
Extend Standardization to support nested inputs (#501)
vpratz c5fb949
Update moments before transform and update test
stefanradev93 100d7c0
Update notebooks
stefanradev93 905bf05
Merge dev into branch
stefanradev93 38f2228
Refactor and simplify due to standardize
stefanradev93 0c24db2
Add comment for fetching the dict's first item, deprecate logits arg …
stefanradev93 5755135
Merge remote-tracking branch 'upstream/dev' into standardize_in_approx
vpratz 4fa1bbb
add missing import in test
vpratz b2bfeea
Refactor preparation of data for networks and new point_appr.log_prob
han-ol 5773d28
Merge branch 'standardize_in_approx' of https://github.com/bayesflow-…
stefanradev93 392d9f7
Add class attributes to inform proper standardization
han-ol 2d5b2fb
Implement stable moving mean and std
stefanradev93 bde587c
Merge and add incremental moments
stefanradev93 1b2b5be
Adapt and fix tests
stefanradev93 d406a29
minor adaptations to moving average (update time, init)
vpratz a503bd9
increase tolerance of allclose tests
vpratz caf0491
[no ci] set trainable to False explicitly in ModelComparisonApproximator
vpratz dd24941
Merge branch 'standardize_in_approx' of https://github.com/bayesflow-…
stefanradev93 8268128
point estimate of covariance compatible with standardization
han-ol e32ae2e
properly set values to zero if std is zero
vpratz b7d6c0e
fix sample post-processing in point approximator
vpratz 00d72ab
activate tests for multivariate normal score
vpratz c2ebd23
[no ci] undo prev commit: MVN test still not stable, was hidden by st…
vpratz cd45b85
specify explicit build functions for approximators
vpratz 3f28f34
Merge remote-tracking branch 'upstream/dev' into standardize_in_approx
vpratz 0952a29
set std for untrained standardization layer to one
vpratz 5c529a2
[no ci] reformulate zero std case
vpratz 399a1b4
approximator builds: add guards against building networks twice
vpratz dd0dc87
[no ci] add comparison with loaded approx to workflow test
vpratz d28df75
Cleanup and address building standardization layers when None specified
stefanradev93 40d2d1d
Cleanup and address building standardization layers when None specifi…
stefanradev93 c6d79ae
Add default case for std transform and add transformation to doc.
stefanradev93 df1761b
adapt handling of the special case M^2=0
vpratz 3b93251
[no ci] minor fix in concatenate_valid_shapes
vpratz 65cac46
Merge remote-tracking branch 'upstream/dev' into standardize_in_approx
vpratz 1944186
[no ci] extend test suite for approximators
vpratz 1ebf1cd
fixes for standardize=None case
vpratz 71cd6b9
skip unstable MVN score case
vpratz 3f0f9d1
Better transformation types
han-ol a3b59c3
Add test for both_sides_scale inverse standardization
han-ol 183f608
Add test for left_side_scale inverse standardization
han-ol f0de38b
Remove flaky test failing due to sampling error
han-ol 43ced5b
Fix input dtypes in inverse standardization transformation_type tests
han-ol c3e945e
Merge branch 'dev' into standardize_in_approx
han-ol 82e28a7
Use concatenate_valid in _sample
han-ol ef97a6c
Replace PositiveDefinite link with CholeskyFactor
han-ol 24c268b
Reintroduce test sampling with MVN score
han-ol e45f260
Address TODOs and adapt docstrings and workflow
stefanradev93 333c30f
Adapt notebooks
stefanradev93 48bb190
Fix in model comparison
stefanradev93 fd83567
Update readme and add point estimation nb
stefanradev93 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it would be nice to have a convenience function that calculates mean and std for a dataset, in the format that would be required here. We could also advertise it in the deprecation warning. What do you think?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed, but such a function will not be very efficient when the entire data set is not (yet) in memory. I see its use mainly for
OfflineDataset
.