Skip to content
This repository was archived by the owner on Feb 7, 2025. It is now read-only.

Commit d37a752

Browse files
authored
Pretrained DDPM (#177)
* Working DDP training run * Allows for choice of output dir * Fixes conflicting arg * Save checkpoints so they can be loaded without DataParallel * Tidies up * Adds option for pretrained in inpainting tutorial * Updates compare schedulers with pretrained option * Updates 2d tutorial * Updates v prediction * Reverts update v prediction * Adds type hinting and removes unused arg
1 parent 5adb94c commit d37a752

File tree

7 files changed

+1053
-540
lines changed

7 files changed

+1053
-540
lines changed

tutorials/generative/2d_ddpm/2d_ddpm_compare_schedulers.ipynb

Lines changed: 187 additions & 101 deletions
Large diffs are not rendered by default.

tutorials/generative/2d_ddpm/2d_ddpm_compare_schedulers.py

Lines changed: 155 additions & 102 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
# format_version: '1.3'
99
# jupytext_version: 1.14.1
1010
# kernelspec:
11-
# display_name: Python 3
11+
# display_name: Python 3 (ipykernel)
1212
# language: python
1313
# name: python3
1414
# ---
@@ -204,124 +204,177 @@
204204
# %% [markdown]
205205
# ### Model training
206206
# Here, we are training our model for 100 epochs (training time: ~40 minutes). It is necessary to train for a bit longer than other tutorials because the DDIM and PNDM schedules seem to require a model trained longer before they start producing good samples, when compared to DDPM.
207+
#
208+
# If you would like to skip the training and use a pre-trained model instead, set `use_pretrained=True`. This model was trained using the code in `tutorials/generative/distributed_training/ddpm_training_ddp.py`
207209

208210
# %%
209-
n_epochs = 100
210-
val_interval = 10
211-
epoch_loss_list = []
212-
val_epoch_loss_list = []
213-
for epoch in range(n_epochs):
214-
model.train()
215-
epoch_loss = 0
216-
progress_bar = tqdm(enumerate(train_loader), total=len(train_loader))
217-
progress_bar.set_description(f"Epoch {epoch}")
218-
for step, batch in progress_bar:
219-
images = batch["image"].to(device)
220-
optimizer.zero_grad(set_to_none=True)
221-
222-
# Randomly select the timesteps to be used for the minibacth
223-
timesteps = torch.randint(0, ddpm_scheduler.num_train_timesteps, (images.shape[0],), device=device).long()
224-
225-
# Add noise to the minibatch images with intensity defined by the scheduler and timesteps
226-
noise = torch.randn_like(images).to(device)
227-
noisy_image = ddpm_scheduler.add_noise(original_samples=images, noise=noise, timesteps=timesteps)
228-
229-
# In this example, we are parametrising our DDPM to learn the added noise (epsilon).
230-
# For this reason, we are using our network to predict the added noise and then using L1 loss to predict
231-
# its performance.
232-
noise_pred = model(x=noisy_image, timesteps=timesteps)
233-
loss = F.l1_loss(noise_pred.float(), noise.float())
234-
235-
loss.backward()
236-
optimizer.step()
237-
epoch_loss += loss.item()
238-
239-
progress_bar.set_postfix(
240-
{
241-
"loss": epoch_loss / (step + 1),
242-
}
243-
)
244-
epoch_loss_list.append(epoch_loss / (step + 1))
245-
246-
if (epoch + 1) % val_interval == 0:
247-
model.eval()
248-
val_epoch_loss = 0
249-
progress_bar = tqdm(enumerate(val_loader), total=len(train_loader))
250-
progress_bar.set_description(f"Epoch {epoch} - Validation set")
211+
use_pretrained = False
212+
213+
if use_pretrained:
214+
model = torch.hub.load("marksgraham/pretrained_generative_models", model="ddpm_2d", verbose=True).to(device)
215+
else:
216+
n_epochs = 100
217+
val_interval = 10
218+
epoch_loss_list = []
219+
val_epoch_loss_list = []
220+
for epoch in range(n_epochs):
221+
model.train()
222+
epoch_loss = 0
223+
progress_bar = tqdm(enumerate(train_loader), total=len(train_loader))
224+
progress_bar.set_description(f"Epoch {epoch}")
251225
for step, batch in progress_bar:
252226
images = batch["image"].to(device)
227+
optimizer.zero_grad(set_to_none=True)
228+
229+
# Randomly select the timesteps to be used for the minibacth
253230
timesteps = torch.randint(0, ddpm_scheduler.num_train_timesteps, (images.shape[0],), device=device).long()
231+
232+
# Add noise to the minibatch images with intensity defined by the scheduler and timesteps
254233
noise = torch.randn_like(images).to(device)
255-
with torch.no_grad():
256-
noisy_image = ddpm_scheduler.add_noise(original_samples=images, noise=noise, timesteps=timesteps)
257-
noise_pred = model(x=noisy_image, timesteps=timesteps)
258-
val_loss = F.l1_loss(noise_pred.float(), noise.float())
234+
noisy_image = ddpm_scheduler.add_noise(original_samples=images, noise=noise, timesteps=timesteps)
235+
236+
# In this example, we are parametrising our DDPM to learn the added noise (epsilon).
237+
# For this reason, we are using our network to predict the added noise and then using L1 loss to predict
238+
# its performance.
239+
noise_pred = model(x=noisy_image, timesteps=timesteps)
240+
loss = F.l1_loss(noise_pred.float(), noise.float())
241+
242+
loss.backward()
243+
optimizer.step()
244+
epoch_loss += loss.item()
259245

260-
val_epoch_loss += val_loss.item()
261246
progress_bar.set_postfix(
262247
{
263-
"val_loss": val_epoch_loss / (step + 1),
248+
"loss": epoch_loss / (step + 1),
264249
}
265250
)
266-
val_epoch_loss_list.append(val_epoch_loss / (step + 1))
267-
268-
# Sampling image during training
269-
noise = torch.randn((1, 1, 64, 64))
270-
noise = noise.to(device)
271-
image = inferer.sample(input_noise=noise, diffusion_model=model, scheduler=ddpm_scheduler)
272-
plt.figure(figsize=(8, 4))
273-
plt.subplot(3, len(sampling_steps), 1)
274-
plt.imshow(image[0, 0].cpu(), vmin=0, vmax=1, cmap="gray")
275-
plt.tick_params(top=False, bottom=False, left=False, right=False, labelleft=False, labelbottom=False)
276-
plt.ylabel("DDPM")
277-
plt.title("1000 steps")
278-
# DDIM
279-
for idx, reduced_sampling_steps in enumerate(sampling_steps):
280-
ddim_scheduler.set_timesteps(reduced_sampling_steps)
281-
image = inferer.sample(input_noise=noise, diffusion_model=model, scheduler=ddim_scheduler)
282-
plt.subplot(3, len(sampling_steps), len(sampling_steps) + idx + 1)
251+
epoch_loss_list.append(epoch_loss / (step + 1))
252+
253+
if (epoch + 1) % val_interval == 0:
254+
model.eval()
255+
val_epoch_loss = 0
256+
progress_bar = tqdm(enumerate(val_loader), total=len(train_loader))
257+
progress_bar.set_description(f"Epoch {epoch} - Validation set")
258+
for step, batch in progress_bar:
259+
images = batch["image"].to(device)
260+
timesteps = torch.randint(
261+
0, ddpm_scheduler.num_train_timesteps, (images.shape[0],), device=device
262+
).long()
263+
noise = torch.randn_like(images).to(device)
264+
with torch.no_grad():
265+
noisy_image = ddpm_scheduler.add_noise(original_samples=images, noise=noise, timesteps=timesteps)
266+
noise_pred = model(x=noisy_image, timesteps=timesteps)
267+
val_loss = F.l1_loss(noise_pred.float(), noise.float())
268+
269+
val_epoch_loss += val_loss.item()
270+
progress_bar.set_postfix(
271+
{
272+
"val_loss": val_epoch_loss / (step + 1),
273+
}
274+
)
275+
val_epoch_loss_list.append(val_epoch_loss / (step + 1))
276+
277+
# Sampling image during training
278+
noise = torch.randn((1, 1, 64, 64))
279+
noise = noise.to(device)
280+
image = inferer.sample(input_noise=noise, diffusion_model=model, scheduler=ddpm_scheduler)
281+
plt.figure(figsize=(8, 4))
282+
plt.subplot(3, len(sampling_steps), 1)
283283
plt.imshow(image[0, 0].cpu(), vmin=0, vmax=1, cmap="gray")
284-
plt.ylabel("DDIM")
285-
if idx == 0:
286-
plt.tick_params(top=False, bottom=False, left=False, right=False, labelleft=False, labelbottom=False)
287-
else:
288-
plt.axis("off")
289-
plt.title(f"{reduced_sampling_steps} steps")
290-
# PNDM
291-
for idx, reduced_sampling_steps in enumerate(sampling_steps):
292-
pndm_scheduler.set_timesteps(reduced_sampling_steps)
293-
image = inferer.sample(input_noise=noise, diffusion_model=model, scheduler=pndm_scheduler)
294-
plt.subplot(3, len(sampling_steps), len(sampling_steps) * 2 + idx + 1)
295-
plt.imshow(image[0, 0].cpu(), vmin=0, vmax=1, cmap="gray")
296-
plt.ylabel("PNDM")
297-
if idx == 0:
298-
plt.tick_params(top=False, bottom=False, left=False, right=False, labelleft=False, labelbottom=False)
299-
else:
300-
plt.axis("off")
301-
plt.title(f"{reduced_sampling_steps} steps")
302-
plt.suptitle(f"Epoch {epoch+1}")
303-
plt.show()
284+
plt.tick_params(top=False, bottom=False, left=False, right=False, labelleft=False, labelbottom=False)
285+
plt.ylabel("DDPM")
286+
plt.title("1000 steps")
287+
# DDIM
288+
for idx, reduced_sampling_steps in enumerate(sampling_steps):
289+
ddim_scheduler.set_timesteps(reduced_sampling_steps)
290+
image = inferer.sample(input_noise=noise, diffusion_model=model, scheduler=ddim_scheduler)
291+
plt.subplot(3, len(sampling_steps), len(sampling_steps) + idx + 1)
292+
plt.imshow(image[0, 0].cpu(), vmin=0, vmax=1, cmap="gray")
293+
plt.ylabel("DDIM")
294+
if idx == 0:
295+
plt.tick_params(
296+
top=False, bottom=False, left=False, right=False, labelleft=False, labelbottom=False
297+
)
298+
else:
299+
plt.axis("off")
300+
plt.title(f"{reduced_sampling_steps} steps")
301+
# PNDM
302+
for idx, reduced_sampling_steps in enumerate(sampling_steps):
303+
pndm_scheduler.set_timesteps(reduced_sampling_steps)
304+
image = inferer.sample(input_noise=noise, diffusion_model=model, scheduler=pndm_scheduler)
305+
plt.subplot(3, len(sampling_steps), len(sampling_steps) * 2 + idx + 1)
306+
plt.imshow(image[0, 0].cpu(), vmin=0, vmax=1, cmap="gray")
307+
plt.ylabel("PNDM")
308+
if idx == 0:
309+
plt.tick_params(
310+
top=False, bottom=False, left=False, right=False, labelleft=False, labelbottom=False
311+
)
312+
else:
313+
plt.axis("off")
314+
plt.title(f"{reduced_sampling_steps} steps")
315+
plt.suptitle(f"Epoch {epoch+1}")
316+
plt.show()
304317
# %% [markdown]
305318
# ### Learning curves
306319

307320
# %%
308-
plt.style.use("seaborn")
309-
plt.title("Learning Curves", fontsize=20)
310-
plt.plot(np.linspace(1, n_epochs, n_epochs), epoch_loss_list, color="C0", linewidth=2.0, label="Train")
311-
plt.plot(
312-
np.linspace(val_interval, n_epochs, int(n_epochs / val_interval)),
313-
val_epoch_loss_list,
314-
color="C1",
315-
linewidth=2.0,
316-
label="Validation",
317-
)
318-
plt.yticks(fontsize=12)
319-
plt.xticks(fontsize=12)
320-
plt.xlabel("Epochs", fontsize=16)
321-
plt.ylabel("Loss", fontsize=16)
322-
plt.legend(prop={"size": 14})
323-
plt.show()
321+
if not use_pretrained:
322+
plt.style.use("seaborn")
323+
plt.title("Learning Curves", fontsize=20)
324+
plt.plot(np.linspace(1, n_epochs, n_epochs), epoch_loss_list, color="C0", linewidth=2.0, label="Train")
325+
plt.plot(
326+
np.linspace(val_interval, n_epochs, int(n_epochs / val_interval)),
327+
val_epoch_loss_list,
328+
color="C1",
329+
linewidth=2.0,
330+
label="Validation",
331+
)
332+
plt.yticks(fontsize=12)
333+
plt.xticks(fontsize=12)
334+
plt.xlabel("Epochs", fontsize=16)
335+
plt.ylabel("Loss", fontsize=16)
336+
plt.legend(prop={"size": 14})
337+
plt.show()
338+
339+
340+
# %% [markdown]
341+
# ### Compare samples from trained model
324342

343+
# %%
344+
noise = torch.randn((1, 1, 64, 64))
345+
noise = noise.to(device)
346+
image = inferer.sample(input_noise=noise, diffusion_model=model, scheduler=ddpm_scheduler)
347+
plt.figure(figsize=(8, 4))
348+
plt.subplot(3, len(sampling_steps), 1)
349+
plt.imshow(image[0, 0].cpu(), vmin=0, vmax=1, cmap="gray")
350+
plt.tick_params(top=False, bottom=False, left=False, right=False, labelleft=False, labelbottom=False)
351+
plt.ylabel("DDPM")
352+
plt.title("1000 steps")
353+
# DDIM
354+
for idx, reduced_sampling_steps in enumerate(sampling_steps):
355+
ddim_scheduler.set_timesteps(reduced_sampling_steps)
356+
image = inferer.sample(input_noise=noise, diffusion_model=model, scheduler=ddim_scheduler)
357+
plt.subplot(3, len(sampling_steps), len(sampling_steps) + idx + 1)
358+
plt.imshow(image[0, 0].cpu(), vmin=0, vmax=1, cmap="gray")
359+
plt.ylabel("DDIM")
360+
if idx == 0:
361+
plt.tick_params(top=False, bottom=False, left=False, right=False, labelleft=False, labelbottom=False)
362+
else:
363+
plt.axis("off")
364+
plt.title(f"{reduced_sampling_steps} steps")
365+
# PNDM
366+
for idx, reduced_sampling_steps in enumerate(sampling_steps):
367+
pndm_scheduler.set_timesteps(reduced_sampling_steps)
368+
image = inferer.sample(input_noise=noise, diffusion_model=model, scheduler=pndm_scheduler)
369+
plt.subplot(3, len(sampling_steps), len(sampling_steps) * 2 + idx + 1)
370+
plt.imshow(image[0, 0].cpu(), vmin=0, vmax=1, cmap="gray")
371+
plt.ylabel("PNDM")
372+
if idx == 0:
373+
plt.tick_params(top=False, bottom=False, left=False, right=False, labelleft=False, labelbottom=False)
374+
else:
375+
plt.axis("off")
376+
plt.title(f"{reduced_sampling_steps} steps")
377+
plt.show()
325378

326379
# %% [markdown]
327380
# ### Cleanup data directory

0 commit comments

Comments
 (0)