Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Efficient Diffusion Training via Min-SNR Weighting Strategy #308

Merged
merged 7 commits into from
Mar 26, 2023

Conversation

AI-Casanova
Copy link
Contributor

Implementation of https://arxiv.org/abs/2303.09556

Low noise timesteps produce outsized loss (as I discovered on my own here #294), which can lead to training instability as single samples make large steps in a direction that may not be advantageous.

This paper introduces a scaling factor gamma, accessible with the new argument --min_snr_gamma that lowers the weight of these low timesteps by calculating the signal to noise ratio.

image
From the highest loss to lowest gamma=20,5,4,3,2,1

image
(Generated from the losses above)

@AI-Casanova
Copy link
Contributor Author

@bmaltais I thought you may be interested in trying this.

@bmaltais
Copy link
Contributor

@bmaltais I thought you may be interested in trying this.

Thank you, look interesting.

@ryukra
Copy link

ryukra commented Mar 21, 2023

what does the comparison show when the paper talks about converging faster?

@bmaltais
Copy link
Contributor

One possible improvement to the code could be to create a function in a seperate custom_train_functions.py file to be called by all trainer:

import torch

def apply_snr_weight(loss, noisy_latents, latents, gamma):
    gamma = gamma
    if gamma:
        sigma = torch.sub(noisy_latents, latents)
        zeros = torch.zeros_like(sigma) 
        alpha_mean_sq = torch.nn.functional.mse_loss(latents.float(), zeros.float(), reduction="none").mean([1, 2, 3])
        sigma_mean_sq = torch.nn.functional.mse_loss(sigma.float(), zeros.float(), reduction="none").mean([1, 2, 3])
        snr = torch.div(alpha_mean_sq, sigma_mean_sq)
        gamma_over_snr = torch.div(torch.ones_like(snr) * gamma, snr)
        snr_weight = torch.minimum(gamma_over_snr, torch.ones_like(gamma_over_snr)).float()
        loss = loss * snr_weight
    return loss

That way all you need is add this to each trainers:

from library.custom_train_functions import apply_snr_weight

loss = apply_snr_weight(loss, noisy_latents, latents, args.min_snr_gamma)

@bmaltais
Copy link
Contributor

I gave it a try and it does appear to have a significantly positive effect on the training. This is a keeper!

@AI-Casanova
Copy link
Contributor Author

One possible improvement to the code could be to create a function in a seperate custom_train_functions.py file to be called by all trainer:

I do like the cleanliness of that. It could also have the associated argparse lines there as well.

@AI-Casanova
Copy link
Contributor Author

what does the comparison show when the paper talks about converging faster?

I didn't so much notice faster convergence as I did better likeness, but that may have to do with the fact that we're fine-tuning, not training a model from scratch.

In this implementation, I noticed less overtraining on superficial details like eyebags and facial imperfections, which can occur when the random timesteps are very low. Low timesteps/low noise force back propagation into both a big update step, and that step to be superficial.

@TingTingin
Copy link
Contributor

what are the recommended values from your testing?

@AI-Casanova
Copy link
Contributor Author

AI-Casanova commented Mar 21, 2023

@TingTingin --min_snr_gamma=5 follows the papers most common example and seems a reasonable starting point, remembering that lower numbers have higher effect.

During my testing I added print(timesteps) print(snr_weight) to the subroutine and found my highest value tested (gamma 20) to have very minimal impact and would consider it to be an upper bound.

@bmaltais
Copy link
Contributor

bmaltais commented Mar 22, 2023

I would think that you will need to add the min-snr to all trainer before @kohya-ss can merge. In the current state it would only apply to one of the 4 trainer...

@AI-Casanova
Copy link
Contributor Author

Just sitting down to refactor it now.

@AI-Casanova
Copy link
Contributor Author

Alright, I refactored into a separate file per @bmaltais suggestion, and added the function to all 4 trainers.

I do not have a dataset on hand to test fine_tune.py, but the other 3 function as intended.

@laksjdjf
Copy link
Contributor

After reading the paper, I thought that $\alpha_t$ and $\sigma_t$ were determined uniformly based on the timesteps and sampling scheduler settings. However, it seems that this PR calculates $\alpha_t$ and $\sigma_t$ based on the noise and latents. Is there any specific reason for this approach?

SNR calculation code by author.

The $\alpha_t$ and $\sigma_t$ are here in diffusers.

@AI-Casanova
Copy link
Contributor Author

@laksjdjf thanks for linking me the official code, it wasn't out at the time I coded this up.

I'm not math guy by any stretch, but the code sanity checks, and has been checked over by better math brains than mine. If there are improvements, I'm completely game.

As to their code, I see Compute training losses for a single timestep I'm trying to isolate that SNR for each item in the batch (which each have their own timesteps) at runtime and calculate the weight adjustment from that.

I was just asking said math brain today how we could determine the SNR timestep correlation, but we hadn't worked it out. If the ratio of mean squares is not accurate enough, we can change systems.

If you add a print(timesteps) and print(snr_weight) in the appropriate places, you can see it in action.

@AI-Casanova
Copy link
Contributor Author

You know what? I think you're right, let me do some tests to see where the ratios end up.

@laksjdjf
Copy link
Contributor

noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise

in noise_scheduler.add_noise

and the variance of noise and original_samples is 1, so the SNR of noisy latents becomes

(sqrt_alpha_prod / sqrt_one_minus_alpha_prod)**2.

I did test to see snr ratio by below code.

import torch
from diffusers import DDPMScheduler
import matplotlib.pyplot as plt
scheduler = DDPMScheduler.from_pretrained("stabilityai/stable-diffusion-2", subfolder="scheduler")

timesteps = torch.arange(0,1000)

def get_snr(
    scheduler, 
    timesteps: torch.IntTensor,
) -> torch.FloatTensor:

    sqrt_alpha_prod = scheduler.alphas_cumprod[timesteps] ** 0.5
    sqrt_alpha_prod = sqrt_alpha_prod.flatten()

    sqrt_one_minus_alpha_prod = (1 - scheduler.alphas_cumprod[timesteps]) ** 0.5
    sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()

    return (sqrt_alpha_prod / sqrt_one_minus_alpha_prod) ** 2

snr = get_snr(scheduler, timesteps)

plt.xlabel("timesteps")
plt.ylabel("snr")
plt.plot(timesteps,snr)

snr

If it is limited to between 100 and 1000 because it is not clear,

snrlimit

I am not sure if these numbers are reasonable.

@AI-Casanova
Copy link
Contributor Author

AI-Casanova commented Mar 23, 2023

Pardon my git illiteracy there.

Implemented the necessary changes to align with the authors calculations

image
The lower line is the new loss, so there is definitely a touch more scaling going on

Your graphs are correct @laksjdjf and here is a google sheet with a graph of snr_weight scalable by gamma
https://docs.google.com/spreadsheets/d/1Vq6NwUrG0AGPHe06nGzIUTda1Ul1RSjNbMjv5MgtEwU/edit?usp=sharing

Thank you for bringing this to my attention!

@AI-Casanova
Copy link
Contributor Author

Comparisons:
From highest loss to lowest (same with image grid)
No gamma, old gamma 5, new gamma 5, old gamma 3, new gamma 3
image
image

Differences between SNR calculations are quite subtle, but I'd still rather have them correct than not.

@AI-Casanova
Copy link
Contributor Author

Trying to git push at 1 am is too much for my brain I guess. Thanks to the kind soul from discord who caught my mistake.

@noseconexoes
Copy link

Congratulation for the work Casanova!
Thanks for the contribution! And thanks to @kohya-ss and @bmaltais for everything!
Hope that this implementation reach soon as possible the main verstion!

@kohya-ss
Copy link
Owner

Thank you @AI-Casanova and everyone for the great discussion! I don't fully understand the background of the theory, but the results are excellent!

I think this PR is ready to merge. I will be merging this today. If you have any concerns, please let me know.

@AI-Casanova
Copy link
Contributor Author

I see no issues for merging, base functionality shouldn't be affected at all without a --min_snr_gamma argument, and thanks to @laksjdjf I've got the algorithm working nicely.

Thanks for everything @kohya-ss!

@mgz-dev
Copy link
Contributor

mgz-dev commented Mar 25, 2023

Great work implementing this! Very excited to do some comparisons with this modified loss calculation.

I did run into a small issue due while testing this repo due to datatype incompatibility with numpy (specifically bf16):

File "\train_network.py", line 552, in train
    loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma)
  File "\library\custom_train_functions.py", line 9, in apply_snr_weight
    sqrt_alphas_cumprod = np.sqrt(alphas_cumprod)
  File "sd\lib\site-packages\torch\_tensor.py", line 970, in __array__
    return self.numpy()
