Skip to content

Unify flow matching and score-based models #1497

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 77 commits into from
Mar 25, 2025

Conversation

StarostinV
Copy link
Collaborator

@StarostinV StarostinV commented Mar 19, 2025

What does this PR do?

This unifies score-based models with flow matching under a common API. Furthermore, based on the following paper https://arxiv.org/abs/2410.02217, one can use the vector field learnt with flow matching to calculate the time-dependent score and marginal distributions, as well as drift and diffusion functions for SDE-based sampling. This effectively enables the use of all score-based methods (SDE-based sampling, iid, gradient evaluation, map with score-based gradient, guidance, etc) with flow matching.

API of FMPE and NPSE didn't change, but they are now wrappers around the same class, VectorFieldInference, and only differ in the following:

  • default sample_with value in build_posterior method (sde for NPSE and ode for FMPE),
  • estimator builder (FlowMatchingEstimator and ScoreBasedEstimator that are subclasses of ConditionalVectorFieldEstimator),
  • loss function and concrete SDE & ODE functions provided by the estimator.

Additionally, the ode solver is isolated into a separate API, which enables swapping ode backends. So far, only zuko backend is implemented.

Does this close any issues?

#1440 and #1462

Anything else we should know?

The following features have been implemented:

  • Add ode_solvers to be able to swap backends
  • Extend API in the ConditionalVectorFieldEstimator to enable SDE
  • Make FlowMatchingEstimator a subclass of ConditionalVectorFieldEstimator
  • Implement score-based methods in FlowMatchingEstimator
  • Implement unified VectorFieldPotential, VectorFieldPosterior, VectorFieldInference classes
  • Update FMPE and NPSE classes
  • Update score_fn_iid to support new API
  • Remove tests that rely on old API and DirectPosterior for flow matching
  • Better documentation
  • Add tests for new functionality
  • Add benchmark tests
  • Remove the previous implementation of score-based methods

Related issues

✅ Checklist

Put an x in the boxes that apply. If you're unsure about any of them, no worries — just ask!

  • I have read and followed the contribution guidelines
  • I have added helpful comments to my code where needed
  • I have added tests for new functionality
  • (If applicable) I have reported how long new tests run and marked them with pytest.mark.slow

For reviewers:

  • I have reviewed every file
  • All comments have been addressed

…ple_and_log_prob

The test relies on sample and log_prob methods of the estimators. FlowMatchiningEstimator does not implement these methods.
VectorFieldPosterior does not yet support norm_posterior
the test assumes DirectPosterior and does not support new vector field implementation of FMPE
@manuelgloeckler manuelgloeckler marked this pull request as ready for review March 20, 2025 08:21
@janfb
Copy link
Contributor

janfb commented Mar 25, 2025

@janfb @manuelgloeckler I have some good news: the iid methods work great with flow matching! However, so far the metrics in (slow iid) tests are not so great, and that is because of the neural network architecture. I trained a model independently using a very simple architecture and c2st is much better for iid sampling with flow matching. That is another PR though.

that's great! but what is the difference in the NN architecture between ours and your simple one?

Copy link
Contributor

@janfb janfb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks again @StarostinV for another round of big effort! This looks almost done now 👍

Added a couple of minor comments and questions.

@StarostinV StarostinV requested a review from janfb March 25, 2025 16:46
@StarostinV
Copy link
Collaborator Author

@janfb @manuelgloeckler I have some good news: the iid methods work great with flow matching! However, so far the metrics in (slow iid) tests are not so great, and that is because of the neural network architecture. I trained a model independently using a very simple architecture and c2st is much better for iid sampling with flow matching. That is another PR though.

that's great! but what is the difference in the NN architecture between ours and your simple one?

I haven't looked into the current architecture, but we made some quick tests during the hackathon and it was clear that the architecture for flow matching performs substantially worse than the simple one below. However, it could also be because of the number of parameters. It will become clear after unifying net builders for scores and flows.

I just use some MLP with skip connections and time embeddings for tests, but I wouldn't advertise it since it could be improved in many ways :) Maybe I just put it here for the reference:

