-
Notifications
You must be signed in to change notification settings - Fork 188
Unify flow matching and score-based models #1497
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…to unify-scores-and-flows
…ple_and_log_prob The test relies on sample and log_prob methods of the estimators. FlowMatchiningEstimator does not implement these methods.
VectorFieldPosterior does not yet support norm_posterior
the test assumes DirectPosterior and does not support new vector field implementation of FMPE
that's great! but what is the difference in the NN architecture between ours and your simple one? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks again @StarostinV for another round of big effort! This looks almost done now 👍
Added a couple of minor comments and questions.
I haven't looked into the current architecture, but we made some quick tests during the hackathon and it was clear that the architecture for flow matching performs substantially worse than the simple one below. However, it could also be because of the number of parameters. It will become clear after unifying net builders for scores and flows. I just use some MLP with skip connections and time embeddings for tests, but I wouldn't advertise it since it could be improved in many ways :) Maybe I just put it here for the reference: class SimpleNet(VectorFieldNet):
def __init__(
self,
in_dim: int = 2,
condition_dim: int = 2,
out_dim: int = 2,
hid_dim: int = 256,
time_emb_dim: int = 16,
num_blocks: int = 3,
):
super().__init__()
self.time_embedding = SinusoidalTimeEmbedding(time_emb_dim)
in_dims = [in_dim + time_emb_dim + condition_dim] + [hid_dim] * (num_blocks - 1)
out_dims = [hid_dim] * (num_blocks - 1) + [out_dim]
self.net = nn.Sequential(
*[
ResidualBlock(in_dim, out_dim, hid_dim=hid_dim)
for in_dim, out_dim in zip(in_dims, out_dims)
],
)
def forward(self, theta: torch.Tensor, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
x = torch.cat([theta, x, self.time_embedding(t)], dim=-1)
x = self.net(x)
return x
class ResidualBlock(nn.Module):
def __init__(self, in_dim: int, out_dim: int, hid_dim: int = 128):
super().__init__()
self.net = nn.Sequential(
nn.Linear(in_dim, hid_dim),
nn.LeakyReLU(),
nn.LayerNorm(hid_dim),
nn.Linear(hid_dim, out_dim),
)
self.residual = nn.Linear(in_dim, out_dim)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.net(x) + self.residual(x) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Heroic effort @StarostinV 🏅 🚀
looks all good now, thanks for your patients with all my comments! 🙏
looking forward to seeing this in action.
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #1497 +/- ##
==========================================
- Coverage 89.45% 79.56% -9.90%
==========================================
Files 128 133 +5
Lines 10170 10201 +31
==========================================
- Hits 9098 8116 -982
- Misses 1072 2085 +1013
Flags with carried forward coverage won't be shown. Click here to find out more.
🚀 New features to boost your workflow:
|
What does this PR do?
This unifies score-based models with flow matching under a common API. Furthermore, based on the following paper https://arxiv.org/abs/2410.02217, one can use the vector field learnt with flow matching to calculate the time-dependent score and marginal distributions, as well as drift and diffusion functions for SDE-based sampling. This effectively enables the use of all score-based methods (SDE-based sampling, iid, gradient evaluation, map with score-based gradient, guidance, etc) with flow matching.
API of
FMPE
andNPSE
didn't change, but they are now wrappers around the same class,VectorFieldInference
, and only differ in the following:sample_with
value inbuild_posterior
method (sde
for NPSE andode
for FMPE),FlowMatchingEstimator
andScoreBasedEstimator
that are subclasses ofConditionalVectorFieldEstimator
),Additionally, the ode solver is isolated into a separate API, which enables swapping ode backends. So far, only
zuko
backend is implemented.Does this close any issues?
#1440 and #1462
Anything else we should know?
The following features have been implemented:
ode_solvers
to be able to swap backendsConditionalVectorFieldEstimator
to enable SDEFlowMatchingEstimator
a subclass ofConditionalVectorFieldEstimator
FlowMatchingEstimator
VectorFieldPotential
,VectorFieldPosterior
,VectorFieldInference
classesFMPE
andNPSE
classesscore_fn_iid
to support new APIDirectPosterior
for flow matchingRelated issues
✅ Checklist
Put an x in the boxes that apply. If you're unsure about any of them, no worries — just ask!
For reviewers: