Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
299 commits
Select commit Hold shift + click to select a range
14262f1
Typing with vectorfield net
manuelgloeckler Jul 3, 2025
cf762ea
Simplify score estimator
manuelgloeckler Jul 3, 2025
2839aec
Updates
manuelgloeckler Jul 3, 2025
151973a
Fixing transformer with cross attn
manuelgloeckler Jul 9, 2025
b23418b
Add error msg for unsupported shapes
manuelgloeckler Jul 9, 2025
689a727
Better tests
manuelgloeckler Jul 9, 2025
0b3dc40
refactored tests
manuelgloeckler Jul 9, 2025
49f8caf
Reverting wierd reshapings in score estimator.
manuelgloeckler Jul 9, 2025
66a68eb
Merge branch 'main' into merge_flow_builders_to_current_main
manuelgloeckler Jul 9, 2025
0293383
Fix formating issues
manuelgloeckler Jul 9, 2025
c585eb1
Fixing inconsistencies
manuelgloeckler Jul 9, 2025
42c1d56
Fixing pyright
manuelgloeckler Jul 9, 2025
024f54e
Fix embedding_net not passed
manuelgloeckler Jul 9, 2025
59e35a0
Fix embedding net bug
manuelgloeckler Jul 9, 2025
f1cf710
Remove redundant "num_blocks"
manuelgloeckler Jul 9, 2025
bed97b6
Merge branch 'main' into merge_flow_builders_to_current_main
manuelgloeckler Jul 11, 2025
a0fb73d
Adding some degree of backward compatibility on user interface.
manuelgloeckler Jul 11, 2025
387f965
Fixing failing test on new convergence check
manuelgloeckler Jul 11, 2025
fe00c02
Add transformer to bm
manuelgloeckler Jul 15, 2025
1c00bcd
Must be okay that the files already exits bm
manuelgloeckler Jul 15, 2025
27a9c3a
Fix merge bug. Add deprecation warnings for Score estimator keyword a…
manuelgloeckler Jul 15, 2025
a9445ab
Fixing transformers... (no pos emb. and others)
manuelgloeckler Jul 15, 2025
0dca35b
Refactorings and tunings
manuelgloeckler Jul 15, 2025
25af448
deprecation warnings and small refactorings
manuelgloeckler Jul 15, 2025
2eef0c6
Backwards compatibility
manuelgloeckler Jul 15, 2025
587eda3
Move score_estimator tests to vf_estimator_tests, run doc notebook once
manuelgloeckler Jul 15, 2025
af6acd1
remove random wierd comment
manuelgloeckler Jul 15, 2025
4e15f02
Remove tolerance special cases
manuelgloeckler Jul 15, 2025
27dda74
Consistent naming
manuelgloeckler Jul 15, 2025
9215cfa
Added main simformer neural net architecture
nMaax Jul 15, 2025
b42b767
Added masked conditional vector field and score estimator classes
nMaax Jul 15, 2025
a1ac170
Added Simformer inference classes
nMaax Jul 15, 2025
536a232
Added cpu() device move over ensure_numpy plot helper function
nMaax Jul 15, 2025
56234fc
Added tests for Simformer and related classes
nMaax Jul 15, 2025
d18696a
Merge remote-tracking branch 'upstream/merge_flow_builders_to_current…
nMaax Jul 15, 2025
b4a4b04
Fixed overall simformer implementation and tests to adapt to recent m…
nMaax Jul 15, 2025
da5a5b6
Increased simformer on linear gaussian num of simulations for test co…
nMaax Jul 15, 2025
ed1902c
Faster convergence for slighly worse performance
manuelgloeckler Jul 16, 2025
870a754
Backward compatibility for imports of NPSE and FMPE
manuelgloeckler Jul 16, 2025
8e2ed99
Docstring update
manuelgloeckler Jul 16, 2025
1e073e3
Imporve docstrings
manuelgloeckler Jul 16, 2025
1fb0664
Backward compatibility
manuelgloeckler Jul 16, 2025
0a04880
Use new keywords
manuelgloeckler Jul 16, 2025
4156549
Merge branch 'main' into merge_flow_builders_to_current_main
manuelgloeckler Jul 16, 2025
39c3353
Cleaned mvfe wrapper for generalized num trials shapes
nMaax Jul 17, 2025
98ebd43
Minor bugfix and cleaned simformer linear guassian tests
nMaax Jul 17, 2025
c364c74
Generalized condition mask handling, now variable index can be passed…
nMaax Jul 17, 2025
b43b793
Added mvfe wrapper tests, re-factored masked tests
nMaax Jul 17, 2025
595e062
Shape fix on wrapper for multiple num_trials, minor fixes
nMaax Jul 17, 2025
d005ee5
Generalized conditiong variables idx at Simformer inference
nMaax Jul 18, 2025
d6d2908
Moved condition indexes to Masked Neural Inference
nMaax Jul 18, 2025
dd1a6d0
Fixed minor bug on condition indexes handling
nMaax Jul 18, 2025
bb10046
Enforced typing and converting lists to tensors in indexes setter for…
nMaax Jul 18, 2025
c7b21a6
Added note on simformer iid inference test for future debugging (when…
nMaax Jul 19, 2025
e3640bc
Added note on simformer iid inference test for future fine-tuning of …
nMaax Jul 19, 2025
98861ae
Skipping iid_sampling over Simformer
nMaax Jul 19, 2025
550f7d1
Removed automatic masks generations in score estimator loss, improved…
nMaax Jul 21, 2025
5f4ad34
Handling invalid inputs as latent
nMaax Jul 22, 2025
0a23c65
Improved loss shape handling and error messages in masked score estim…
nMaax Jul 22, 2025
39ab03e
Renamed num_blocks as num_layers for coherence
nMaax Jul 23, 2025
73afe93
Renamed num_blocks as num_layers for coherence withing internals of s…
nMaax Jul 23, 2025
0679c22
Merge branch 'main' into merge_flow_builders_to_current_main
manuelgloeckler Jul 23, 2025
342da44
Format
manuelgloeckler Jul 23, 2025
e9e6f64
Add missing headers
manuelgloeckler Jul 24, 2025
ddc4309
Refactored Simformer: condition and edge mask are generated at traini…
nMaax Jul 26, 2025
d554a08
Refactor mask generator type hints and ensure masks are moved to devi…
nMaax Jul 26, 2025
d74e3be
Re-added nan and Inf handling for refactored Simformer
nMaax Jul 26, 2025
f22a28f
Refactored ode_fn for None edge masks
nMaax Jul 26, 2025
8124a62
Adapted tests on Simformer after refactor on mask handling
nMaax Jul 26, 2025
b5ffddc
Adapted build_conditional and related methods over last NeuralInferen…
nMaax Jul 27, 2025
ab19628
Refactored Variance Exploding Score Estimator inheritance using a sha…
nMaax Jul 27, 2025
ab6242b
Refactored VP and subVP Score Estimators using Mixin as done with VE,…
nMaax Jul 27, 2025
de6a86d
Refactored DiT to act as MaskedDiT (which has been removed), minor fi…
nMaax Jul 27, 2025
4983514
Tuned params for linear gaussian inference on Simformer to optimize t…
nMaax Jul 27, 2025
ee09291
Improved and cleaned DiT and Time Additive transformer blocks
nMaax Jul 28, 2025
02d4480
Cleaned typing and docstrings for vector field simformer net and scor…
nMaax Jul 28, 2025
30b8be9
Provided Flow Matching Simformer
nMaax Jul 28, 2025
62e82d8
Updated docstring API reference in inference trainers for Simformer
nMaax Jul 28, 2025
e2877d9
Fixed linear gaussian test on Simformer, now calls correct latent and…
nMaax Jul 29, 2025
14a6650
Updated docstring API reference in neural_nets amd inference trainers…
nMaax Jul 29, 2025
4e220ba
Documentation for Simformer
nMaax Jul 31, 2025
4c483b1
Update sbi/inference/trainers/vfpe/base_vf_inference.py
manuelgloeckler Aug 7, 2025
7e272b8
Update sbi/inference/trainers/fmpe/__init__.py
manuelgloeckler Aug 7, 2025
6b99778
Update sbi/inference/trainers/npse/__init__.py
manuelgloeckler Aug 7, 2025
427f3c0
Update sbi/inference/trainers/vfpe/fmpe.py
manuelgloeckler Aug 7, 2025
5d03a84
Update sbi/inference/trainers/vfpe/fmpe.py
manuelgloeckler Aug 7, 2025
a5e216e
Update sbi/inference/trainers/vfpe/npse.py
manuelgloeckler Aug 7, 2025
f4ca1f8
Update sbi/inference/trainers/vfpe/npse.py
manuelgloeckler Aug 7, 2025
0a09f67
Update sbi/neural_nets/__init__.py
manuelgloeckler Aug 7, 2025
099ad7d
Update sbi/neural_nets/estimators/flowmatching_estimator.py
manuelgloeckler Aug 7, 2025
663088e
Update sbi/neural_nets/estimators/score_estimator.py
manuelgloeckler Aug 7, 2025
43f21be
Update sbi/neural_nets/factory.py
manuelgloeckler Aug 7, 2025
dfd51e4
Update sbi/neural_nets/factory.py
manuelgloeckler Aug 7, 2025
c2bd2d3
Update sbi/neural_nets/net_builders/vector_field_nets.py
manuelgloeckler Aug 7, 2025
8edbfaa
Update sbi/neural_nets/net_builders/vector_field_nets.py
manuelgloeckler Aug 7, 2025
846debb
Update sbi/neural_nets/factory.py
manuelgloeckler Aug 7, 2025
09cd09e
Update tests/bm_test.py
manuelgloeckler Aug 7, 2025
8c613ce
Update tests/bm_test.py
manuelgloeckler Aug 7, 2025
f8a8b50
Update tests/bm_test.py
manuelgloeckler Aug 7, 2025
9276a78
Update tests/bm_test.py
manuelgloeckler Aug 7, 2025
cc97264
Merge branch 'main' into merge_flow_builders_to_current_main
manuelgloeckler Aug 7, 2025
457a0aa
Merge remote-tracking branch 'origin/main' into merge_flow_builders_t…
manuelgloeckler Aug 7, 2025
fab3c62
Add nugget as keyward argument to train
manuelgloeckler Aug 7, 2025
a1d9b98
Imporve converged docstring
manuelgloeckler Aug 7, 2025
1ced05a
Better typing and docstrings and so on
manuelgloeckler Aug 7, 2025
34f396f
docstring
manuelgloeckler Aug 7, 2025
4b09869
Add context
manuelgloeckler Aug 7, 2025
06b84a5
Extended docstring
manuelgloeckler Aug 7, 2025
7eeaa7d
move protocol
manuelgloeckler Aug 7, 2025
57fdf39
Formating fix
manuelgloeckler Aug 7, 2025
f661a05
Revert "move protocol"
manuelgloeckler Aug 7, 2025
50f7443
fix formating
manuelgloeckler Aug 7, 2025
5cb49b7
removing deprecated
manuelgloeckler Aug 7, 2025
3caf37a
Fix typing
manuelgloeckler Aug 7, 2025
4ea5e3a
Positional argument for default builder model name
manuelgloeckler Aug 8, 2025
7f281cd
Formating
manuelgloeckler Aug 8, 2025
4f88303
unify nets test
manuelgloeckler Aug 8, 2025
1691a63
fixing builder
manuelgloeckler Aug 8, 2025
77ba7a6
update notebooks
manuelgloeckler Aug 8, 2025
10f61b6
Formating and some text updates
manuelgloeckler Aug 8, 2025
314f45e
formating
manuelgloeckler Aug 8, 2025
ccee887
Fixed passing sde_type to builders, minor fixes and clean up
nMaax Aug 9, 2025
c69a74b
Re-adjusted test params for different SDE
nMaax Aug 9, 2025
b1eb2eb
Fixed flow matching simformer noising only latent input
nMaax Aug 9, 2025
495458a
Improved condition mask generation to manage only fully observed samp…
nMaax Aug 9, 2025
606db0f
Re-formated Flow Matching Simformer and Masked VF Inference interface…
nMaax Aug 9, 2025
70f52bd
Fixed gpu use in linear gaussian tests on simformer
nMaax Aug 9, 2025
8820c35
Removed slow and gpu marks on flow simformer linear gaussian tests
nMaax Aug 10, 2025
170872d
Partially refactored NeuralInference classes with new Mixin
nMaax Aug 10, 2025
d2228b0
Moved simulations handling and train functions to Base Neural Inferen…
nMaax Aug 10, 2025
04a971b
Detached Masked Neural Inference from Neural Inference, now inheritin…
nMaax Aug 10, 2025
16eae9f
Adapted new NeuralPosterior build_conditional to Simformer
nMaax Aug 11, 2025
c0ceecf
Further simplified resolve estimator and resolve prior helper functio…
nMaax Aug 11, 2025
85a11d6
Moved _create_posterior to parent Base Neural Inference Mixin
nMaax Aug 11, 2025
cc8fae4
Minor comments adjustments
nMaax Aug 11, 2025
14f2795
Completely removed inf and nan data handling from Masked Interfaces a…
nMaax Aug 11, 2025
811ca3a
Provided handling of nan/inf values for simformer, still presenting p…
nMaax Aug 12, 2025
69dea89
Added default exclude invalid x parameter for handling invalid inputs…
nMaax Aug 12, 2025
8cf0566
Addressed invalid inputs to ImproperEmprical prior to make Simformer …
nMaax Aug 13, 2025
146f56e
Re-ordered BaseNeuralInference methods and properties
nMaax Aug 13, 2025
77b4139
fix deprecation warning on default args
janfb Aug 13, 2025
b8d2ffb
Solved redundant code in init for Masked and non-Masked Neural Inference
nMaax Aug 13, 2025
b07d188
unnecessary
manuelgloeckler Aug 14, 2025
ebb8706
remove unecessary notes
manuelgloeckler Aug 14, 2025
d449e46
refactor check for deprecation warning
janfb Aug 14, 2025
109f6be
fix mcmc params passing in test
janfb Aug 14, 2025
fca9e5a
Fix mnle_test
manuelgloeckler Aug 14, 2025
b646608
Merge branch 'main' into merge_flow_builders_to_current_main
janfb Aug 14, 2025
d1c1346
add missing import
janfb Aug 14, 2025
0a72a41
Further cleaned docstring of NeuralInference and minor re-factoring
nMaax Aug 14, 2025
89eecf0
Merge remote-tracking branch 'upstream/merge_flow_builders_to_current…
nMaax Aug 14, 2025
943ac82
Fixed missing posterior_parameter use for masked interfaces and simfo…
nMaax Aug 15, 2025
22eeef4
Renamed posterior parameters for simformer build conditional and like…
nMaax Aug 15, 2025
1b83d0f
Restored max num epoch of linear gaussian over simformer vp 1 ndims t…
nMaax Aug 15, 2025
f07f9b5
Merge remote-tracking branch 'upstream/main' into simformer
nMaax Aug 15, 2025
49ba6bf
Re-factored simformer vector field net test code
nMaax Aug 16, 2025
fcce081
Cleaned docstring for API reference
nMaax Aug 16, 2025
6134df7
Improved simformer advanced tutorial notebook with later modifications
nMaax Aug 16, 2025
ba69658
Cleaned linear gaussian test on Simformer
nMaax Aug 18, 2025
bd1485a
Removed optional typing in neural net definition for NeuralInference …
nMaax Aug 18, 2025
f5208ed
Merge remote-tracking branch 'upstream/main' into simformer
nMaax Aug 18, 2025
28494cb
Extended Simformer and tests to handle non 3-dimensional data
nMaax Aug 18, 2025
21fcc18
Specified typing of neural net and prior attributes in NLE and NPE cl…
nMaax Aug 18, 2025
ad4f05e
Re-organized linear gaussian tests on simformer, capped number of epo…
nMaax Aug 18, 2025
40ab9ed
Updated simformer tutorial with later capacity of accepting 2-dimensi…
nMaax Aug 19, 2025
c15d8dd
Changed edge mask generator callable definition over condition mask r…
nMaax Aug 19, 2025
5318dc3
Refactor dataloader methods and convergence checks in BaseNeuralInfer…
nMaax Aug 20, 2025
a719d72
Defined default values for MaskedNeuralInference build conditional
nMaax Aug 21, 2025
0a40736
Refactored use of prior in NeuralInference methods, introduced NoPrio…
nMaax Aug 21, 2025
6b7a4a2
Moved NoPrior() to sbi utils
nMaax Aug 21, 2025
8a899e4
Refactored FlowMatchingSimformer as child of MaskedVectorFieldTrainer…
nMaax Aug 21, 2025
a665045
Minor change in FlowMatchingSimformer init docstring
nMaax Aug 21, 2025
3bf4cd6
BaseNeuralInference prior attribute refactoring using NoPrior for avo…
nMaax Aug 21, 2025
71d1f0b
Re-factored ConditionalVectorFieldEstimator and Masked counterpart us…
nMaax Aug 21, 2025
bdae1e7
Removed useless exclude invalid x attribute from Vector Field Trainer
nMaax Aug 22, 2025
230bfce
Fixed NoPrior to return random noise with correct shapes. Refactor pr…
nMaax Aug 22, 2025
ed58502
Handled linitng typing in missspecification given new Neural Inferenc…
nMaax Aug 22, 2025
480cd7e
Refactor neural network attributes in NeuralInference and MaskedNeura…
nMaax Aug 22, 2025
591cc75
Added loss_proposal_posterior equivalent for masked vf trainer
nMaax Aug 22, 2025
9bf7a91
Renamed attribute masks in prior_masks in vf trainer _loss for clarity
nMaax Aug 22, 2025
bc955ac
Removing prior from kwargs in masked vf trainer in case user mistaken…
nMaax Aug 22, 2025
6a7fd5d
Setting use of DiT Blocks as default in Simformer
nMaax Aug 22, 2025
15a1db7
Refactored posterior_nn test functions to avoid duplicate code for Si…
nMaax Aug 22, 2025
91196a4
Removed depcrated simformer unit test in posterior_nn tests
nMaax Aug 22, 2025
681bda2
Generalized Masked classes and Flow Simformer to feature-absent data,…
nMaax Aug 25, 2025
a601d33
Removed has_sample_dim check in mean and std functions of MaskedCondi…
nMaax Aug 25, 2025
6b33f82
Introduced Simformer and FlowMatchingSimformer to linear gaussian dif…
nMaax Aug 25, 2025
c3fe087
Adapted vector field linear gaussian test for simformer
nMaax Aug 25, 2025
9092297
Revert "Introduced Simformer and FlowMatchingSimformer to linear gaus…
nMaax Aug 25, 2025
988a9f0
Marked Simformer vp and subvp linear gaussian tests as gpu and slow
nMaax Aug 25, 2025
d759e15
Increased default hidden features for Simformer to 128
nMaax Aug 25, 2025
92f5748
Avoding dkl gaussian prior check for Simformer in linear gaussian tests
nMaax Aug 25, 2025
b07c6d5
Removed simformer dedicated linear gaussian test as it has been merge…
nMaax Aug 25, 2025
15904ab
Adapted gpu linear gaussian tests to proper device handling
nMaax Aug 25, 2025
4e59cfd
Adapted map test on vector fields to include simformer
nMaax Aug 25, 2025
f6d292c
Revert "Generalized Masked classes and Flow Simformer to feature-abse…
nMaax Aug 26, 2025
2f70bd7
Removed simformer dedicated linear gaussian map test as it has been m…
nMaax Aug 26, 2025
c119a5d
Minor fix on input shape for forward pass in Masked Conditional Score…
nMaax Aug 26, 2025
a1c78a6
Adapated masked conditional score estimator to natively use 2-dim ten…
nMaax Aug 27, 2025
89b55ee
Also adapated masked flow matching estimator to natively use 2-dim te…
nMaax Aug 27, 2025
fad4564
Simplifies rebalance loss in masked conditional score estimator
nMaax Aug 27, 2025
82b7460
Re-introduced 3-dim specific loss handling in rebalance loss for Mask…
nMaax Aug 28, 2025
b08008e
Improved shape handling for masked flow matching estimator loss pass
nMaax Aug 28, 2025
aadda6f
Added and adapted test on masked flow vf estimator
nMaax Aug 28, 2025
22f43d9
Added input shape 2-dim tests on vf estimator
nMaax Aug 28, 2025
3c4e566
Adapted Masked Conditional VF Estimator Wrapper to handle 2-dim input…
nMaax Aug 29, 2025
8620911
Adapted tests on Masked Conditional VF Estimator Wrapper to 2-dim inputs
nMaax Aug 29, 2025
b640e7d
Runned simformer advanced tutorial, stripped output
nMaax Aug 29, 2025
63eec83
Minor docstring change for Flow Matching Simformer
nMaax Aug 30, 2025
f7381ae
Provided Simformer in benchmarks
nMaax Aug 30, 2025
467fdb8
Removed reference to device in benchmarks for Simformer
nMaax Aug 30, 2025
92d4b91
Improved mask generator function of masked neural inference to proces…
nMaax Aug 30, 2025
aa77cb3
Introduced device attribute for NoPrior
nMaax Aug 30, 2025
0ce778d
Checking at least 2d on log prob of NoPrior
nMaax Aug 30, 2025
d8253be
Update NoPrior batch and even shape at append simulation if none has …
nMaax Aug 30, 2025
ee1ff2a
Removed wrong dimension in condition reshaping/repetition for _assemb…
nMaax Aug 30, 2025
b56fdd5
Added comments for assemble and disassemble full inputs helper functi…
nMaax Aug 30, 2025
018c301
Improved overall MVF Wrapper comments for clarity
nMaax Aug 30, 2025
eb73ffe
Adapted samples rejection logic for NoPrior poor shape informatiom
nMaax Aug 30, 2025
6bf3111
Added Siformer and FlowMatchingSimformer to default benchmark run
nMaax Aug 30, 2025
d7ee92a
Removed useless unsqueeze for benchmark on Simformer
nMaax Aug 30, 2025
91738a2
Improved references to Simformer paper and minor docstrings
nMaax Aug 30, 2025
666c097
Minor docustring improvement
nMaax Aug 30, 2025
e4d62fb
Removed sde_type from FlowMatchingSimformer as unused
nMaax Aug 30, 2025
2d4fe62
Minor docstring improvement
nMaax Aug 30, 2025
ac521ed
Further added comments and docstring into MVF Wrapper
nMaax Aug 31, 2025
51c2b9a
Adapted Wrapper and Posterior to handle degenerate case of full laten…
nMaax Aug 31, 2025
578bdd8
Marked two linear gaussian tests simformer as slow to speed up time o…
nMaax Aug 31, 2025
43fa23a
Further marked other linear gaussian tests simformer as slow to speed…
nMaax Aug 31, 2025
8a5b2c4
Improved comments and typing, removed parameter for progress bar as i…
nMaax Aug 31, 2025
a83b674
Ensuring tensor on device for VE sde drift fn
nMaax Aug 31, 2025
0528680
Running simformer fixtures on gpu
nMaax Aug 31, 2025
db24a5e
Ensuring masks on the same device as input in Simformer network, rais…
nMaax Aug 31, 2025
6997551
Moved ode/sde equivalence and iid linear gaussian checks simformer fi…
nMaax Aug 31, 2025
2472a47
Separated Simformer and FlowSimformer benchmark cases
nMaax Sep 1, 2025
6c8c64c
Revert "Adapted samples rejection logic for NoPrior poor shape inform…
nMaax Sep 10, 2025
d86fc69
Revert "Update NoPrior batch and even shape at append simulation if n…
nMaax Sep 10, 2025
0ba5863
Introduced support constraint real in NoPrior, fix errors in shape ch…
nMaax Sep 10, 2025
a646c25
Merge remote-tracking branch 'upstream/main' into simformer
nMaax Sep 12, 2025
eee967c
Re-ordered params for build posterior and build likelihood in Simform…
nMaax Sep 12, 2025
5fff8db
Attempted to fix bug in iid lin gauss tests over simformer
nMaax Sep 12, 2025
90d3d5e
Re-inserted check on calc misspecificaton mmd
nMaax Sep 12, 2025
08aeb73
minor typo and docs fixes
janfb Sep 19, 2025
394a48d
Revert advanced tutorials counters modifications
nMaax Oct 31, 2025
6887172
Revert advanced tutorials counters modifications again
nMaax Oct 31, 2025
f183ff2
Sorted inference init imports alphabetically
nMaax Oct 31, 2025
18c6e52
Avoided assert for an if-else statement instead
nMaax Oct 31, 2025
b9af9ae
Avoided assert for an if-else statement
nMaax Oct 31, 2025
4ea0c03
Properly typed ada_time flag for Simformer Architecture as bool inste…
nMaax Oct 31, 2025
c7e9f5f
Merge remote-tracking branch 'upstream/main' into simformer
nMaax Nov 22, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/advanced_tutorials.rst
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ Neural nets
advanced_tutorials/12_iid_data_and_permutation_invariant_embeddings.ipynb
advanced_tutorials/19_flowmatching_and_scorematching.ipynb
advanced_tutorials/20_score_based_methods_new_features.ipynb
advanced_tutorials/22_simformer.ipynb


Training
Expand Down
496 changes: 496 additions & 0 deletions docs/advanced_tutorials/22_simformer.ipynb

Large diffs are not rendered by default.

4 changes: 4 additions & 0 deletions docs/sbi.rst
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ Neural nets
sbi.neural_nets.classifier_nn
sbi.neural_nets.flowmatching_nn
sbi.neural_nets.posterior_score_nn
sbi.neural_nets.simformer_score_nn
sbi.neural_nets.simformer_flow_nn
sbi.neural_nets.marginal_nn


Expand Down Expand Up @@ -60,6 +62,8 @@ Training
sbi.inference.MNPE
sbi.inference.FMPE
sbi.inference.NPSE
sbi.inference.Simformer
sbi.inference.FlowMatchingSimformer
Comment on lines +65 to +66
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sorted alphabetically, but maybe that's fine, it seems to be sorted semantically.

sbi.inference.NLE_A
sbi.inference.NRE_A
sbi.inference.NRE_B
Expand Down
8 changes: 3 additions & 5 deletions sbi/diagnostics/misspecification.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@
from typing import Optional

import torch
import torch.nn as nn
from torch import Tensor
from torch import Tensor, nn

from sbi.inference.trainers.npe.npe_base import PosteriorEstimatorTrainer
from sbi.neural_nets.estimators import UnconditionalDensityEstimator
Expand Down Expand Up @@ -157,10 +156,9 @@ def calc_misspecification_mmd(
)
if inference._neural_net.embedding_net is None:
raise AttributeError(
"embedding_net attribute is None but is required for misspecification"
" detection."
"embedding_net attribute is None but is required for misspecification "
"detection."
)

z_obs = inference._neural_net.embedding_net(x_obs).detach()
z = inference._neural_net.embedding_net(x).detach()
else:
Expand Down
23 changes: 20 additions & 3 deletions sbi/inference/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from sbi.inference.abc import MCABC, SMCABC
from sbi.inference.trainers.base import (
MaskedNeuralInference, # noqa: F401
NeuralInference, # noqa: F401
check_if_proposal_has_default_x,
infer,
Expand All @@ -11,7 +12,7 @@
from sbi.inference.trainers.nle import MNLE, NLE_A
from sbi.inference.trainers.npe import MNPE, NPE_A, NPE_B, NPE_C # noqa: F401
from sbi.inference.trainers.nre import BNRE, NRE_A, NRE_B, NRE_C # noqa: F401
from sbi.inference.trainers.vfpe import FMPE, NPSE
from sbi.inference.trainers.vfpe import FMPE, NPSE, FlowMatchingSimformer, Simformer

SNL = SNLE = SNLE_A = NLE = NLE_A
_nle_family = ["NLE"]
Expand All @@ -33,7 +34,13 @@
_abc_family = ["ABC", "MCABC", "SMC", "SMCABC"]


__all__ = _npe_family + _nre_family + _nle_family + _abc_family + ["FMPE", "NPSE"]
__all__ = (
_npe_family
+ _nre_family
+ _nle_family
+ _abc_family
+ ["FMPE", "NPSE", "Simformer", "FlowMatchingSimformer"]
)

from sbi.inference.posteriors import (
DirectPosterior,
Expand All @@ -53,4 +60,14 @@
)
from sbi.utils.simulation_utils import simulate_for_sbi

__all__ = ["FMPE", "MarginalTrainer", "NLE", "NPE", "NPSE", "NRE", "simulate_for_sbi"]
__all__ = [
"FlowMatchingSimformer",
"FMPE",
"MarginalTrainer",
"NLE",
"NPE",
"NPSE",
"NRE",
"Simformer",
"simulate_for_sbi",
]
24 changes: 23 additions & 1 deletion sbi/inference/posteriors/vector_field_posterior.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from sbi.samplers.score.predictors import Predictor
from sbi.sbi_types import Shape
from sbi.utils import check_prior
from sbi.utils.sbiutils import gradient_ascent, within_support
from sbi.utils.sbiutils import gradient_ascent, handle_invalid_x, within_support
from sbi.utils.torchutils import ensure_theta_batched


Expand Down Expand Up @@ -202,6 +202,17 @@ def sample(

x = self._x_else_default_x(x)
x = reshape_to_batch_event(x, self.vector_field_estimator.condition_shape)
_, num_nans, num_infs = handle_invalid_x(x)

if num_nans + num_infs == 0:
ValueError(
"Some invalid entries (NaN/Infs) were "
"found in x. You probably passed these as the ground observed "
"x's `x_obs`. "
"Please, remove these values and provide reasonable observed x's "
"to avoid the sampling process to run indefinitely."
)

is_iid = x.shape[0] > 1
self.potential_fn.set_x(
x,
Expand Down Expand Up @@ -455,6 +466,17 @@ def sample_batched(
"""
num_samples = torch.Size(sample_shape).numel()
x = reshape_to_batch_event(x, self.vector_field_estimator.condition_shape)
_, num_nans, num_infs = handle_invalid_x(x)

if num_nans + num_infs == 0:
ValueError(
"Some invalid entries (NaN/Infs) were "
"found in x. You probably passed these as the ground observed "
"x's `x_obs`. "
"Please, remove these values and provide reasonable observed x's "
"to avoid the sampling process to run indefinitely."
)

condition_dim = len(self.vector_field_estimator.condition_shape)
batch_shape = x.shape[:-condition_dim]
batch_size = batch_shape.numel()
Expand Down
Loading
Loading