diff --git a/README.md b/README.md index 6a05127..d0468fc 100644 --- a/README.md +++ b/README.md @@ -748,16 +748,6 @@ Anything! It is MIT licensed. In other words, you can freely copy / paste for yo } ``` -```bibtex -@article{Choi2022PerceptionPT, - title = {Perception Prioritized Training of Diffusion Models}, - author = {Jooyoung Choi and Jungbeom Lee and Chaehun Shin and Sungwon Kim and Hyunwoo J. Kim and Sung-Hoon Yoon}, - journal = {ArXiv}, - year = {2022}, - volume = {abs/2204.00227} -} -``` - ```bibtex @inproceedings{Sankararaman2022BayesFormerTW, title = {BayesFormer: Transformer with Uncertainty Estimation}, @@ -898,3 +888,11 @@ Anything! It is MIT licensed. In other words, you can freely copy / paste for yo status = {to be published - one attention stabilization technique is circulating within Google Brain, being used by multiple teams} } ``` + +```bibtex +@inproceedings{Hang2023EfficientDT, + title = {Efficient Diffusion Training via Min-SNR Weighting Strategy}, + author = {Tiankai Hang and Shuyang Gu and Chen Li and Jianmin Bao and Dong Chen and Han Hu and Xin Geng and Baining Guo}, + year = {2023} +} +``` diff --git a/imagen_pytorch/imagen_pytorch.py b/imagen_pytorch/imagen_pytorch.py index b3fa919..794e20d 100644 --- a/imagen_pytorch/imagen_pytorch.py +++ b/imagen_pytorch/imagen_pytorch.py @@ -1802,14 +1802,14 @@ def __init__( per_sample_random_aug_noise_level = False, # unclear when conditioning on augmentation noise level, whether each batch element receives a random aug noise value - turning off due to @marunine's find condition_on_text = True, auto_normalize_img = True, # whether to take care of normalizing the image from [0, 1] to [-1, 1] and back automatically - you can turn this off if you want to pass in the [-1, 1] ranged image yourself from the dataloader - p2_loss_weight_gamma = 0.5, # p2 loss weight, from https://arxiv.org/abs/2204.00227 - 0 is equivalent to weight of 1 across time - p2_loss_weight_k = 1, dynamic_thresholding = True, dynamic_thresholding_percentile = 0.95, # unsure what this was based on perusal of paper only_train_unet_number = None, temporal_downsample_factor = 1, resize_cond_video_frames = True, - resize_mode = 'nearest' + resize_mode = 'nearest', + min_snr_loss_weight = True, # https://arxiv.org/abs/2303.09556 + min_snr_gamma = 5 ): super().__init__() @@ -1956,12 +1956,13 @@ def __init__( self.dynamic_thresholding = cast_tuple(dynamic_thresholding, num_unets) self.dynamic_thresholding_percentile = dynamic_thresholding_percentile - # p2 loss weight + # min snr loss weight - self.p2_loss_weight_k = p2_loss_weight_k - self.p2_loss_weight_gamma = cast_tuple(p2_loss_weight_gamma, num_unets) + min_snr_loss_weight = cast_tuple(min_snr_loss_weight, num_unets) + min_snr_gamma = cast_tuple(min_snr_gamma, num_unets) - assert all([(gamma_value <= 2) for gamma_value in self.p2_loss_weight_gamma]), 'in paper, they noticed any gamma greater than 2 is harmful' + assert len(min_snr_loss_weight) == len(min_snr_gamma) == num_unets + self.min_snr_gamma = tuple((gamma if use_min_snr else None) for use_min_snr, gamma in zip(min_snr_loss_weight, min_snr_gamma)) # one temp parameter for keeping track of device @@ -2494,7 +2495,7 @@ def p_losses( noise = None, times_next = None, pred_objective = 'noise', - p2_loss_weight_gamma = 0., + min_snr_gamma = None, random_crop_size = None, **kwargs ): @@ -2600,12 +2601,22 @@ def p_losses( losses = self.loss_fn(pred, target, reduction = 'none') losses = reduce(losses, 'b ... -> b', 'mean') - # p2 loss reweighting + # min snr loss reweighting + + snr = log_snr.exp() + maybe_clipped_snr = snr.clone() + + if exists(min_snr_gamma): + maybe_clipped_snr.clamp_(min = min_snr_gamma) - if p2_loss_weight_gamma > 0: - loss_weight = (self.p2_loss_weight_k + log_snr.exp()) ** -p2_loss_weight_gamma - losses = losses * loss_weight + if pred_objective == 'noise': + loss_weight = maybe_clipped_snr / snr + elif pred_objective == 'x_start': + loss_weight = maybe_clipped_snr + elif pred_objective == 'v': + loss_weight = maybe_clipped_snr / (snr + 1) + losses = losses * loss_weight return losses.mean() @beartype @@ -2641,7 +2652,7 @@ def forward( assert not isinstance(unet, NullUnet), 'null unet cannot and should not be trained' noise_scheduler = self.noise_schedulers[unet_index] - p2_loss_weight_gamma = self.p2_loss_weight_gamma[unet_index] + min_snr_gamma = self.min_snr_gamma[unet_index] pred_objective = self.pred_objectives[unet_index] target_image_size = self.image_sizes[unet_index] random_crop_size = self.random_crop_sizes[unet_index] @@ -2702,4 +2713,4 @@ def forward( images = self.resize_to(images, target_image_size, **frames_to_resize_kwargs(target_frame_size)) - return self.p_losses(unet, images, times, text_embeds = text_embeds, text_mask = text_masks, cond_images = cond_images, noise_scheduler = noise_scheduler, lowres_cond_img = lowres_cond_img, lowres_aug_times = lowres_aug_times, pred_objective = pred_objective, p2_loss_weight_gamma = p2_loss_weight_gamma, random_crop_size = random_crop_size, **kwargs) + return self.p_losses(unet, images, times, text_embeds = text_embeds, text_mask = text_masks, cond_images = cond_images, noise_scheduler = noise_scheduler, lowres_cond_img = lowres_cond_img, lowres_aug_times = lowres_aug_times, pred_objective = pred_objective, min_snr_gamma = min_snr_gamma, random_crop_size = random_crop_size, **kwargs) diff --git a/imagen_pytorch/version.py b/imagen_pytorch/version.py index c38eab6..3335368 100644 --- a/imagen_pytorch/version.py +++ b/imagen_pytorch/version.py @@ -1 +1 @@ -__version__ = '1.22.4' +__version__ = '1.23.0'