Skip to content

Commit bd8df2d

Browse files
[Pytorch] Pytorch only schedulers (#534)
* pytorch only schedulers * fix style * remove match_shape * pytorch only ddpm * remove SchedulerMixin * remove numpy from karras_ve * fix types * remove numpy from lms_discrete * remove numpy from pndm * fix typo * remove mixin and numpy from sde_vp and ve * remove remaining tensor_format * fix style * sigmas has to be torch tensor * removed set_format in readme * remove set format from docs * remove set_format from pipelines * update tests * fix typo * continue to use mixin * fix imports * removed unsed imports * match shape instead of assuming image shapes * remove import typo * update call to add_noise * use math instead of numpy * fix t_index * removed commented out numpy tests * timesteps needs to be discrete * cast timesteps to int in flax scheduler too * fix device mismatch issue * small fix * Update src/diffusers/schedulers/scheduling_pndm.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
1 parent 3b747de commit bd8df2d

27 files changed

+232
-465
lines changed

docs/source/api/schedulers.mdx

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,7 @@ To this end, the design of schedulers is such that:
4444
The core API for any new scheduler must follow a limited structure.
4545
- Schedulers should provide one or more `def step(...)` functions that should be called to update the generated sample iteratively.
4646
- Schedulers should provide a `set_timesteps(...)` method that configures the parameters of a schedule function for a specific inference task.
47-
- Schedulers should be framework-agnostic, but provide a simple functionality to convert the scheduler into a specific framework, such as PyTorch
48-
with a `set_format(...)` method.
47+
- Schedulers should be framework-specific.
4948

5049
The base class [`SchedulerMixin`] implements low level utilities used by multiple schedulers.
5150

examples/community/clip_guided_stable_diffusion.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -274,7 +274,7 @@ def __call__(
274274
# the model input needs to be scaled to match the continuous ODE formulation in K-LMS
275275
latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5)
276276

277-
# # predict the noise residual
277+
# predict the noise residual
278278
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
279279

280280
# perform classifier free guidance

examples/textual_inversion/textual_inversion.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -424,7 +424,10 @@ def main():
424424

425425
# TODO (patil-suraj): load scheduler using args
426426
noise_scheduler = DDPMScheduler(
427-
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, tensor_format="pt"
427+
beta_start=0.00085,
428+
beta_end=0.012,
429+
beta_schedule="scaled_linear",
430+
num_train_timesteps=1000,
428431
)
429432

430433
train_dataset = TextualInversionDataset(

examples/unconditional_image_generation/train_unconditional.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def main(args):
5959
"UpBlock2D",
6060
),
6161
)
62-
noise_scheduler = DDPMScheduler(num_train_timesteps=1000, tensor_format="pt")
62+
noise_scheduler = DDPMScheduler(num_train_timesteps=1000)
6363
optimizer = torch.optim.AdamW(
6464
model.parameters(),
6565
lr=args.learning_rate,

src/diffusers/pipelines/ddim/pipeline_ddim.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@ class DDIMPipeline(DiffusionPipeline):
3535

3636
def __init__(self, unet, scheduler):
3737
super().__init__()
38-
scheduler = scheduler.set_format("pt")
3938
self.register_modules(unet=unet, scheduler=scheduler)
4039

4140
@torch.no_grad()

src/diffusers/pipelines/ddpm/pipeline_ddpm.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@ class DDPMPipeline(DiffusionPipeline):
3535

3636
def __init__(self, unet, scheduler):
3737
super().__init__()
38-
scheduler = scheduler.set_format("pt")
3938
self.register_modules(unet=unet, scheduler=scheduler)
4039

4140
@torch.no_grad()

src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,6 @@ def __init__(
4545
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
4646
):
4747
super().__init__()
48-
scheduler = scheduler.set_format("pt")
4948
self.register_modules(vqvae=vqvae, bert=bert, tokenizer=tokenizer, unet=unet, scheduler=scheduler)
5049

5150
@torch.no_grad()

src/diffusers/pipelines/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@ class LDMPipeline(DiffusionPipeline):
2323

2424
def __init__(self, vqvae: VQModel, unet: UNet2DModel, scheduler: DDIMScheduler):
2525
super().__init__()
26-
scheduler = scheduler.set_format("pt")
2726
self.register_modules(vqvae=vqvae, unet=unet, scheduler=scheduler)
2827

2928
@torch.no_grad()

src/diffusers/pipelines/pndm/pipeline_pndm.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@ class PNDMPipeline(DiffusionPipeline):
3939

4040
def __init__(self, unet: UNet2DModel, scheduler: PNDMScheduler):
4141
super().__init__()
42-
scheduler = scheduler.set_format("pt")
4342
self.register_modules(unet=unet, scheduler=scheduler)
4443

4544
@torch.no_grad()

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,6 @@ def __init__(
5757
feature_extractor: CLIPFeatureExtractor,
5858
):
5959
super().__init__()
60-
scheduler = scheduler.set_format("pt")
6160

6261
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
6362
warnings.warn(

0 commit comments

Comments
 (0)