Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Soft min SNR gamma #1068

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open

Conversation

rockerBOO
Copy link
Contributor

@rockerBOO rockerBOO commented Jan 24, 2024

In "Scalable High-Resolution Pixel-Space Image Synthesis with Hourglass Diffusion Transformers" they came up with Soft min SNR gamma which smooths out the transition area.

Screenshot 2024-01-23 at 23-13-30 Scalable High-Resolution Pixel-Space Image Synthesis with Hourglass Diffusion Transformers - 2401 11605 pdf
Screenshot 2024-01-23 at 22-53-06 Scalable High-Resolution Pixel-Space Image Synthesis with Hourglass Diffusion Transformers - 2401 11605 pdf

31 = soft_min_snr_gamma = 5
32 = min_snr_gamma = 5
Screenshot 2024-01-24 at 00-01-36 Weights   Biases

Note I think the math is correct but I could be wrong so if anyone wants to correct I can update it.

@@ -68,6 +68,13 @@ def apply_snr_weight(loss, timesteps, noise_scheduler, gamma, v_prediction=False
return loss


def apply_soft_snr_weight(loss, timesteps, noise_scheduler, gamma, v_prediction=False):
snr = torch.stack([noise_scheduler.all_snr[t] for t in timesteps])
soft_min_snr_gamma_weight = 1 / (torch.pow(snr if v_prediction is False else snr + 1, 2) + (1 / float(gamma)))
Copy link
Contributor

@feffy380 feffy380 Jan 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The math here is incorrect. SNR is equal to the whole expression 1/sigma**2, not sigma (at least based on the fact that here they use Min(1/sigma**2, gamma) and in the min-snr paper they use Min(SNR, gamma). The variable names are inconsistent between papers so I don't blame you for getting them confused).
The correct weight should be:

sigma2 = 1 / snr
1 / (sigma2 + 1/gamma)
1 / (1/snr + 1/gamma)
# simplified
weight = snr * gamma / (snr + gamma)

Finally, the given formulation for soft-min-snr is for x_0 prediction. We use epsilon or v-prediction, which according to the original min-snr paper means we need to divide by SNR or SNR+1 respectively, so the final weight calculation should be:

snr_weight = (snr * gamma / (snr + gamma)).float().to(loss.device)
if v_prediction:
    snr_weight /= snr + 1
else:
    snr_weight /= snr

Copy link
Contributor Author

@rockerBOO rockerBOO Jan 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tried this formula

def apply_soft_snr_weight(loss, timesteps, noise_scheduler, gamma, v_prediction=False):
    snr = torch.stack([noise_scheduler.all_snr[t] for t in timesteps])
    snr_weight = (snr * gamma / (snr + gamma)).float().to(loss.device)
    if v_prediction:
        snr_weight /= snr + 1
    else:
        snr_weight /= snr
    loss = loss * snr_weight
    return loss

And produced the same loss curve as the current implementation. But the paper says it should match up except for the ones closer to the transition so the loss curves should be similar. Still seeing a difference with the Min SNR version though.

38 = weight = snr * gamma / (snr + gamma)
37 = the current PR version
35 = Min SNR version.

(38 and 37 are overlapping in the graph)
Screenshot 2024-01-24 at 11-58-51 Weights   Biases

Maybe there's something else that is missing in these calculations.

Here is some example snr,gamma to use with the following test script
snr.txt

And using this code to test the formulas.

import math

with open("snr.txt", "r") as f:
    lines = f.readlines()

    print("snr min_snr soft soft2")
    for line in lines:
        snr, gamma = line.split(",")

        snr = float(snr)
        gamma = float(gamma)

        min_snr = min(1 / math.pow(snr, 2), gamma)
        soft_min_snr_gamma = 1 / (math.pow(snr, 2) + (1 / gamma))

        snr_weight = (snr * gamma / (snr + gamma))

        print(f"{snr:10.4f} {min_snr:4.4f}, {soft_min_snr_gamma:4.4f}, {snr_weight:4.4f}")

I don't know what I'm doing, I think, with the math but trying to learn.

@Birch-san
Copy link

here's how to formulate it as an EDM target:
https://github.com/crowsonkb/k-diffusion/blob/6ab5146d4a5ef63901326489f31f1d8e7dd36b48/k_diffusion/layers.py#L65

here's how to formulate it as an x0 loss weighting:
https://github.com/Birch-san/k-diffusion/blob/9bce54aec1e596548cf73f56f4842c11aa6271c6/k_diffusion/layers.py#L160

here's an alternative style for expressing it as an EDM target, where we use the x0 loss weighting and apply a correction to adapt it for EDM:
https://github.com/Birch-san/k-diffusion/blob/9bce54aec1e596548cf73f56f4842c11aa6271c6/k_diffusion/layers.py#L250

"--soft_min_snr_gamma",
type=float,
default=None,
help="gamma for reducing the weight of high loss timesteps. Lower numbers have stronger effect. 5 is recommended by paper. / 低いタイムステップでの高いlossに対して重みを減らすためのgamma値、低いほど効果が強く、論文では5が推奨",

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we don't recommend gamma=5, we recommend gamma=sigma_data**-2.
for our pixel-space dataset, ImageNet, we declared sigma_data=0.5 and hence gamma=4.
for latent datasets, you probably want sigma_data=1.0 and hence gamma=1. because you are not training on the raw latents, you are first multiplying by 0.13025 to standardize their std to 1.

sigma_data is the standard deviation of your dataset's pixels (or latents).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated to recommend 1. Mistakenly I copied the help function from the min_snr_gamma option. Thank you!

@rockerBOO
Copy link
Contributor Author

here's how to formulate it as an EDM target: https://github.com/crowsonkb/k-diffusion/blob/6ab5146d4a5ef63901326489f31f1d8e7dd36b48/k_diffusion/layers.py#L65

here's how to formulate it as an x0 loss weighting: https://github.com/Birch-san/k-diffusion/blob/9bce54aec1e596548cf73f56f4842c11aa6271c6/k_diffusion/layers.py#L160

here's an alternative style for expressing it as an EDM target, where we use the x0 loss weighting and apply a correction to adapt it for EDM: https://github.com/Birch-san/k-diffusion/blob/9bce54aec1e596548cf73f56f4842c11aa6271c6/k_diffusion/layers.py#L250

Thanks for sharing these. Not sure what EDM or x0 loss weighting are to know what makes sense to use here. I foolishly tried to implement it but do not understand the math underlying it.

I will try these implementation details you shared to see if i can discern how to implement it properly into this code also integrating what feffy has suggested.

@kohya-ss
Copy link
Owner

kohya-ss commented Apr 2, 2024

Sorry for the delay for merging. However, it would be nice if this PR could be improved. I am also trying to understand the math but it is quite difficult...

Please let me know when the PR is ready.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants