|
8 | 8 | # format_version: '1.3' |
9 | 9 | # jupytext_version: 1.14.1 |
10 | 10 | # kernelspec: |
11 | | -# display_name: Python 3 |
| 11 | +# display_name: Python 3 (ipykernel) |
12 | 12 | # language: python |
13 | 13 | # name: python3 |
14 | 14 | # --- |
|
204 | 204 | # %% [markdown] |
205 | 205 | # ### Model training |
206 | 206 | # Here, we are training our model for 100 epochs (training time: ~40 minutes). It is necessary to train for a bit longer than other tutorials because the DDIM and PNDM schedules seem to require a model trained longer before they start producing good samples, when compared to DDPM. |
| 207 | +# |
| 208 | +# If you would like to skip the training and use a pre-trained model instead, set `use_pretrained=True`. This model was trained using the code in `tutorials/generative/distributed_training/ddpm_training_ddp.py` |
207 | 209 |
|
208 | 210 | # %% |
209 | | -n_epochs = 100 |
210 | | -val_interval = 10 |
211 | | -epoch_loss_list = [] |
212 | | -val_epoch_loss_list = [] |
213 | | -for epoch in range(n_epochs): |
214 | | - model.train() |
215 | | - epoch_loss = 0 |
216 | | - progress_bar = tqdm(enumerate(train_loader), total=len(train_loader)) |
217 | | - progress_bar.set_description(f"Epoch {epoch}") |
218 | | - for step, batch in progress_bar: |
219 | | - images = batch["image"].to(device) |
220 | | - optimizer.zero_grad(set_to_none=True) |
221 | | - |
222 | | - # Randomly select the timesteps to be used for the minibacth |
223 | | - timesteps = torch.randint(0, ddpm_scheduler.num_train_timesteps, (images.shape[0],), device=device).long() |
224 | | - |
225 | | - # Add noise to the minibatch images with intensity defined by the scheduler and timesteps |
226 | | - noise = torch.randn_like(images).to(device) |
227 | | - noisy_image = ddpm_scheduler.add_noise(original_samples=images, noise=noise, timesteps=timesteps) |
228 | | - |
229 | | - # In this example, we are parametrising our DDPM to learn the added noise (epsilon). |
230 | | - # For this reason, we are using our network to predict the added noise and then using L1 loss to predict |
231 | | - # its performance. |
232 | | - noise_pred = model(x=noisy_image, timesteps=timesteps) |
233 | | - loss = F.l1_loss(noise_pred.float(), noise.float()) |
234 | | - |
235 | | - loss.backward() |
236 | | - optimizer.step() |
237 | | - epoch_loss += loss.item() |
238 | | - |
239 | | - progress_bar.set_postfix( |
240 | | - { |
241 | | - "loss": epoch_loss / (step + 1), |
242 | | - } |
243 | | - ) |
244 | | - epoch_loss_list.append(epoch_loss / (step + 1)) |
245 | | - |
246 | | - if (epoch + 1) % val_interval == 0: |
247 | | - model.eval() |
248 | | - val_epoch_loss = 0 |
249 | | - progress_bar = tqdm(enumerate(val_loader), total=len(train_loader)) |
250 | | - progress_bar.set_description(f"Epoch {epoch} - Validation set") |
| 211 | +use_pretrained = False |
| 212 | + |
| 213 | +if use_pretrained: |
| 214 | + model = torch.hub.load("marksgraham/pretrained_generative_models", model="ddpm_2d", verbose=True).to(device) |
| 215 | +else: |
| 216 | + n_epochs = 100 |
| 217 | + val_interval = 10 |
| 218 | + epoch_loss_list = [] |
| 219 | + val_epoch_loss_list = [] |
| 220 | + for epoch in range(n_epochs): |
| 221 | + model.train() |
| 222 | + epoch_loss = 0 |
| 223 | + progress_bar = tqdm(enumerate(train_loader), total=len(train_loader)) |
| 224 | + progress_bar.set_description(f"Epoch {epoch}") |
251 | 225 | for step, batch in progress_bar: |
252 | 226 | images = batch["image"].to(device) |
| 227 | + optimizer.zero_grad(set_to_none=True) |
| 228 | + |
| 229 | + # Randomly select the timesteps to be used for the minibacth |
253 | 230 | timesteps = torch.randint(0, ddpm_scheduler.num_train_timesteps, (images.shape[0],), device=device).long() |
| 231 | + |
| 232 | + # Add noise to the minibatch images with intensity defined by the scheduler and timesteps |
254 | 233 | noise = torch.randn_like(images).to(device) |
255 | | - with torch.no_grad(): |
256 | | - noisy_image = ddpm_scheduler.add_noise(original_samples=images, noise=noise, timesteps=timesteps) |
257 | | - noise_pred = model(x=noisy_image, timesteps=timesteps) |
258 | | - val_loss = F.l1_loss(noise_pred.float(), noise.float()) |
| 234 | + noisy_image = ddpm_scheduler.add_noise(original_samples=images, noise=noise, timesteps=timesteps) |
| 235 | + |
| 236 | + # In this example, we are parametrising our DDPM to learn the added noise (epsilon). |
| 237 | + # For this reason, we are using our network to predict the added noise and then using L1 loss to predict |
| 238 | + # its performance. |
| 239 | + noise_pred = model(x=noisy_image, timesteps=timesteps) |
| 240 | + loss = F.l1_loss(noise_pred.float(), noise.float()) |
| 241 | + |
| 242 | + loss.backward() |
| 243 | + optimizer.step() |
| 244 | + epoch_loss += loss.item() |
259 | 245 |
|
260 | | - val_epoch_loss += val_loss.item() |
261 | 246 | progress_bar.set_postfix( |
262 | 247 | { |
263 | | - "val_loss": val_epoch_loss / (step + 1), |
| 248 | + "loss": epoch_loss / (step + 1), |
264 | 249 | } |
265 | 250 | ) |
266 | | - val_epoch_loss_list.append(val_epoch_loss / (step + 1)) |
267 | | - |
268 | | - # Sampling image during training |
269 | | - noise = torch.randn((1, 1, 64, 64)) |
270 | | - noise = noise.to(device) |
271 | | - image = inferer.sample(input_noise=noise, diffusion_model=model, scheduler=ddpm_scheduler) |
272 | | - plt.figure(figsize=(8, 4)) |
273 | | - plt.subplot(3, len(sampling_steps), 1) |
274 | | - plt.imshow(image[0, 0].cpu(), vmin=0, vmax=1, cmap="gray") |
275 | | - plt.tick_params(top=False, bottom=False, left=False, right=False, labelleft=False, labelbottom=False) |
276 | | - plt.ylabel("DDPM") |
277 | | - plt.title("1000 steps") |
278 | | - # DDIM |
279 | | - for idx, reduced_sampling_steps in enumerate(sampling_steps): |
280 | | - ddim_scheduler.set_timesteps(reduced_sampling_steps) |
281 | | - image = inferer.sample(input_noise=noise, diffusion_model=model, scheduler=ddim_scheduler) |
282 | | - plt.subplot(3, len(sampling_steps), len(sampling_steps) + idx + 1) |
| 251 | + epoch_loss_list.append(epoch_loss / (step + 1)) |
| 252 | + |
| 253 | + if (epoch + 1) % val_interval == 0: |
| 254 | + model.eval() |
| 255 | + val_epoch_loss = 0 |
| 256 | + progress_bar = tqdm(enumerate(val_loader), total=len(train_loader)) |
| 257 | + progress_bar.set_description(f"Epoch {epoch} - Validation set") |
| 258 | + for step, batch in progress_bar: |
| 259 | + images = batch["image"].to(device) |
| 260 | + timesteps = torch.randint( |
| 261 | + 0, ddpm_scheduler.num_train_timesteps, (images.shape[0],), device=device |
| 262 | + ).long() |
| 263 | + noise = torch.randn_like(images).to(device) |
| 264 | + with torch.no_grad(): |
| 265 | + noisy_image = ddpm_scheduler.add_noise(original_samples=images, noise=noise, timesteps=timesteps) |
| 266 | + noise_pred = model(x=noisy_image, timesteps=timesteps) |
| 267 | + val_loss = F.l1_loss(noise_pred.float(), noise.float()) |
| 268 | + |
| 269 | + val_epoch_loss += val_loss.item() |
| 270 | + progress_bar.set_postfix( |
| 271 | + { |
| 272 | + "val_loss": val_epoch_loss / (step + 1), |
| 273 | + } |
| 274 | + ) |
| 275 | + val_epoch_loss_list.append(val_epoch_loss / (step + 1)) |
| 276 | + |
| 277 | + # Sampling image during training |
| 278 | + noise = torch.randn((1, 1, 64, 64)) |
| 279 | + noise = noise.to(device) |
| 280 | + image = inferer.sample(input_noise=noise, diffusion_model=model, scheduler=ddpm_scheduler) |
| 281 | + plt.figure(figsize=(8, 4)) |
| 282 | + plt.subplot(3, len(sampling_steps), 1) |
283 | 283 | plt.imshow(image[0, 0].cpu(), vmin=0, vmax=1, cmap="gray") |
284 | | - plt.ylabel("DDIM") |
285 | | - if idx == 0: |
286 | | - plt.tick_params(top=False, bottom=False, left=False, right=False, labelleft=False, labelbottom=False) |
287 | | - else: |
288 | | - plt.axis("off") |
289 | | - plt.title(f"{reduced_sampling_steps} steps") |
290 | | - # PNDM |
291 | | - for idx, reduced_sampling_steps in enumerate(sampling_steps): |
292 | | - pndm_scheduler.set_timesteps(reduced_sampling_steps) |
293 | | - image = inferer.sample(input_noise=noise, diffusion_model=model, scheduler=pndm_scheduler) |
294 | | - plt.subplot(3, len(sampling_steps), len(sampling_steps) * 2 + idx + 1) |
295 | | - plt.imshow(image[0, 0].cpu(), vmin=0, vmax=1, cmap="gray") |
296 | | - plt.ylabel("PNDM") |
297 | | - if idx == 0: |
298 | | - plt.tick_params(top=False, bottom=False, left=False, right=False, labelleft=False, labelbottom=False) |
299 | | - else: |
300 | | - plt.axis("off") |
301 | | - plt.title(f"{reduced_sampling_steps} steps") |
302 | | - plt.suptitle(f"Epoch {epoch+1}") |
303 | | - plt.show() |
| 284 | + plt.tick_params(top=False, bottom=False, left=False, right=False, labelleft=False, labelbottom=False) |
| 285 | + plt.ylabel("DDPM") |
| 286 | + plt.title("1000 steps") |
| 287 | + # DDIM |
| 288 | + for idx, reduced_sampling_steps in enumerate(sampling_steps): |
| 289 | + ddim_scheduler.set_timesteps(reduced_sampling_steps) |
| 290 | + image = inferer.sample(input_noise=noise, diffusion_model=model, scheduler=ddim_scheduler) |
| 291 | + plt.subplot(3, len(sampling_steps), len(sampling_steps) + idx + 1) |
| 292 | + plt.imshow(image[0, 0].cpu(), vmin=0, vmax=1, cmap="gray") |
| 293 | + plt.ylabel("DDIM") |
| 294 | + if idx == 0: |
| 295 | + plt.tick_params( |
| 296 | + top=False, bottom=False, left=False, right=False, labelleft=False, labelbottom=False |
| 297 | + ) |
| 298 | + else: |
| 299 | + plt.axis("off") |
| 300 | + plt.title(f"{reduced_sampling_steps} steps") |
| 301 | + # PNDM |
| 302 | + for idx, reduced_sampling_steps in enumerate(sampling_steps): |
| 303 | + pndm_scheduler.set_timesteps(reduced_sampling_steps) |
| 304 | + image = inferer.sample(input_noise=noise, diffusion_model=model, scheduler=pndm_scheduler) |
| 305 | + plt.subplot(3, len(sampling_steps), len(sampling_steps) * 2 + idx + 1) |
| 306 | + plt.imshow(image[0, 0].cpu(), vmin=0, vmax=1, cmap="gray") |
| 307 | + plt.ylabel("PNDM") |
| 308 | + if idx == 0: |
| 309 | + plt.tick_params( |
| 310 | + top=False, bottom=False, left=False, right=False, labelleft=False, labelbottom=False |
| 311 | + ) |
| 312 | + else: |
| 313 | + plt.axis("off") |
| 314 | + plt.title(f"{reduced_sampling_steps} steps") |
| 315 | + plt.suptitle(f"Epoch {epoch+1}") |
| 316 | + plt.show() |
304 | 317 | # %% [markdown] |
305 | 318 | # ### Learning curves |
306 | 319 |
|
307 | 320 | # %% |
308 | | -plt.style.use("seaborn") |
309 | | -plt.title("Learning Curves", fontsize=20) |
310 | | -plt.plot(np.linspace(1, n_epochs, n_epochs), epoch_loss_list, color="C0", linewidth=2.0, label="Train") |
311 | | -plt.plot( |
312 | | - np.linspace(val_interval, n_epochs, int(n_epochs / val_interval)), |
313 | | - val_epoch_loss_list, |
314 | | - color="C1", |
315 | | - linewidth=2.0, |
316 | | - label="Validation", |
317 | | -) |
318 | | -plt.yticks(fontsize=12) |
319 | | -plt.xticks(fontsize=12) |
320 | | -plt.xlabel("Epochs", fontsize=16) |
321 | | -plt.ylabel("Loss", fontsize=16) |
322 | | -plt.legend(prop={"size": 14}) |
323 | | -plt.show() |
| 321 | +if not use_pretrained: |
| 322 | + plt.style.use("seaborn") |
| 323 | + plt.title("Learning Curves", fontsize=20) |
| 324 | + plt.plot(np.linspace(1, n_epochs, n_epochs), epoch_loss_list, color="C0", linewidth=2.0, label="Train") |
| 325 | + plt.plot( |
| 326 | + np.linspace(val_interval, n_epochs, int(n_epochs / val_interval)), |
| 327 | + val_epoch_loss_list, |
| 328 | + color="C1", |
| 329 | + linewidth=2.0, |
| 330 | + label="Validation", |
| 331 | + ) |
| 332 | + plt.yticks(fontsize=12) |
| 333 | + plt.xticks(fontsize=12) |
| 334 | + plt.xlabel("Epochs", fontsize=16) |
| 335 | + plt.ylabel("Loss", fontsize=16) |
| 336 | + plt.legend(prop={"size": 14}) |
| 337 | + plt.show() |
| 338 | + |
| 339 | + |
| 340 | +# %% [markdown] |
| 341 | +# ### Compare samples from trained model |
324 | 342 |
|
| 343 | +# %% |
| 344 | +noise = torch.randn((1, 1, 64, 64)) |
| 345 | +noise = noise.to(device) |
| 346 | +image = inferer.sample(input_noise=noise, diffusion_model=model, scheduler=ddpm_scheduler) |
| 347 | +plt.figure(figsize=(8, 4)) |
| 348 | +plt.subplot(3, len(sampling_steps), 1) |
| 349 | +plt.imshow(image[0, 0].cpu(), vmin=0, vmax=1, cmap="gray") |
| 350 | +plt.tick_params(top=False, bottom=False, left=False, right=False, labelleft=False, labelbottom=False) |
| 351 | +plt.ylabel("DDPM") |
| 352 | +plt.title("1000 steps") |
| 353 | +# DDIM |
| 354 | +for idx, reduced_sampling_steps in enumerate(sampling_steps): |
| 355 | + ddim_scheduler.set_timesteps(reduced_sampling_steps) |
| 356 | + image = inferer.sample(input_noise=noise, diffusion_model=model, scheduler=ddim_scheduler) |
| 357 | + plt.subplot(3, len(sampling_steps), len(sampling_steps) + idx + 1) |
| 358 | + plt.imshow(image[0, 0].cpu(), vmin=0, vmax=1, cmap="gray") |
| 359 | + plt.ylabel("DDIM") |
| 360 | + if idx == 0: |
| 361 | + plt.tick_params(top=False, bottom=False, left=False, right=False, labelleft=False, labelbottom=False) |
| 362 | + else: |
| 363 | + plt.axis("off") |
| 364 | + plt.title(f"{reduced_sampling_steps} steps") |
| 365 | +# PNDM |
| 366 | +for idx, reduced_sampling_steps in enumerate(sampling_steps): |
| 367 | + pndm_scheduler.set_timesteps(reduced_sampling_steps) |
| 368 | + image = inferer.sample(input_noise=noise, diffusion_model=model, scheduler=pndm_scheduler) |
| 369 | + plt.subplot(3, len(sampling_steps), len(sampling_steps) * 2 + idx + 1) |
| 370 | + plt.imshow(image[0, 0].cpu(), vmin=0, vmax=1, cmap="gray") |
| 371 | + plt.ylabel("PNDM") |
| 372 | + if idx == 0: |
| 373 | + plt.tick_params(top=False, bottom=False, left=False, right=False, labelleft=False, labelbottom=False) |
| 374 | + else: |
| 375 | + plt.axis("off") |
| 376 | + plt.title(f"{reduced_sampling_steps} steps") |
| 377 | +plt.show() |
325 | 378 |
|
326 | 379 | # %% [markdown] |
327 | 380 | # ### Cleanup data directory |
|
0 commit comments