Skip to content

Commit 956bdcc

Browse files
authored
Flag Flax schedulers as deprecated (#13031)
flag flax schedulers as deprecated
1 parent 2af7baa commit 956bdcc

File tree

6 files changed

+141
-24
lines changed

6 files changed

+141
-24
lines changed

src/diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py

Lines changed: 32 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import jax.numpy as jnp
2323

2424
from ..configuration_utils import ConfigMixin, register_to_config
25+
from ..utils import logging
2526
from .scheduling_utils_flax import (
2627
CommonSchedulerState,
2728
FlaxKarrasDiffusionSchedulers,
@@ -31,6 +32,9 @@
3132
)
3233

3334

35+
logger = logging.get_logger(__name__)
36+
37+
3438
@flax.struct.dataclass
3539
class DPMSolverMultistepSchedulerState:
3640
common: CommonSchedulerState
@@ -171,6 +175,10 @@ def __init__(
171175
timestep_spacing: str = "linspace",
172176
dtype: jnp.dtype = jnp.float32,
173177
):
178+
logger.warning(
179+
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
180+
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
181+
)
174182
self.dtype = dtype
175183

176184
def create_state(self, common: Optional[CommonSchedulerState] = None) -> DPMSolverMultistepSchedulerState:
@@ -203,7 +211,10 @@ def create_state(self, common: Optional[CommonSchedulerState] = None) -> DPMSolv
203211
)
204212