TypeError: Got unsupported ScalarType BFloat16

It should be an easy fix by just changing the sqrt functions to their torch equivalents. It also may be slightly cleaner notation to change min_snr_gamma default to read in as None.

import torch
import argparse


def apply_snr_weight(loss, timesteps, noise_scheduler, gamma): 
  alphas_cumprod = noise_scheduler.alphas_cumprod.cpu()
  sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
  sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod)
  alpha = sqrt_alphas_cumprod
  sigma = sqrt_one_minus_alphas_cumprod
  all_snr = (alpha / sigma) ** 2
  all_snr.to(loss.device)
  snr = torch.stack([all_snr[t] for t in timesteps])
  gamma_over_snr = torch.div(torch.ones_like(snr)*gamma,snr)
  snr_weight = torch.minimum(gamma_over_snr,torch.ones_like(gamma_over_snr)).float().to(loss.device) #from paper
  loss = loss * snr_weight
  return loss

def add_custom_train_arguments(parser: argparse.ArgumentParser):
  parser.add_argument("--min_snr_gamma", type=float, default=None, help="gamma for reducing the weight of high loss timesteps. Lower numbers have stronger effect. 5 is recommended by paper.")

Thanks again for finding this paper and sharing with everyone!

@AI-Casanova
Copy link
Contributor Author

