Skip to content

Commit 34f3797

Browse files
add prediction type for rflow scheduler (#8386)
Add prediction type to RFlow ### Description A few sentences describing the changes proposed in this pull request. ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Can-Zhao <canz@nvidia.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 90de55b commit 34f3797

File tree

1 file changed

+16
-0
lines changed

1 file changed

+16
-0
lines changed

monai/networks/schedulers/rectified_flow.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,22 @@
3434
import torch
3535
from torch.distributions import LogisticNormal
3636

37+
from monai.utils import StrEnum
38+
39+
from .ddpm import DDPMPredictionType
3740
from .scheduler import Scheduler
3841

3942

43+
class RFlowPredictionType(StrEnum):
44+
"""
45+
Set of valid prediction type names for the RFlow scheduler's `prediction_type` argument.
46+
47+
v_prediction: velocity prediction, see section 2.4 https://imagen.research.google/video/paper.pdf
48+
"""
49+
50+
V_PREDICTION = DDPMPredictionType.V_PREDICTION
51+
52+
4053
def timestep_transform(
4154
t, input_img_size_numel, base_img_size_numel=32 * 32 * 32, scale=1.0, num_train_timesteps=1000, spatial_dim=3
4255
):
@@ -143,6 +156,9 @@ def __init__(
143156
base_img_size_numel: int = 32 * 32 * 32,
144157
spatial_dim: int = 3,
145158
):
159+
# rectified flow only accepts velocity prediction
160+
self.prediction_type = RFlowPredictionType.V_PREDICTION
161+
146162
self.num_train_timesteps = num_train_timesteps
147163
self.use_discrete_timesteps = use_discrete_timesteps
148164
self.base_img_size_numel = base_img_size_numel

0 commit comments

Comments
 (0)