205213
def set_timesteps(
206-
self, state: DPMSolverMultistepSchedulerState, num_inference_steps: int, shape: Tuple
214+
self,
215+
state: DPMSolverMultistepSchedulerState,
216+
num_inference_steps: int,
217+
shape: Tuple,
207218
) -> DPMSolverMultistepSchedulerState:
208219
"""
209220
Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
@@ -301,10 +312,13 @@ def convert_model_output(
301312
if self.config.thresholding:
302313
# Dynamic thresholding in https://huggingface.co/papers/2205.11487
303314
dynamic_max_val = jnp.percentile(
304-
jnp.abs(x0_pred), self.config.dynamic_thresholding_ratio, axis=tuple(range(1, x0_pred.ndim))
315+
jnp.abs(x0_pred),
316+
self.config.dynamic_thresholding_ratio,
317+
axis=tuple(range(1, x0_pred.ndim)),
305318
)
306319
dynamic_max_val = jnp.maximum(
307-
dynamic_max_val, self.config.sample_max_value * jnp.ones_like(dynamic_max_val)
320+
dynamic_max_val,
321+
self.config.sample_max_value * jnp.ones_like(dynamic_max_val),
308322
)
309323
x0_pred = jnp.clip(x0_pred, -dynamic_max_val, dynamic_max_val) / dynamic_max_val
310324
return x0_pred
@@ -385,7 +399,11 @@ def multistep_dpm_solver_second_order_update(
385399
"""
386400
t, s0, s1 = prev_timestep, timestep_list[-1], timestep_list[-2]
387401
m0, m1 = model_output_list[-1], model_output_list[-2]
388-
lambda_t, lambda_s0, lambda_s1 = state.lambda_t[t], state.lambda_t[s0], state.lambda_t[s1]
402+
lambda_t, lambda_s0, lambda_s1 = (
403+
state.lambda_t[t],
404+
state.lambda_t[s0],
405+
state.lambda_t[s1],
406+
)
389407
alpha_t, alpha_s0 = state.alpha_t[t], state.alpha_t[s0]
390408
sigma_t, sigma_s0 = state.sigma_t[t], state.sigma_t[s0]
391409
h, h_0 = lambda_t - lambda_s0, lambda_s0 - lambda_s1
@@ -443,7 +461,12 @@ def multistep_dpm_solver_third_order_update(
443461
Returns:
444462
`jnp.ndarray`: the sample tensor at the previous timestep.
445463
"""
446-
t, s0, s1, s2 = prev_timestep, timestep_list[-1], timestep_list[-2], timestep_list[-3]
464+
t, s0, s1, s2 = (
465+
prev_timestep,
466+
timestep_list[-1],
467+
timestep_list[-2],
468+
timestep_list[-3],
469+
)
447470
m0, m1, m2 = model_output_list[-1], model_output_list[-2], model_output_list[-3]
448471
lambda_t, lambda_s0, lambda_s1, lambda_s2 = (
449472
state.lambda_t[t],
@@ -615,7 +638,10 @@ def step_3(state: DPMSolverMultistepSchedulerState) -> jnp.ndarray:
615638
return FlaxDPMSolverMultistepSchedulerOutput(prev_sample=prev_sample, state=state)
616639

617640
def scale_model_input(
618-
self, state: DPMSolverMultistepSchedulerState, sample: jnp.ndarray, timestep: Optional[int] = None
641+
self,
642+
state: DPMSolverMultistepSchedulerState,
643+
sample: jnp.ndarray,
644+
timestep: Optional[int] = None,
619645
) -> jnp.ndarray:
620646
"""
621647
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the

src/diffusers/schedulers/scheduling_euler_discrete_flax.py

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import jax.numpy as jnp
2020

2121
from ..configuration_utils import ConfigMixin, register_to_config
22+
from ..utils import logging
2223
from .scheduling_utils_flax import (
2324
CommonSchedulerState,
2425
FlaxKarrasDiffusionSchedulers,
@@ -28,6 +29,9 @@
2829
)
2930

3031

32+
logger = logging.get_logger(__name__)
33+
34+
3135
@flax.struct.dataclass
3236
class EulerDiscreteSchedulerState:
3337
common: CommonSchedulerState
@@ -40,9 +44,18 @@ class EulerDiscreteSchedulerState:
4044

4145
@classmethod
4246
def create(
43-
cls, common: CommonSchedulerState, init_noise_sigma: jnp.ndarray, timesteps: jnp.ndarray, sigmas: jnp.ndarray
47+
cls,
48+
common: CommonSchedulerState,
49+
init_noise_sigma: jnp.ndarray,
50+
timesteps: jnp.ndarray,
51+
sigmas: jnp.ndarray,
4452
):
45-
return cls(common=common, init_noise_sigma=init_noise_sigma, timesteps=timesteps, sigmas=sigmas)
53+
return cls(
54+
common=common,
55+
init_noise_sigma=init_noise_sigma,
56+
timesteps=timesteps,
57+
sigmas=sigmas,
58+
)
4659

4760

4861
@dataclass
@@ -99,6 +112,10 @@ def __init__(
99112
timestep_spacing: str = "linspace",
100113
dtype: jnp.dtype = jnp.float32,
101114
):
115+
logger.warning(
116+
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
117+
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
118+
)
102119
self.dtype = dtype
103120

104121
def create_state(self, common: Optional[CommonSchedulerState] = None) -> EulerDiscreteSchedulerState:
@@ -146,7 +163,10 @@ def scale_model_input(self, state: EulerDiscreteSchedulerState, sample: jnp.ndar
146163
return sample
147164

148165
def set_timesteps(
149-
self, state: EulerDiscreteSchedulerState, num_inference_steps: int, shape: Tuple = ()
166+
self,
167+
state: EulerDiscreteSchedulerState,
168+
num_inference_steps: int,
169+
shape: Tuple = (),
150170
) -> EulerDiscreteSchedulerState:
151171
"""
152172
Sets the timesteps used for the diffusion chain. Supporting function to be run before inference.
@@ -159,7 +179,12 @@ def set_timesteps(
159179
"""
160180

161181
if self.config.timestep_spacing == "linspace":
162-
timesteps = jnp.linspace(self.config.num_train_timesteps - 1, 0, num_inference_steps, dtype=self.dtype)
182+
timesteps = jnp.linspace(
183+
self.config.num_train_timesteps - 1,
184+
0,
185+
num_inference_steps,
186+
dtype=self.dtype,
187+
)
163188
elif self.config.timestep_spacing == "leading":
164189
step_ratio = self.config.num_train_timesteps // num_inference_steps
165190
timesteps = (jnp.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(float)

src/diffusers/schedulers/scheduling_karras_ve_flax.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,13 @@
2222
from jax import random
2323

2424
from ..configuration_utils import ConfigMixin, register_to_config
25-
from ..utils import BaseOutput
25+
from ..utils import BaseOutput, logging
2626
from .scheduling_utils_flax import FlaxSchedulerMixin
2727

2828

29+
logger = logging.get_logger(__name__)
30+
31+
2932
@flax.struct.dataclass
3033
class KarrasVeSchedulerState:
3134
# setable values
@@ -102,7 +105,10 @@ def __init__(
102105
s_min: float = 0.05,
103106
s_max: float = 50,
104107
):
105-
pass
108+
logger.warning(
109+
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
110+
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
111+
)
106112

107113
def create_state(self):
108114
return KarrasVeSchedulerState.create()

src/diffusers/schedulers/scheduling_lms_discrete_flax.py

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from scipy import integrate
2121

2222
from ..configuration_utils import ConfigMixin, register_to_config
23+
from ..utils import logging
2324
from .scheduling_utils_flax import (
2425
CommonSchedulerState,
2526
FlaxKarrasDiffusionSchedulers,
@@ -29,6 +30,9 @@
2930
)
3031

3132

33+
logger = logging.get_logger(__name__)
34+
35+
3236
@flax.struct.dataclass
3337
class LMSDiscreteSchedulerState:
3438
common: CommonSchedulerState
@@ -44,9 +48,18 @@ class LMSDiscreteSchedulerState:
4448

4549
@classmethod
4650
def create(
47-
cls, common: CommonSchedulerState, init_noise_sigma: jnp.ndarray, timesteps: jnp.ndarray, sigmas: jnp.ndarray
51+
cls,
52+
common: CommonSchedulerState,
53+
init_noise_sigma: jnp.ndarray,
54+
timesteps: jnp.ndarray,
55+
sigmas: jnp.ndarray,
4856
):
49-
return cls(common=common, init_noise_sigma=init_noise_sigma, timesteps=timesteps, sigmas=sigmas)
57+
return cls(
58+
common=common,
59+
init_noise_sigma=init_noise_sigma,
60+
timesteps=timesteps,
61+
sigmas=sigmas,
62+
)
5063

5164

5265
@dataclass
@@ -101,6 +114,10 @@ def __init__(
101114
prediction_type: str = "epsilon",
102115
dtype: jnp.dtype = jnp.float32,
103116
):
117+
logger.warning(
118+
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
119+
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
120+
)
104121
self.dtype = dtype
105122

106123
def create_state(self, common: Optional[CommonSchedulerState] = None) -> LMSDiscreteSchedulerState:
@@ -165,7 +182,10 @@ def lms_derivative(tau):
165182
return integrated_coeff
166183

167184
def set_timesteps(
168-
self, state: LMSDiscreteSchedulerState, num_inference_steps: int, shape: Tuple = ()
185+
self,
186+
state: LMSDiscreteSchedulerState,
187+
num_inference_steps: int,
188+
shape: Tuple = (),
169189
) -> LMSDiscreteSchedulerState:
170190
"""
171191
Sets the timesteps used for the diffusion chain. Supporting function to be run before inference.
@@ -177,7 +197,12 @@ def set_timesteps(
177197
the number of diffusion steps used when generating samples with a pre-trained model.
178198
"""
179199

180-
timesteps = jnp.linspace(self.config.num_train_timesteps - 1, 0, num_inference_steps, dtype=self.dtype)
200+
timesteps = jnp.linspace(
201+
self.config.num_train_timesteps - 1,
202+
0,
203+
num_inference_steps,
204+
dtype=self.dtype,
205+
)
181206

182207
low_idx = jnp.floor(timesteps).astype(jnp.int32)
183208
high_idx = jnp.ceil(timesteps).astype(jnp.int32)

src/diffusers/schedulers/scheduling_pndm_flax.py

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import jax.numpy as jnp
2323

2424
from ..configuration_utils import ConfigMixin, register_to_config
25+
from ..utils import logging
2526
from .scheduling_utils_flax import (
2627
CommonSchedulerState,
2728
FlaxKarrasDiffusionSchedulers,
@@ -31,6 +32,9 @@
3132
)
3233

3334

35+
logger = logging.get_logger(__name__)
36+
37+
3438
@flax.struct.dataclass
3539
class PNDMSchedulerState:
3640
common: CommonSchedulerState
@@ -131,6 +135,10 @@ def __init__(
131135
prediction_type: str = "epsilon",
132136
dtype: jnp.dtype = jnp.float32,
133137
):
138+
logger.warning(
139+
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
140+
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
141+
)
134142
self.dtype = dtype
135143

136144
# For now we only support F-PNDM, i.e. the runge-kutta method
@@ -190,7 +198,10 @@ def set_timesteps(self, state: PNDMSchedulerState, num_inference_steps: int, sha
190198

191199
else:
192200
prk_timesteps = _timesteps[-self.pndm_order :].repeat(2) + jnp.tile(
193-
jnp.array([0, self.config.num_train_timesteps // num_inference_steps // 2], dtype=jnp.int32),
201+
jnp.array(
202+
[0, self.config.num_train_timesteps // num_inference_steps // 2],
203+
dtype=jnp.int32,
204+
),
194205
self.pndm_order,
195206
)
196207

@@ -218,7 +229,10 @@ def set_timesteps(self, state: PNDMSchedulerState, num_inference_steps: int, sha
218229
)
219230

220231
def scale_model_input(
221-
self, state: PNDMSchedulerState, sample: jnp.ndarray, timestep: Optional[int] = None
232+
self,
233+
state: PNDMSchedulerState,
234+
sample: jnp.ndarray,
235+
timestep: Optional[int] = None,
222236
) -> jnp.ndarray:
223237
"""
224238
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
@@ -320,7 +334,9 @@ def step_prk(
320334
)
321335

322336
diff_to_prev = jnp.where(
323-
state.counter % 2, 0, self.config.num_train_timesteps // state.num_inference_steps // 2
337+
state.counter % 2,
338+
0,
339+
self.config.num_train_timesteps // state.num_inference_steps // 2,
324340
)
325341
prev_timestep = timestep - diff_to_prev
326342
timestep = state.prk_timesteps[state.counter // 4 * 4]
@@ -401,7 +417,9 @@ def step_plms(
401417

402418
prev_timestep = jnp.where(state.counter == 1, timestep, prev_timestep)
403419
timestep = jnp.where(
404-
state.counter == 1, timestep + self.config.num_train_timesteps // state.num_inference_steps, timestep
420+
state.counter == 1,
421+
timestep + self.config.num_train_timesteps // state.num_inference_steps,
422+
timestep,
405423
)
406424

407425
# Reference:
@@ -466,7 +484,9 @@ def _get_prev_sample(self, state: PNDMSchedulerState, sample, timestep, prev_tim
466484
# prev_sample -> x_(t−δ)
467485
alpha_prod_t = state.common.alphas_cumprod[timestep]
468486
alpha_prod_t_prev = jnp.where(
469-
prev_timestep >= 0, state.common.alphas_cumprod[prev_timestep], state.final_alpha_cumprod
487+
prev_timestep >= 0,
488+
state.common.alphas_cumprod[prev_timestep],
489+
state.final_alpha_cumprod,
470490
)
471491
beta_prod_t = 1 - alpha_prod_t
472492
beta_prod_t_prev = 1 - alpha_prod_t_prev

0 commit comments

Comments
 (0)