Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PNDM] Stable diffusion #186

Merged
merged 2 commits into from
Aug 16, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 41 additions & 12 deletions src/diffusers/schedulers/scheduling_pndm.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def __init__(
beta_end=0.02,
beta_schedule="linear",
tensor_format="pt",
skip_prk_steps=False,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should default to True for all stable diffusion models

):

if beta_schedule == "linear":
Expand Down Expand Up @@ -88,24 +89,35 @@ def __init__(
# setable values
self.num_inference_steps = None
self._timesteps = np.arange(0, num_train_timesteps)[::-1].copy()
self._offset = 0
self.prk_timesteps = None
self.plms_timesteps = None
self.timesteps = None

self.tensor_format = tensor_format
self.set_format(tensor_format=tensor_format)

def set_timesteps(self, num_inference_steps):
def set_timesteps(self, num_inference_steps, offset=0):
self.num_inference_steps = num_inference_steps
self._timesteps = list(
range(0, self.config.num_train_timesteps, self.config.num_train_timesteps // num_inference_steps)
)
self._offset = offset
self._timesteps = [t + self._offset for t in self._timesteps]

if self.config.skip_prk_steps:
# for some models like stable diffusion the prk steps can/should be skipped to
# produce better results. When using PNDM with `self.config.skip_prk_steps` the implementation
# is based on crowsonkb's PLMS sampler implementation: https://github.com/CompVis/latent-diffusion/pull/51
self.prk_timesteps = []
self.plms_timesteps = list(reversed(self._timesteps[:-1] + self._timesteps[-2:-1] + self._timesteps[-1:]))
else:
prk_timesteps = np.array(self._timesteps[-self.pndm_order :]).repeat(2) + np.tile(
np.array([0, self.config.num_train_timesteps // num_inference_steps // 2]), self.pndm_order
)
self.prk_timesteps = list(reversed(prk_timesteps[:-1].repeat(2)[1:-1]))
self.plms_timesteps = list(reversed(self._timesteps[:-3]))

prk_timesteps = np.array(self._timesteps[-self.pndm_order :]).repeat(2) + np.tile(
np.array([0, self.config.num_train_timesteps // num_inference_steps // 2]), self.pndm_order
)
self.prk_timesteps = list(reversed(prk_timesteps[:-1].repeat(2)[1:-1]))
self.plms_timesteps = list(reversed(self._timesteps[:-3]))
self.timesteps = self.prk_timesteps + self.plms_timesteps

self.counter = 0
Expand All @@ -117,7 +129,7 @@ def step(
timestep: int,
sample: Union[torch.FloatTensor, np.ndarray],
):
if self.counter < len(self.prk_timesteps):
if self.counter < len(self.prk_timesteps) and not self.config.skip_prk_steps:
return self.step_prk(model_output=model_output, timestep=timestep, sample=sample)
else:
return self.step_plms(model_output=model_output, timestep=timestep, sample=sample)
Expand Down Expand Up @@ -166,7 +178,7 @@ def step_plms(
Step function propagating the sample with the linear multi-step method. This has one forward pass with multiple
times to approximate the solution.
"""
if len(self.ets) < 3:
if not self.config.skip_prk_steps and len(self.ets) < 3:
raise ValueError(
f"{self.__class__} can only be run AFTER scheduler has been run "
"in 'prk' mode for at least 12 iterations "
Expand All @@ -175,9 +187,26 @@ def step_plms(
)

prev_timestep = max(timestep - self.config.num_train_timesteps // self.num_inference_steps, 0)
self.ets.append(model_output)

model_output = (1 / 24) * (55 * self.ets[-1] - 59 * self.ets[-2] + 37 * self.ets[-3] - 9 * self.ets[-4])
if self.counter != 1:
self.ets.append(model_output)
else:
prev_timestep = timestep
timestep = timestep + self.config.num_train_timesteps // self.num_inference_steps

if len(self.ets) == 1 and self.counter == 0:
model_output = model_output
self.cur_sample = sample
elif len(self.ets) == 1 and self.counter == 1:
model_output = (model_output + self.ets[-1]) / 2
sample = self.cur_sample
self.cur_sample = None
elif len(self.ets) == 2:
model_output = (3 * self.ets[-1] - self.ets[-2]) / 2
elif len(self.ets) == 3:
model_output = (23 * self.ets[-1] - 16 * self.ets[-2] + 5 * self.ets[-3]) / 12
else:
model_output = (1 / 24) * (55 * self.ets[-1] - 59 * self.ets[-2] + 37 * self.ets[-3] - 9 * self.ets[-4])
Comment on lines +206 to +209
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now it looks much more familiar :D


prev_sample = self._get_prev_sample(sample, timestep, prev_timestep, model_output)
self.counter += 1
Expand All @@ -197,8 +226,8 @@ def _get_prev_sample(self, sample, timestep, timestep_prev, model_output):
# sample -> x_t
# model_output -> e_θ(x_t, t)
# prev_sample -> x_(t−δ)
alpha_prod_t = self.alphas_cumprod[timestep + 1]
alpha_prod_t_prev = self.alphas_cumprod[timestep_prev + 1]
alpha_prod_t = self.alphas_cumprod[timestep + 1 - self._offset]
alpha_prod_t_prev = self.alphas_cumprod[timestep_prev + 1 - self._offset]
beta_prod_t = 1 - alpha_prod_t
beta_prod_t_prev = 1 - alpha_prod_t_prev

Expand Down
3 changes: 2 additions & 1 deletion tests/test_modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -843,6 +843,7 @@ def test_ldm_text2img_fast(self):
@slow
@unittest.skipIf(torch_device == "cpu", "Stable diffusion is suppused to run on GPU")
def test_stable_diffusion(self):
# make sure here that pndm scheduler skips prk
sd_pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-1-diffusers")

prompt = "A painting of a squirrel eating a burger"
Expand All @@ -857,7 +858,7 @@ def test_stable_diffusion(self):
image_slice = image[0, -3:, -3:, -1]

assert image.shape == (1, 512, 512, 3)
expected_slice = np.array([0.898, 0.9194, 0.91, 0.8955, 0.915, 0.919, 0.9233, 0.9307, 0.8887])
expected_slice = np.array([0.8887, 0.915, 0.91, 0.894, 0.909, 0.912, 0.919, 0.925, 0.883])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2

@slow
Expand Down