Thanks for testing, @mgz-dev, and finding the TypeError with bf16.

Switching to torch.sqrt is even slicker, because I no longer have to pull to CPU to do the calculations.

Ran a quick test with fp16 to ensure that the loss graphs were a match. Let me know if you find anything else.

@kohya-ss kohya-ss changed the base branch from main to dev March 26, 2023 08:09
@CCRcmcpe
Copy link

Current implementation is not applicable to SD2 (v-pred), it only considered SD1 (eps-pred).
You can refer to Appendix. B of the original paper for a derivation of SNR weight for x and v.

Here's an improved implementation I thought:

def min_snr_weight(scheduler, t, gamma):
    alpha_cp = scheduler.alphas_cumprod
    sigma_pow_2 = 1.0 - alpha_cp
    snr = (alpha_cp ** 2) / sigma_pow_2
    snr_t = snr[t]

    match scheduler.config.prediction_type:
        case "epsilon":
            min_snr_w = torch.minimum(gamma / snr_t, torch.ones_like(t, dtype=torch.float32))
        case "sample":
            min_snr_w = torch.minimum(snr_t, torch.full_like(t, gamma, dtype=torch.float32))
        case "v_prediction":
            min_snr_w = torch.minimum(snr_t + 1, torch.full_like(t, gamma, dtype=torch.float32))
        case _:
            raise Exception("Unknown prediction type")

    return min_snr_w

@AI-Casanova
Copy link
Contributor Author

@CCRcmcpe Gimme a bit to think through this, I did forget to disclaim V2 incompatibility.

@AI-Casanova
Copy link
Contributor Author

Also @CCRcmcpe wouldn't the velocity factor be `min(SNR,gamma)/(SNR+1)

That's how I'm reading this and the appendix
Screenshot_20230426-213458-856

@CCRcmcpe
Copy link

Yeah I'm mistaken.

@laksjdjf
Copy link
Contributor

laksjdjf commented Jun 2, 2023

Has anything been done regarding the v_prediction model?
I agree with the velocity factor being min(SNR, gamma) / (SNR + 1).

wkpark pushed a commit to wkpark/sd-scripts that referenced this pull request Feb 27, 2024
Remove legacy 8bit adam checkbox
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

9 participants