class SimpleNet(VectorFieldNet):
    def __init__(
            self,
            in_dim: int = 2,
            condition_dim: int = 2,
            out_dim: int = 2, 
            hid_dim: int = 256, 
            time_emb_dim: int = 16,
            num_blocks: int = 3,
        ):
        super().__init__()
        self.time_embedding = SinusoidalTimeEmbedding(time_emb_dim)

        in_dims = [in_dim + time_emb_dim + condition_dim] + [hid_dim] * (num_blocks - 1)
        out_dims = [hid_dim] * (num_blocks - 1) + [out_dim]

        self.net = nn.Sequential(
            *[
                ResidualBlock(in_dim, out_dim, hid_dim=hid_dim) 
                for in_dim, out_dim in zip(in_dims, out_dims)
            ],
        )

    def forward(self, theta: torch.Tensor, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
        x = torch.cat([theta, x, self.time_embedding(t)], dim=-1)
        x = self.net(x)
        return x


class ResidualBlock(nn.Module):
    def __init__(self, in_dim: int, out_dim: int, hid_dim: int = 128):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, hid_dim),
            nn.LeakyReLU(),
            nn.LayerNorm(hid_dim),
            nn.Linear(hid_dim, out_dim),
        )

        self.residual = nn.Linear(in_dim, out_dim)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.net(x) + self.residual(x)

Copy link
Contributor

@janfb janfb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Heroic effort @StarostinV 🏅 🚀

looks all good now, thanks for your patients with all my comments! 🙏
looking forward to seeing this in action.

Copy link

codecov bot commented Apr 25, 2025

Codecov Report

Attention: Patch coverage is 90.73171% with 38 lines in your changes missing coverage. Please review.

Project coverage is 79.56%. Comparing base (8900ca0) to head (369477d).
Report is 33 commits behind head on main.

Files with missing lines Patch % Lines
...i/neural_nets/estimators/flowmatching_estimator.py 85.45% 8 Missing ⚠️
.../inference/trainers/npse/vector_field_inference.py 95.42% 7 Missing ⚠️
sbi/neural_nets/estimators/base.py 75.00% 7 Missing ⚠️
sbi/inference/potentials/vector_field_potential.py 75.00% 6 Missing ⚠️
sbi/inference/potentials/score_fn_iid.py 80.00% 5 Missing ⚠️
sbi/inference/posteriors/vector_field_posterior.py 91.30% 2 Missing ⚠️
sbi/samplers/ode_solvers/base.py 92.85% 2 Missing ⚠️
sbi/samplers/ode_solvers/ode_builder.py 85.71% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1497      +/-   ##
==========================================
- Coverage   89.45%   79.56%   -9.90%     
==========================================
  Files         128      133       +5     
  Lines       10170    10201      +31     
==========================================
- Hits         9098     8116     -982     
- Misses       1072     2085    +1013     
Flag Coverage Δ
unittests 79.56% <90.73%> (-9.90%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

Files with missing lines Coverage Δ
sbi/diagnostics/sbc.py 92.95% <100.00%> (ø)
sbi/inference/__init__.py 100.00% <ø> (ø)
sbi/inference/posteriors/__init__.py 100.00% <100.00%> (ø)
sbi/inference/potentials/__init__.py 100.00% <100.00%> (ø)
sbi/inference/trainers/fmpe/fmpe.py 100.00% <100.00%> (+5.71%) ⬆️
sbi/inference/trainers/npse/npse.py 100.00% <100.00%> (+2.79%) ⬆️
sbi/neural_nets/estimators/__init__.py 100.00% <ø> (ø)
sbi/neural_nets/estimators/score_estimator.py 93.25% <100.00%> (+0.35%) ⬆️
sbi/samplers/ode_solvers/__init__.py 100.00% <100.00%> (ø)
sbi/samplers/ode_solvers/zuko_ode.py 100.00% <100.00%> (ø)
... and 10 more

... and 33 files with indirect coverage changes

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
hackathon score-matching-performance Improving the performance of score- and flow-matching methods
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants