Skip to content

Commit 3b83a56

Browse files
committed
add rectified flow noise scheduler to monai
Signed-off-by: Can-Zhao <canz@nvidia.com> Signed-off-by: Can-Zhao <volcanofly@gmail.com>
1 parent 621fc5f commit 3b83a56

File tree

5 files changed

+317
-6
lines changed

5 files changed

+317
-6
lines changed

monai/inferers/inferer.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
SPADEAutoencoderKL,
4040
SPADEDiffusionModelUNet,
4141
)
42-
from monai.networks.schedulers import Scheduler
42+
from monai.networks.schedulers import RFlowScheduler, Scheduler
4343
from monai.transforms import CenterSpatialCrop, SpatialPad
4444
from monai.utils import BlendMode, Ordering, PatchKeys, PytorchPadMode, ensure_tuple, optional_import
4545
from monai.visualize import CAM, GradCAM, GradCAMpp
@@ -859,12 +859,19 @@ def sample(
859859
if not scheduler:
860860
scheduler = self.scheduler
861861
image = input_noise
862+
863+
all_next_timesteps = torch.cat((scheduler.timesteps[1:], torch.tensor([0], dtype=scheduler.timesteps.dtype)))
862864
if verbose and has_tqdm:
863-
progress_bar = tqdm(scheduler.timesteps)
865+
progress_bar = tqdm(
866+
zip(scheduler.timesteps, all_next_timesteps),
867+
total=min(len(scheduler.timesteps), len(all_next_timesteps)),
868+
)
864869
else:
865870
progress_bar = iter(scheduler.timesteps)
871+
progress_bar = iter(zip(scheduler.timesteps, all_next_timesteps))
866872
intermediates = []
867-
for t in progress_bar:
873+
874+
for t, next_t in progress_bar:
868875
# 1. predict noise model_output
869876
diffusion_model = (
870877
partial(diffusion_model, seg=seg)
@@ -882,9 +889,13 @@ def sample(
882889
)
883890

884891
# 2. compute previous image: x_t -> x_t-1
885-
image, _ = scheduler.step(model_output, t, image)
892+
if not isinstance(scheduler, RFlowScheduler):
893+
image, _ = scheduler.step(model_output, t, image)
894+
else:
895+
image, _ = scheduler.step(model_output, t, image, next_t)
886896
if save_intermediates and t % intermediate_steps == 0:
887897
intermediates.append(image)
898+
888899
if save_intermediates:
889900
return image, intermediates
890901
else:

monai/networks/schedulers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,4 +14,5 @@
1414
from .ddim import DDIMScheduler
1515
from .ddpm import DDPMScheduler
1616
from .pndm import PNDMScheduler
17+
from .rectified_flow import RFlowScheduler
1718
from .scheduler import NoiseSchedules, Scheduler
Lines changed: 283 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,283 @@
1+
# Copyright (c) MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
#
12+
# =========================================================================
13+
# Adapted from https://github.com/hpcaitech/Open-Sora/blob/main/opensora/schedulers/rf/rectified_flow.py
14+
# which has the following license:
15+
# https://github.com/hpcaitech/Open-Sora/blob/main/LICENSE
16+
# Licensed under the Apache License, Version 2.0 (the "License");
17+
# you may not use this file except in compliance with the License.
18+
# You may obtain a copy of the License at
19+
#
20+
# http://www.apache.org/licenses/LICENSE-2.0
21+
#
22+
# Unless required by applicable law or agreed to in writing, software
23+
# distributed under the License is distributed on an "AS IS" BASIS,
24+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
25+
# See the License for the specific language governing permissions and
26+
# limitations under the License.
27+
# =========================================================================
28+
29+
from __future__ import annotations
30+
31+
from typing import Any
32+
33+
import numpy as np
34+
import torch
35+
from torch.distributions import LogisticNormal
36+
37+
from .scheduler import Scheduler
38+
39+
40+
def timestep_transform(
41+
t, input_img_size_numel, base_img_size_numel=32 * 32 * 32, scale=1.0, num_train_timesteps=1000, spatial_dim=3
42+
):
43+
"""
44+
Applies a transformation to the timestep based on image resolution scaling.
45+
46+
Args:
47+
t (torch.Tensor): The original timestep(s).
48+
input_img_size_numel (torch.Tensor): The input image's size (H * W * D).
49+
base_img_size_numel (int): reference H*W*D size, usually smaller than input_img_size_numel.
50+
scale (float): Scaling factor for the transformation.
51+
num_train_timesteps (int): Total number of training timesteps.
52+
spatial_dim (int): Number of spatial dimensions in the image.
53+
54+
Returns:
55+
torch.Tensor: Transformed timestep(s).
56+
"""
57+
t = t / num_train_timesteps
58+
ratio_space = (input_img_size_numel / base_img_size_numel).pow(1.0 / spatial_dim)
59+
60+
ratio = ratio_space * scale
61+
new_t = ratio * t / (1 + (ratio - 1) * t)
62+
63+
new_t = new_t * num_train_timesteps
64+
return new_t
65+
66+
67+
class RFlowScheduler(Scheduler):
68+
"""
69+
A rectified flow scheduler for guiding the diffusion process in a generative model.
70+
71+
Supports uniform and logit-normal sampling methods, timestep transformation for
72+
different resolutions, and noise addition during diffusion.
73+
74+
Attributes:
75+
num_train_timesteps (int): Total number of training timesteps.
76+
use_discrete_timesteps (bool): Whether to use discrete timesteps.
77+
sample_method (str): Training time step sampling method ('uniform' or 'logit-normal').
78+
loc (float): Location parameter for logit-normal distribution, used only if sample_method='logit-normal'.
79+
scale (float): Scale parameter for logit-normal distribution, used only if sample_method='logit-normal'.
80+
use_timestep_transform (bool): Whether to apply timestep transformation.
81+
If true, there will be more inference timesteps at early(noisy) stages for larger image volumes.
82+
transform_scale (float): Scaling factor for timestep transformation, used only if use_timestep_transform=True.
83+
steps_offset (int): Offset added to computed timesteps, used only if use_timestep_transform=True.
84+
base_img_size_numel (int): Reference image volume size for scaling, used only if use_timestep_transform=True.
85+
86+
Example:
87+
88+
.. code-block:: python
89+
90+
# define a scheduler
91+
noise_scheduler = RFlowScheduler(
92+
num_train_timesteps = 1000,
93+
use_discrete_timesteps = True,
94+
sample_method = 'logit-normal',
95+
use_timestep_transform = True,
96+
base_img_size_numel = 32 * 32 * 32
97+
)
98+
99+
# during training
100+
inputs = torch.ones(2,4,64,64,64)
101+
noise = torch.randn_like(inputs)
102+
timesteps = noise_scheduler.sample_timesteps(inputs)
103+
noisy_inputs = noise_scheduler.add_noise(original_samples=inputs, noise=noise, timesteps=timesteps)
104+
predicted_velocity = diffusion_unet(
105+
x=noisy_inputs,
106+
timesteps=timesteps
107+
)
108+
loss = loss_l1(predicted_velocity, (inputs - noise))
109+
110+
# during inference
111+
noisy_inputs = torch.randn(2,4,64,64,64)
112+
input_img_size_numel = torch.prod(torch.tensor(noisy_inputs.shape[-3:])
113+
noise_scheduler.set_timesteps(
114+
num_inference_steps=30, input_img_size_numel=input_img_size_numel)
115+
)
116+
all_next_timesteps = torch.cat(
117+
(noise_scheduler.timesteps[1:], torch.tensor([0], dtype=noise_scheduler.timesteps.dtype))
118+
)
119+
for t, next_t in tqdm(
120+
zip(noise_scheduler.timesteps, all_next_timesteps),
121+
total=min(len(noise_scheduler.timesteps), len(all_next_timesteps)),
122+
):
123+
predicted_velocity = diffusion_unet(
124+
x=noisy_inputs,
125+
timesteps=timesteps
126+
)
127+
noisy_inputs, _ = noise_scheduler.step(predicted_velocity, t, noisy_inputs, next_t)
128+
final_output = noisy_inputs
129+
"""
130+
131+
def __init__(
132+
self,
133+
num_train_timesteps: int = 1000,
134+
use_discrete_timesteps: bool = True,
135+
sample_method: str = "uniform",
136+
loc: float = 0.0,
137+
scale: float = 1.0,
138+
use_timestep_transform: bool = False,
139+
transform_scale: float = 1.0,
140+
steps_offset: int = 0,
141+
base_img_size_numel: int = 32 * 32 * 32,
142+
):
143+
self.num_train_timesteps = num_train_timesteps
144+
self.use_discrete_timesteps = use_discrete_timesteps
145+
self.base_img_size_numel = base_img_size_numel
146+
147+
# sample method
148+
if sample_method not in ["uniform", "logit-normal"]:
149+
raise ValueError(
150+
f"sample_method = {sample_method}, which has to be chosen from ['uniform', 'logit-normal']."
151+
)
152+
self.sample_method = sample_method
153+
if sample_method == "logit-normal":
154+
self.distribution = LogisticNormal(torch.tensor([loc]), torch.tensor([scale]))
155+
self.sample_t = lambda x: self.distribution.sample((x.shape[0],))[:, 0].to(x.device)
156+
157+
# timestep transform
158+
self.use_timestep_transform = use_timestep_transform
159+
self.transform_scale = transform_scale
160+
self.steps_offset = steps_offset
161+
162+
def add_noise(
163+
self, original_samples: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.IntTensor
164+
) -> torch.FloatTensor:
165+
"""
166+
Adds noise to the original samples based on the given timesteps.
167+
168+
Args:
169+
original_samples (torch.FloatTensor): The original sample tensor.
170+
noise (torch.FloatTensor): Noise tensor to be added.
171+
timesteps (torch.IntTensor): Timesteps corresponding to each sample.
172+
173+
Returns:
174+
torch.FloatTensor: The noisy sample tensor.
175+
"""
176+
timepoints = timesteps.float() / self.num_train_timesteps
177+
timepoints = 1 - timepoints # [1,1/1000]
178+
179+
# timepoint (bsz) noise: (bsz, 4, frame, w ,h)
180+
# expand timepoint to noise shape
181+
timepoints = timepoints.unsqueeze(1).unsqueeze(1).unsqueeze(1).unsqueeze(1)
182+
timepoints = timepoints.repeat(1, noise.shape[1], noise.shape[2], noise.shape[3], noise.shape[4])
183+
184+
return timepoints * original_samples + (1 - timepoints) * noise
185+
186+
def set_timesteps(
187+
self,
188+
num_inference_steps: int,
189+
device: str | torch.device | None = None,
190+
input_img_size_numel: int | None = None,
191+
) -> None:
192+
"""
193+
Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
194+
195+
Args:
196+
num_inference_steps: number of diffusion steps used when generating samples with a pre-trained model.
197+
device: target device to put the data.
198+
input_img_size_numel: int, H*W*D of the image, used with self.use_timestep_transform is True.
199+
"""
200+
if num_inference_steps > self.num_train_timesteps:
201+
raise ValueError(
202+
f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.num_train_timesteps`:"
203+
f" {self.num_train_timesteps} as the unet model trained with this scheduler can only handle"
204+
f" maximal {self.num_train_timesteps} timesteps."
205+
)
206+
207+
self.num_inference_steps = num_inference_steps
208+
# prepare timesteps
209+
timesteps = [
210+
(1.0 - i / self.num_inference_steps) * self.num_train_timesteps for i in range(self.num_inference_steps)
211+
]
212+
if self.use_discrete_timesteps:
213+
timesteps = [int(round(t)) for t in timesteps]
214+
if self.use_timestep_transform:
215+
timesteps = [
216+
timestep_transform(
217+
t,
218+
input_img_size_numel=input_img_size_numel,
219+
base_img_size_numel=self.base_img_size_numel,
220+
num_train_timesteps=self.num_train_timesteps,
221+
)
222+
for t in timesteps
223+
]
224+
timesteps = np.array(timesteps).astype(np.float16)
225+
if self.use_discrete_timesteps:
226+
timesteps = timesteps.astype(np.int64)
227+
self.timesteps = torch.from_numpy(timesteps).to(device)
228+
self.timesteps += self.steps_offset
229+
230+
def sample_timesteps(self, x_start):
231+
"""
232+
Randomly samples training timesteps using the chosen sampling method.
233+
234+
Args:
235+
x_start (torch.Tensor): The input tensor for sampling.
236+
237+
Returns:
238+
torch.Tensor: Sampled timesteps.
239+
"""
240+
if self.sample_method == "uniform":
241+
t = torch.rand((x_start.shape[0],), device=x_start.device) * self.num_train_timesteps
242+
elif self.sample_method == "logit-normal":
243+
t = self.sample_t(x_start) * self.num_train_timesteps
244+
245+
if self.use_discrete_timesteps:
246+
t = t.long()
247+
248+
if self.use_timestep_transform:
249+
input_img_size_numel = torch.prod(torch.tensor(x_start.shape[-3:]))
250+
t = timestep_transform(
251+
t,
252+
input_img_size_numel=input_img_size_numel,
253+
base_img_size_numel=self.base_img_size_numel,
254+
num_train_timesteps=self.num_train_timesteps,
255+
)
256+
257+
return t
258+
259+
def step(
260+
self, model_output: torch.Tensor, timestep: int, sample: torch.Tensor, next_timestep=None
261+
) -> tuple[torch.Tensor, Any]:
262+
"""
263+
Predict the sample at the previous timestep. Core function to propagate the diffusion
264+
process from the learned model outputs.
265+
266+
Args:
267+
model_output: direct output from learned diffusion model.
268+
timestep: current discrete timestep in the diffusion chain.
269+
sample: current instance of sample being created by diffusion process.
270+
next_timestep: next discrete timestep in the diffusion chain.
271+
Returns:
272+
pred_prev_sample: Predicted previous sample
273+
None
274+
"""
275+
v_pred = model_output
276+
if next_timestep is None:
277+
dt = 1.0 / self.num_inference_steps
278+
else:
279+
dt = timestep - next_timestep
280+
dt = dt / self.num_train_timesteps
281+
z = sample + v_pred * dt
282+
283+
return z, None

monai/utils/jupyter_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -234,7 +234,7 @@ def plot_engine_status(
234234

235235

236236
def _get_loss_from_output(
237-
output: list[torch.Tensor | dict[str, torch.Tensor]] | dict[str, torch.Tensor] | torch.Tensor,
237+
output: list[torch.Tensor | dict[str, torch.Tensor]] | dict[str, torch.Tensor] | torch.Tensor
238238
) -> torch.Tensor:
239239
"""Returns a single value from the network output, which is a dict or tensor."""
240240

tests/test_diffusion_inferer.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
from monai.inferers import DiffusionInferer
2121
from monai.networks.nets import DiffusionModelUNet
22-
from monai.networks.schedulers import DDIMScheduler, DDPMScheduler
22+
from monai.networks.schedulers import DDIMScheduler, DDPMScheduler, RFlowScheduler
2323
from monai.utils import optional_import
2424

2525
_, has_scipy = optional_import("scipy")
@@ -120,6 +120,22 @@ def test_ddim_sampler(self, model_params, input_shape):
120120
)
121121
self.assertEqual(len(intermediates), 10)
122122

123+
@parameterized.expand(TEST_CASES)
124+
@skipUnless(has_einops, "Requires einops")
125+
def test_rflow_sampler(self, model_params, input_shape):
126+
model = DiffusionModelUNet(**model_params)
127+
device = "cuda:0" if torch.cuda.is_available() else "cpu"
128+
model.to(device)
129+
model.eval()
130+
noise = torch.randn(input_shape).to(device)
131+
scheduler = RFlowScheduler(num_train_timesteps=1000)
132+
inferer = DiffusionInferer(scheduler=scheduler)
133+
scheduler.set_timesteps(num_inference_steps=10)
134+
sample, intermediates = inferer.sample(
135+
input_noise=noise, diffusion_model=model, scheduler=scheduler, save_intermediates=True, intermediate_steps=1
136+
)
137+
self.assertEqual(len(intermediates), 10)
138+
123139
@parameterized.expand(TEST_CASES)
124140
@skipUnless(has_einops, "Requires einops")
125141
def test_sampler_conditioned(self, model_params, input_shape):

0 commit comments

Comments
 (0)