Skip to content

Commit

Permalink
substitute p2 loss weight with min snr loss weight
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Mar 18, 2023
1 parent 726c11a commit 7a21a30
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 25 deletions.
18 changes: 8 additions & 10 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -748,16 +748,6 @@ Anything! It is MIT licensed. In other words, you can freely copy / paste for yo
}
```

```bibtex
@article{Choi2022PerceptionPT,
title = {Perception Prioritized Training of Diffusion Models},
author = {Jooyoung Choi and Jungbeom Lee and Chaehun Shin and Sungwon Kim and Hyunwoo J. Kim and Sung-Hoon Yoon},
journal = {ArXiv},
year = {2022},
volume = {abs/2204.00227}
}
```

```bibtex
@inproceedings{Sankararaman2022BayesFormerTW,
title = {BayesFormer: Transformer with Uncertainty Estimation},
Expand Down Expand Up @@ -898,3 +888,11 @@ Anything! It is MIT licensed. In other words, you can freely copy / paste for yo
status = {to be published - one attention stabilization technique is circulating within Google Brain, being used by multiple teams}
}
```

```bibtex
@inproceedings{Hang2023EfficientDT,
title = {Efficient Diffusion Training via Min-SNR Weighting Strategy},
author = {Tiankai Hang and Shuyang Gu and Chen Li and Jianmin Bao and Dong Chen and Han Hu and Xin Geng and Baining Guo},
year = {2023}
}
```
39 changes: 25 additions & 14 deletions imagen_pytorch/imagen_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1802,14 +1802,14 @@ def __init__(
per_sample_random_aug_noise_level = False, # unclear when conditioning on augmentation noise level, whether each batch element receives a random aug noise value - turning off due to @marunine's find
condition_on_text = True,
auto_normalize_img = True, # whether to take care of normalizing the image from [0, 1] to [-1, 1] and back automatically - you can turn this off if you want to pass in the [-1, 1] ranged image yourself from the dataloader
p2_loss_weight_gamma = 0.5, # p2 loss weight, from https://arxiv.org/abs/2204.00227 - 0 is equivalent to weight of 1 across time
p2_loss_weight_k = 1,
dynamic_thresholding = True,
dynamic_thresholding_percentile = 0.95, # unsure what this was based on perusal of paper
only_train_unet_number = None,
temporal_downsample_factor = 1,
resize_cond_video_frames = True,
resize_mode = 'nearest'
resize_mode = 'nearest',
min_snr_loss_weight = True, # https://arxiv.org/abs/2303.09556
min_snr_gamma = 5
):
super().__init__()

Expand Down Expand Up @@ -1956,12 +1956,13 @@ def __init__(
self.dynamic_thresholding = cast_tuple(dynamic_thresholding, num_unets)
self.dynamic_thresholding_percentile = dynamic_thresholding_percentile

# p2 loss weight
# min snr loss weight

self.p2_loss_weight_k = p2_loss_weight_k
self.p2_loss_weight_gamma = cast_tuple(p2_loss_weight_gamma, num_unets)
min_snr_loss_weight = cast_tuple(min_snr_loss_weight, num_unets)
min_snr_gamma = cast_tuple(min_snr_gamma, num_unets)

assert all([(gamma_value <= 2) for gamma_value in self.p2_loss_weight_gamma]), 'in paper, they noticed any gamma greater than 2 is harmful'
assert len(min_snr_loss_weight) == len(min_snr_gamma) == num_unets
self.min_snr_gamma = tuple((gamma if use_min_snr else None) for use_min_snr, gamma in zip(min_snr_loss_weight, min_snr_gamma))

# one temp parameter for keeping track of device

Expand Down Expand Up @@ -2494,7 +2495,7 @@ def p_losses(
noise = None,
times_next = None,
pred_objective = 'noise',
p2_loss_weight_gamma = 0.,
min_snr_gamma = None,
random_crop_size = None,
**kwargs
):
Expand Down Expand Up @@ -2600,12 +2601,22 @@ def p_losses(
losses = self.loss_fn(pred, target, reduction = 'none')
losses = reduce(losses, 'b ... -> b', 'mean')

# p2 loss reweighting
# min snr loss reweighting

snr = log_snr.exp()
maybe_clipped_snr = snr.clone()

if exists(min_snr_gamma):
maybe_clipped_snr.clamp_(min = min_snr_gamma)

if p2_loss_weight_gamma > 0:
loss_weight = (self.p2_loss_weight_k + log_snr.exp()) ** -p2_loss_weight_gamma
losses = losses * loss_weight
if pred_objective == 'noise':
loss_weight = maybe_clipped_snr / snr
elif pred_objective == 'x_start':
loss_weight = maybe_clipped_snr
elif pred_objective == 'v':
loss_weight = maybe_clipped_snr / (snr + 1)

losses = losses * loss_weight
return losses.mean()

@beartype
Expand Down Expand Up @@ -2641,7 +2652,7 @@ def forward(
assert not isinstance(unet, NullUnet), 'null unet cannot and should not be trained'

noise_scheduler = self.noise_schedulers[unet_index]
p2_loss_weight_gamma = self.p2_loss_weight_gamma[unet_index]
min_snr_gamma = self.min_snr_gamma[unet_index]
pred_objective = self.pred_objectives[unet_index]
target_image_size = self.image_sizes[unet_index]
random_crop_size = self.random_crop_sizes[unet_index]
Expand Down Expand Up @@ -2702,4 +2713,4 @@ def forward(

images = self.resize_to(images, target_image_size, **frames_to_resize_kwargs(target_frame_size))

return self.p_losses(unet, images, times, text_embeds = text_embeds, text_mask = text_masks, cond_images = cond_images, noise_scheduler = noise_scheduler, lowres_cond_img = lowres_cond_img, lowres_aug_times = lowres_aug_times, pred_objective = pred_objective, p2_loss_weight_gamma = p2_loss_weight_gamma, random_crop_size = random_crop_size, **kwargs)
return self.p_losses(unet, images, times, text_embeds = text_embeds, text_mask = text_masks, cond_images = cond_images, noise_scheduler = noise_scheduler, lowres_cond_img = lowres_cond_img, lowres_aug_times = lowres_aug_times, pred_objective = pred_objective, min_snr_gamma = min_snr_gamma, random_crop_size = random_crop_size, **kwargs)
2 changes: 1 addition & 1 deletion imagen_pytorch/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '1.22.4'
__version__ = '1.23.0'

0 comments on commit 7a21a30

Please sign in to comment.