-
Notifications
You must be signed in to change notification settings - Fork 206
Simformer #1621
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Simformer #1621
Conversation
Removing code duplication on embedding net handing
…t use of simformer (condition is an empty tensor)
… up time of not slow tests
…t is default True, in linear gaussian vf test
…ing a warning in case it is detected
…xture to gpu Pass device information to IID method in VectorFieldBasedPotential
|
Alright, as requested by Google:
I mark the below as the last commit for my GSoC. Nonetheless, I am still able to work more on this to implement advices and fixes after review👍 |
janfb
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I started the review and made some initial comments, but I noticed that there is a conflict with main which should be resolved first.
Overall, it's amazing to have all this implemented - well done!
Here are a couple of general comments:
sde_typehandling: score-based uses it; FM variant drops it. This is correct, but should be explicit in docs (and maybe warn if provided to FM). Clarify in FlowMatchingSimformer that sde_type is ignored; consider explicitly documenting that a provided sde_type is dropped. (see above)- Trainer docstrings list a prior argument that does not exist in init. Remove “prior” from both Simformer and FlowMatchingSimformer init docstrings.
- In both trainer docstrings, the line “kwargs: ... passed to the default builder if score_estimator is a string” names the wrong parameter; it should reference mvf_estimator (and also be consistent with “model” naming in the factory).
- Normalize terminology for
time_emb_typeacross the file: consistently use "sinusoidal" | "random_fourier" (some places say "fourier"). - In VectorFieldSimformer docstring, it would be helpful to add expected shapes and dtypes:
- inputs: [B, T, in_features]
- condition_mask: [B, T], bool
- edge_mask: Optional[[B, T, T]], bool
- t: [B] or [B, 1], float
- Minor: I find it a bit confusing to have
SimformerandFMSimformeras trainer classes and thenVectorFieldSimformeras the actual NN class. Maybe, in both trainer class docstrings, add the sentence “This trainer uses the Simformer network (VectorFieldSimformer) under the hood.” Or, a potential renaming could be "SimformerTrainer", "FlowMatchingSimformerTrainer" and "SimformerNet". I prefer these long class names if they add clarity. VectorFieldSimformerdefaultnum_layers = 4, but build_simformer_network default num_layers = 5.
I also made some small typo fixes and docs clarifications already locally and will push them now.
Looking forward to doing the full review.
docs/advanced_tutorials/19_flowmatching_and_scorematching.ipynb
Outdated
Show resolved
Hide resolved
docs/advanced_tutorials/20_score_based_methods_new_features.ipynb
Outdated
Show resolved
Hide resolved
| sbi.inference.Simformer | ||
| sbi.inference.FlowMatchingSimformer |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not sorted alphabetically, but maybe that's fine, it seems to be sorted semantically.
|
|
||
| class NeuralInference(ABC): | ||
| """Abstract base class for neural inference methods.""" | ||
| def check_if_proposal_has_default_x(proposal: Any): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I suggest to move this helper function somewhere else, in a utils file, or just further down below the main classes in this file. e.g., it seems to be used only in npe_base.py, so we could move it there?
Also, can we make the type more precise, e.g., Union[Distribution, NeuralPosterior]?
Simformer
Important
This PR is part of Google Summer of Code 2025
Note
Before opening this PR, I initially experimented with the Simformer and auxiliary components in a separate branch of my sbi fork. You can find it here. I used such branch (
simformer-dev) as a first environment where I could experiment solutions with more freedom, then I opened this PR once I got a minimum viable product. Such branch basically served as my working enviroment for the first month and a half of the GSOC. Nevertheless, all the code I finalized there has been fully incorporated into this PR you are reading.More specifically, in such branch I mainly worked on a first version of the Simformer neural network architecture and the "masked" interface, I also attempted to introduce a Joint distribution interface, i.e., a parallel interface to the current "Posterior" approach in sbi that could generalize better to the Simformer case—as the Simformer do not work by means of "posterior", "likelihood" or such, but more generally by means of arbitrary conditionals. Neverthless, the later has been dropped to rather implement the use of a Wrapper class that could adapt the more general Simformer approach to the existing sbi posterior interface (see below for more information)
Implemented the Simformer, Gloeckler et al. 2024 ICML. The Simformer aims to unify the various simulation-based inference paradigms (posterior, likelihood, or arbitrary conditional sampling) within a single framework, allowing users to sample from any conditional distribution of interest—potentially acting also by a novel data generator if one samples the unconditioned joint distribution of all variables.
The Simformer diverges from the standard sbi paradigm of data provided by means of
thetaandx, it rather exploits a full tensorinputsof data and two masks:condition_maskto identify which variables are latent (to be inferred by the Simformer) and which are observed (ground data)edge_maskto identify relationships between variables, equivalent to and adjacency matrix for a DAG. This mask will be directly used by the transformer attention block to mask-out certain attention scores.Design of the Masked Classes
To accomplish this, it has been necessary to create some "parallel" classes of the current
ScoreEstimator,VectorFieldEstimator, etc. to work by means of this "masked" paradigm.Generally, each "Masked" version of other objects are provided exactly below their counterpart in the same python file, e.g.
MaskedConditionalVectorFieldEstimatoris exactly below the code block ofConditionalVectorFieldEstimator; and they simply consist in an overall re-factor of the original counterpart, where each use of a "thetaandx" or "inputsandcondition" has been replaced with a general "inputs,condtion_mask, andedge_mask".It has been also introduced a Wrapper class able to adapt the original API of the Posterior to the Simformer one, thanks to this class one is able to simply call$(\theta, x)$ setting to the full input tensor, and back.
build_conditional()method directly on the Simformer inference object and obtain a standard Posterior object that works as always—given some fixed condition and edge masks. The Wrapper handles all the shapes automatically and perform auxiliary operations to pass the data to a Simformer network and the underlying masked estimator; this is done mainly through two helper functions:assemble_full_inputs()anddisassemble_full_inputs(), which are able to convert between theAt inference time, an
edge_maskcan be specified, otherwise it will beNone(equivalent to a full ones tensor, but memory safer),condition_maskinstead must be specifically passed atbuild_conditionaltime; another option is to directly use thebuild_posterior()andbuild_likelihood()method which will automatically generate an appropriate condition_mask based onposterior_latent_idxandposterior_observed_idxparameters specified at init() of the Simformer.Also at training time an$\text{Bernoulli}(p=0.5)$ .
edge_maskcan be specified, if not the default value will still beNone, more generally the user can pass a Callable to generate condition or edge masks, so that one can simply choose the mask distributions they prefer. Sets of tensors/lists or even just one tensor can be passed as well. Masks are also generated just-in-time (JIT) for the training, that is, they are not provided atappend_simulation(), but during the train() in order to save up memory. Differently from inference time, here if a condition mask is not specified, a default generator will be used, producing masks sampled by aNote that the Simformer potentially allows the user to set any mask of their choice both at training and inference time, it is rather duty of the user to provide coherent definitions (callables, sets, or fixed tensors) that make sense, e.g. if the user passes a specific edge mask at training time, the Simformer will learn that specific DAG structure, it is then duty of the user to pass a coherent edge_mask also when calling build_conditional, build_posterior or build_likelihood.
Furthermore, the Simformer is also able to manage invalid inputs (
nan's andinf's) natively, ifhandle_invalid_x=Truethen the Simformer will automatically spot invalid inputs at training time (still JIT) and switch their state on the condition mask as latent (to be inferred), other than also replace such values with small Gaussian noise for numerical stability.Also, a Flow-matching equivalent of the Simformer (we assumed the above to be score-based) has been provided.
This PR then includes integration with the
mini-sbibmbenchmakr suite, and a notebook tutorial for the Simformer (underadvanced_tutorials/docs), where I showcase its use. I also tried to make the API Reference as clear as possible for documentation.Refactor of existing code
Parts of the existing code have been refactored, mainly to avoid repetition of code and keep everything DRY. The most important pieces of code that have been modified are:
mean_t,std_tetc. into some standard Mixins (e.g., instead ofVEScoreEstimator(ConditionalScoreEstimator)one now haveVarianceExplodingSDEwhich definedmean_t,std_tetc., andVEScoreEstimatorbecomesVEScoreEstimator(ConditionalScoreEstimator, VarianceExplodingSDE); so that I can also define easilyMaskedVEScoreEstimator(MaskedConditionalScoreEstimator, VarianceExplodingSDE)without repeating the VE SDE pieces.)NeuralInferenceinterface, which has been split using a Mixin too (BaseNeuralInference) which defines shared properties of bothNeuralInferenceandMaskedNeuralInference, this also requested some minor adjustments mainly for methods such as_resolve_prior()and_resolve_estimator(), most importantly a newNoPriorobject has been created as a temporary solution for Keep prior optional and remove unnecessary copies of theas from ImproperPrior. #1635ConditionalVectorFieldEstimatorand theMaskedConditionalVectorFieldEstimatorwhere simplified by moving shared code into a Mixin calledBaseConditionalVectorFieldEstimator, mainly regardingmean_base,std_baseproperties, or methods such asdiffusion_fn()Summary of modified files
Files I modified should count to be the following:
sbi/inference
sbi/inference/trainers/base.py: AddedMaskedNeuralInference.sbi/inference/trainers/vfpe/base_vf_inference.py: AddedMaskedVectorFieldEstimatorBuilderandMaskedVectorFieldInference(subclass ofMaskedNeuralInference).sbi/inference/trainers/vfpe/simformer.py: New file introducing the Simformer inference class.sbi/neural_nets
sbi/neural_nets/factory.py: Added support for building Simformer networks (simformer_nn).sbi/neural_nets/estimators/base.py: AddedMaskedConditionalEstimatorandMaskedConditionalVectorFieldEstimator(subclass ofMaskedConditionalEstimator).sbi/neural_nets/estimators/score_estimator.py:MaskedConditionalScoreEstimator(subclass ofMaskedConditionalVectorFieldEstimator), placed directly aboveConditionalScoreEstimator.MaskedVEScoreEstimator(subclass ofMaskedConditionalScoreEstimator).sbi/neural_nets/net_builders/vector_field_nets.py:build_vector_field_estimatorupdated to supportsimformerandmasked-score.MaskedSimformerBlock,MaskedDiTBlock,SimformerNet(subclass ofMaskedVectorFieldNet), andbuild_simformer_network(defines default architecture parameters).sbi/utils
sbi/utils/vector_field_utils.py: AddedMaskedVectorFieldNet.sbi/analysis
sbi/analysis/plots.py: Minor fix to ensure CPU conversion inensure_numpy()(added.cpu()before.numpy()).Unit Test
Introduced benchmarks (
mini_sbibm) and test for the simformer and related masked objects intests/linearGaussian_vector_field_test.pytests/posterior_nn_test.pytests/vector_field_nets_test.pytests/vf_estimator_test.py(which also includes shape tests on the Wrapper)tests/bm_test.pyRegarding linear gaussian tests, I tried to implement the simformer tests in existing methods as much as possible, nonetheless iid test and sde/ode sampling equivalence are still provided as separate dedicated tests and fixtures
New files
docs/advanced_tutorials/22_simformer.ipynbsbi/inference/trainers/vfpe/simformer.py: including both Score-based and Flow-matching Simformer interfacesThank you
Thank you sbi and Google for this opportunity. It has been so rewarding implementing the Simformer: not only I learned something completely new itself, but most importantly I understood how to do it: having to familiarize with new concepts, writing code within code made by others, and following indications of mentors are the real value of this experience. Special thanks to my mentors Manuel (@manuelgloeckler ) and Jan (@janfb ) for accepting my proposal, and @manuelgloeckler in particular for having helped me throughout the whole journey!