-
Notifications
You must be signed in to change notification settings - Fork 206
Simformer #1621
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
nMaax
wants to merge
299
commits into
sbi-dev:main
Choose a base branch
from
nMaax:simformer
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+5,911
−662
Open
Simformer #1621
Changes from all commits
Commits
Show all changes
299 commits
Select commit
Hold shift + click to select a range
14262f1
Typing with vectorfield net
manuelgloeckler cf762ea
Simplify score estimator
manuelgloeckler 2839aec
Updates
manuelgloeckler 151973a
Fixing transformer with cross attn
manuelgloeckler b23418b
Add error msg for unsupported shapes
manuelgloeckler 689a727
Better tests
manuelgloeckler 0b3dc40
refactored tests
manuelgloeckler 49f8caf
Reverting wierd reshapings in score estimator.
manuelgloeckler 66a68eb
Merge branch 'main' into merge_flow_builders_to_current_main
manuelgloeckler 0293383
Fix formating issues
manuelgloeckler c585eb1
Fixing inconsistencies
manuelgloeckler 42c1d56
Fixing pyright
manuelgloeckler 024f54e
Fix embedding_net not passed
manuelgloeckler 59e35a0
Fix embedding net bug
manuelgloeckler f1cf710
Remove redundant "num_blocks"
manuelgloeckler bed97b6
Merge branch 'main' into merge_flow_builders_to_current_main
manuelgloeckler a0fb73d
Adding some degree of backward compatibility on user interface.
manuelgloeckler 387f965
Fixing failing test on new convergence check
manuelgloeckler fe00c02
Add transformer to bm
manuelgloeckler 1c00bcd
Must be okay that the files already exits bm
manuelgloeckler 27a9c3a
Fix merge bug. Add deprecation warnings for Score estimator keyword a…
manuelgloeckler a9445ab
Fixing transformers... (no pos emb. and others)
manuelgloeckler 0dca35b
Refactorings and tunings
manuelgloeckler 25af448
deprecation warnings and small refactorings
manuelgloeckler 2eef0c6
Backwards compatibility
manuelgloeckler 587eda3
Move score_estimator tests to vf_estimator_tests, run doc notebook once
manuelgloeckler af6acd1
remove random wierd comment
manuelgloeckler 4e15f02
Remove tolerance special cases
manuelgloeckler 27dda74
Consistent naming
manuelgloeckler 9215cfa
Added main simformer neural net architecture
nMaax b42b767
Added masked conditional vector field and score estimator classes
nMaax a1ac170
Added Simformer inference classes
nMaax 536a232
Added cpu() device move over ensure_numpy plot helper function
nMaax 56234fc
Added tests for Simformer and related classes
nMaax d18696a
Merge remote-tracking branch 'upstream/merge_flow_builders_to_current…
nMaax b4a4b04
Fixed overall simformer implementation and tests to adapt to recent m…
nMaax da5a5b6
Increased simformer on linear gaussian num of simulations for test co…
nMaax ed1902c
Faster convergence for slighly worse performance
manuelgloeckler 870a754
Backward compatibility for imports of NPSE and FMPE
manuelgloeckler 8e2ed99
Docstring update
manuelgloeckler 1e073e3
Imporve docstrings
manuelgloeckler 1fb0664
Backward compatibility
manuelgloeckler 0a04880
Use new keywords
manuelgloeckler 4156549
Merge branch 'main' into merge_flow_builders_to_current_main
manuelgloeckler 39c3353
Cleaned mvfe wrapper for generalized num trials shapes
nMaax 98ebd43
Minor bugfix and cleaned simformer linear guassian tests
nMaax c364c74
Generalized condition mask handling, now variable index can be passed…
nMaax b43b793
Added mvfe wrapper tests, re-factored masked tests
nMaax 595e062
Shape fix on wrapper for multiple num_trials, minor fixes
nMaax d005ee5
Generalized conditiong variables idx at Simformer inference
nMaax d6d2908
Moved condition indexes to Masked Neural Inference
nMaax dd1a6d0
Fixed minor bug on condition indexes handling
nMaax bb10046
Enforced typing and converting lists to tensors in indexes setter for…
nMaax c7b21a6
Added note on simformer iid inference test for future debugging (when…
nMaax e3640bc
Added note on simformer iid inference test for future fine-tuning of …
nMaax 98861ae
Skipping iid_sampling over Simformer
nMaax 550f7d1
Removed automatic masks generations in score estimator loss, improved…
nMaax 5f4ad34
Handling invalid inputs as latent
nMaax 0a23c65
Improved loss shape handling and error messages in masked score estim…
nMaax 39ab03e
Renamed num_blocks as num_layers for coherence
nMaax 73afe93
Renamed num_blocks as num_layers for coherence withing internals of s…
nMaax 0679c22
Merge branch 'main' into merge_flow_builders_to_current_main
manuelgloeckler 342da44
Format
manuelgloeckler e9e6f64
Add missing headers
manuelgloeckler ddc4309
Refactored Simformer: condition and edge mask are generated at traini…
nMaax d554a08
Refactor mask generator type hints and ensure masks are moved to devi…
nMaax d74e3be
Re-added nan and Inf handling for refactored Simformer
nMaax f22a28f
Refactored ode_fn for None edge masks
nMaax 8124a62
Adapted tests on Simformer after refactor on mask handling
nMaax b5ffddc
Adapted build_conditional and related methods over last NeuralInferen…
nMaax ab19628
Refactored Variance Exploding Score Estimator inheritance using a sha…
nMaax ab6242b
Refactored VP and subVP Score Estimators using Mixin as done with VE,…
nMaax de6a86d
Refactored DiT to act as MaskedDiT (which has been removed), minor fi…
nMaax 4983514
Tuned params for linear gaussian inference on Simformer to optimize t…
nMaax ee09291
Improved and cleaned DiT and Time Additive transformer blocks
nMaax 02d4480
Cleaned typing and docstrings for vector field simformer net and scor…
nMaax 30b8be9
Provided Flow Matching Simformer
nMaax 62e82d8
Updated docstring API reference in inference trainers for Simformer
nMaax e2877d9
Fixed linear gaussian test on Simformer, now calls correct latent and…
nMaax 14a6650
Updated docstring API reference in neural_nets amd inference trainers…
nMaax 4e220ba
Documentation for Simformer
nMaax 4c483b1
Update sbi/inference/trainers/vfpe/base_vf_inference.py
manuelgloeckler 7e272b8
Update sbi/inference/trainers/fmpe/__init__.py
manuelgloeckler 6b99778
Update sbi/inference/trainers/npse/__init__.py
manuelgloeckler 427f3c0
Update sbi/inference/trainers/vfpe/fmpe.py
manuelgloeckler 5d03a84
Update sbi/inference/trainers/vfpe/fmpe.py
manuelgloeckler a5e216e
Update sbi/inference/trainers/vfpe/npse.py
manuelgloeckler f4ca1f8
Update sbi/inference/trainers/vfpe/npse.py
manuelgloeckler 0a09f67
Update sbi/neural_nets/__init__.py
manuelgloeckler 099ad7d
Update sbi/neural_nets/estimators/flowmatching_estimator.py
manuelgloeckler 663088e
Update sbi/neural_nets/estimators/score_estimator.py
manuelgloeckler 43f21be
Update sbi/neural_nets/factory.py
manuelgloeckler dfd51e4
Update sbi/neural_nets/factory.py
manuelgloeckler c2bd2d3
Update sbi/neural_nets/net_builders/vector_field_nets.py
manuelgloeckler 8edbfaa
Update sbi/neural_nets/net_builders/vector_field_nets.py
manuelgloeckler 846debb
Update sbi/neural_nets/factory.py
manuelgloeckler 09cd09e
Update tests/bm_test.py
manuelgloeckler 8c613ce
Update tests/bm_test.py
manuelgloeckler f8a8b50
Update tests/bm_test.py
manuelgloeckler 9276a78
Update tests/bm_test.py
manuelgloeckler cc97264
Merge branch 'main' into merge_flow_builders_to_current_main
manuelgloeckler 457a0aa
Merge remote-tracking branch 'origin/main' into merge_flow_builders_t…
manuelgloeckler fab3c62
Add nugget as keyward argument to train
manuelgloeckler a1d9b98
Imporve converged docstring
manuelgloeckler 1ced05a
Better typing and docstrings and so on
manuelgloeckler 34f396f
docstring
manuelgloeckler 4b09869
Add context
manuelgloeckler 06b84a5
Extended docstring
manuelgloeckler 7eeaa7d
move protocol
manuelgloeckler 57fdf39
Formating fix
manuelgloeckler f661a05
Revert "move protocol"
manuelgloeckler 50f7443
fix formating
manuelgloeckler 5cb49b7
removing deprecated
manuelgloeckler 3caf37a
Fix typing
manuelgloeckler 4ea5e3a
Positional argument for default builder model name
manuelgloeckler 7f281cd
Formating
manuelgloeckler 4f88303
unify nets test
manuelgloeckler 1691a63
fixing builder
manuelgloeckler 77ba7a6
update notebooks
manuelgloeckler 10f61b6
Formating and some text updates
manuelgloeckler 314f45e
formating
manuelgloeckler ccee887
Fixed passing sde_type to builders, minor fixes and clean up
nMaax c69a74b
Re-adjusted test params for different SDE
nMaax b1eb2eb
Fixed flow matching simformer noising only latent input
nMaax 495458a
Improved condition mask generation to manage only fully observed samp…
nMaax 606db0f
Re-formated Flow Matching Simformer and Masked VF Inference interface…
nMaax 70f52bd
Fixed gpu use in linear gaussian tests on simformer
nMaax 8820c35
Removed slow and gpu marks on flow simformer linear gaussian tests
nMaax 170872d
Partially refactored NeuralInference classes with new Mixin
nMaax d2228b0
Moved simulations handling and train functions to Base Neural Inferen…
nMaax 04a971b
Detached Masked Neural Inference from Neural Inference, now inheritin…
nMaax 16eae9f
Adapted new NeuralPosterior build_conditional to Simformer
nMaax c0ceecf
Further simplified resolve estimator and resolve prior helper functio…
nMaax 85a11d6
Moved _create_posterior to parent Base Neural Inference Mixin
nMaax cc8fae4
Minor comments adjustments
nMaax 14f2795
Completely removed inf and nan data handling from Masked Interfaces a…
nMaax 811ca3a
Provided handling of nan/inf values for simformer, still presenting p…
nMaax 69dea89
Added default exclude invalid x parameter for handling invalid inputs…
nMaax 8cf0566
Addressed invalid inputs to ImproperEmprical prior to make Simformer …
nMaax 146f56e
Re-ordered BaseNeuralInference methods and properties
nMaax 77b4139
fix deprecation warning on default args
janfb b8d2ffb
Solved redundant code in init for Masked and non-Masked Neural Inference
nMaax b07d188
unnecessary
manuelgloeckler ebb8706
remove unecessary notes
manuelgloeckler d449e46
refactor check for deprecation warning
janfb 109f6be
fix mcmc params passing in test
janfb fca9e5a
Fix mnle_test
manuelgloeckler b646608
Merge branch 'main' into merge_flow_builders_to_current_main
janfb d1c1346
add missing import
janfb 0a72a41
Further cleaned docstring of NeuralInference and minor re-factoring
nMaax 89eecf0
Merge remote-tracking branch 'upstream/merge_flow_builders_to_current…
nMaax 943ac82
Fixed missing posterior_parameter use for masked interfaces and simfo…
nMaax 22eeef4
Renamed posterior parameters for simformer build conditional and like…
nMaax 1b83d0f
Restored max num epoch of linear gaussian over simformer vp 1 ndims t…
nMaax f07f9b5
Merge remote-tracking branch 'upstream/main' into simformer
nMaax 49ba6bf
Re-factored simformer vector field net test code
nMaax fcce081
Cleaned docstring for API reference
nMaax 6134df7
Improved simformer advanced tutorial notebook with later modifications
nMaax ba69658
Cleaned linear gaussian test on Simformer
nMaax bd1485a
Removed optional typing in neural net definition for NeuralInference …
nMaax f5208ed
Merge remote-tracking branch 'upstream/main' into simformer
nMaax 28494cb
Extended Simformer and tests to handle non 3-dimensional data
nMaax 21fcc18
Specified typing of neural net and prior attributes in NLE and NPE cl…
nMaax ad4f05e
Re-organized linear gaussian tests on simformer, capped number of epo…
nMaax 40ab9ed
Updated simformer tutorial with later capacity of accepting 2-dimensi…
nMaax c15d8dd
Changed edge mask generator callable definition over condition mask r…
nMaax 5318dc3
Refactor dataloader methods and convergence checks in BaseNeuralInfer…
nMaax a719d72
Defined default values for MaskedNeuralInference build conditional
nMaax 0a40736
Refactored use of prior in NeuralInference methods, introduced NoPrio…
nMaax 6b7a4a2
Moved NoPrior() to sbi utils
nMaax 8a899e4
Refactored FlowMatchingSimformer as child of MaskedVectorFieldTrainer…
nMaax a665045
Minor change in FlowMatchingSimformer init docstring
nMaax 3bf4cd6
BaseNeuralInference prior attribute refactoring using NoPrior for avo…
nMaax 71d1f0b
Re-factored ConditionalVectorFieldEstimator and Masked counterpart us…
nMaax bdae1e7
Removed useless exclude invalid x attribute from Vector Field Trainer
nMaax 230bfce
Fixed NoPrior to return random noise with correct shapes. Refactor pr…
nMaax ed58502
Handled linitng typing in missspecification given new Neural Inferenc…
nMaax 480cd7e
Refactor neural network attributes in NeuralInference and MaskedNeura…
nMaax 591cc75
Added loss_proposal_posterior equivalent for masked vf trainer
nMaax 9bf7a91
Renamed attribute masks in prior_masks in vf trainer _loss for clarity
nMaax bc955ac
Removing prior from kwargs in masked vf trainer in case user mistaken…
nMaax 6a7fd5d
Setting use of DiT Blocks as default in Simformer
nMaax 15a1db7
Refactored posterior_nn test functions to avoid duplicate code for Si…
nMaax 91196a4
Removed depcrated simformer unit test in posterior_nn tests
nMaax 681bda2
Generalized Masked classes and Flow Simformer to feature-absent data,…
nMaax a601d33
Removed has_sample_dim check in mean and std functions of MaskedCondi…
nMaax 6b33f82
Introduced Simformer and FlowMatchingSimformer to linear gaussian dif…
nMaax c3fe087
Adapted vector field linear gaussian test for simformer
nMaax 9092297
Revert "Introduced Simformer and FlowMatchingSimformer to linear gaus…
nMaax 988a9f0
Marked Simformer vp and subvp linear gaussian tests as gpu and slow
nMaax d759e15
Increased default hidden features for Simformer to 128
nMaax 92f5748
Avoding dkl gaussian prior check for Simformer in linear gaussian tests
nMaax b07c6d5
Removed simformer dedicated linear gaussian test as it has been merge…
nMaax 15904ab
Adapted gpu linear gaussian tests to proper device handling
nMaax 4e59cfd
Adapted map test on vector fields to include simformer
nMaax f6d292c
Revert "Generalized Masked classes and Flow Simformer to feature-abse…
nMaax 2f70bd7
Removed simformer dedicated linear gaussian map test as it has been m…
nMaax c119a5d
Minor fix on input shape for forward pass in Masked Conditional Score…
nMaax a1c78a6
Adapated masked conditional score estimator to natively use 2-dim ten…
nMaax 89b55ee
Also adapated masked flow matching estimator to natively use 2-dim te…
nMaax fad4564
Simplifies rebalance loss in masked conditional score estimator
nMaax 82b7460
Re-introduced 3-dim specific loss handling in rebalance loss for Mask…
nMaax b08008e
Improved shape handling for masked flow matching estimator loss pass
nMaax aadda6f
Added and adapted test on masked flow vf estimator
nMaax 22f43d9
Added input shape 2-dim tests on vf estimator
nMaax 3c4e566
Adapted Masked Conditional VF Estimator Wrapper to handle 2-dim input…
nMaax 8620911
Adapted tests on Masked Conditional VF Estimator Wrapper to 2-dim inputs
nMaax b640e7d
Runned simformer advanced tutorial, stripped output
nMaax 63eec83
Minor docstring change for Flow Matching Simformer
nMaax f7381ae
Provided Simformer in benchmarks
nMaax 467fdb8
Removed reference to device in benchmarks for Simformer
nMaax 92d4b91
Improved mask generator function of masked neural inference to proces…
nMaax aa77cb3
Introduced device attribute for NoPrior
nMaax 0ce778d
Checking at least 2d on log prob of NoPrior
nMaax d8253be
Update NoPrior batch and even shape at append simulation if none has …
nMaax ee1ff2a
Removed wrong dimension in condition reshaping/repetition for _assemb…
nMaax b56fdd5
Added comments for assemble and disassemble full inputs helper functi…
nMaax 018c301
Improved overall MVF Wrapper comments for clarity
nMaax eb73ffe
Adapted samples rejection logic for NoPrior poor shape informatiom
nMaax 6bf3111
Added Siformer and FlowMatchingSimformer to default benchmark run
nMaax d7ee92a
Removed useless unsqueeze for benchmark on Simformer
nMaax 91738a2
Improved references to Simformer paper and minor docstrings
nMaax 666c097
Minor docustring improvement
nMaax e4d62fb
Removed sde_type from FlowMatchingSimformer as unused
nMaax 2d4fe62
Minor docstring improvement
nMaax ac521ed
Further added comments and docstring into MVF Wrapper
nMaax 51c2b9a
Adapted Wrapper and Posterior to handle degenerate case of full laten…
nMaax 578bdd8
Marked two linear gaussian tests simformer as slow to speed up time o…
nMaax 43fa23a
Further marked other linear gaussian tests simformer as slow to speed…
nMaax 8a5b2c4
Improved comments and typing, removed parameter for progress bar as i…
nMaax a83b674
Ensuring tensor on device for VE sde drift fn
nMaax 0528680
Running simformer fixtures on gpu
nMaax db24a5e
Ensuring masks on the same device as input in Simformer network, rais…
nMaax 6997551
Moved ode/sde equivalence and iid linear gaussian checks simformer fi…
nMaax 2472a47
Separated Simformer and FlowSimformer benchmark cases
nMaax 6c8c64c
Revert "Adapted samples rejection logic for NoPrior poor shape inform…
nMaax d86fc69
Revert "Update NoPrior batch and even shape at append simulation if n…
nMaax 0ba5863
Introduced support constraint real in NoPrior, fix errors in shape ch…
nMaax a646c25
Merge remote-tracking branch 'upstream/main' into simformer
nMaax eee967c
Re-ordered params for build posterior and build likelihood in Simform…
nMaax 5fff8db
Attempted to fix bug in iid lin gauss tests over simformer
nMaax 90d3d5e
Re-inserted check on calc misspecificaton mmd
nMaax 08aeb73
minor typo and docs fixes
janfb 394a48d
Revert advanced tutorials counters modifications
nMaax 6887172
Revert advanced tutorials counters modifications again
nMaax f183ff2
Sorted inference init imports alphabetically
nMaax 18c6e52
Avoided assert for an if-else statement instead
nMaax b9af9ae
Avoided assert for an if-else statement
nMaax 4ea0c03
Properly typed ada_time flag for Simformer Architecture as bool inste…
nMaax c7e9f5f
Merge remote-tracking branch 'upstream/main' into simformer
nMaax File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not sorted alphabetically, but maybe that's fine, it seems to be sorted semantically.