-
Notifications
You must be signed in to change notification settings - Fork 867
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Soft min SNR gamma #1068
base: main
Are you sure you want to change the base?
Soft min SNR gamma #1068
Conversation
@@ -68,6 +68,13 @@ def apply_snr_weight(loss, timesteps, noise_scheduler, gamma, v_prediction=False | |||
return loss | |||
|
|||
|
|||
def apply_soft_snr_weight(loss, timesteps, noise_scheduler, gamma, v_prediction=False): | |||
snr = torch.stack([noise_scheduler.all_snr[t] for t in timesteps]) | |||
soft_min_snr_gamma_weight = 1 / (torch.pow(snr if v_prediction is False else snr + 1, 2) + (1 / float(gamma))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The math here is incorrect. SNR is equal to the whole expression 1/sigma**2
, not sigma
(at least based on the fact that here they use Min(1/sigma**2, gamma) and in the min-snr paper they use Min(SNR, gamma). The variable names are inconsistent between papers so I don't blame you for getting them confused).
The correct weight should be:
sigma2 = 1 / snr
1 / (sigma2 + 1/gamma)
1 / (1/snr + 1/gamma)
# simplified
weight = snr * gamma / (snr + gamma)
Finally, the given formulation for soft-min-snr is for x_0
prediction. We use epsilon or v-prediction, which according to the original min-snr paper means we need to divide by SNR or SNR+1 respectively, so the final weight calculation should be:
snr_weight = (snr * gamma / (snr + gamma)).float().to(loss.device)
if v_prediction:
snr_weight /= snr + 1
else:
snr_weight /= snr
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Tried this formula
def apply_soft_snr_weight(loss, timesteps, noise_scheduler, gamma, v_prediction=False):
snr = torch.stack([noise_scheduler.all_snr[t] for t in timesteps])
snr_weight = (snr * gamma / (snr + gamma)).float().to(loss.device)
if v_prediction:
snr_weight /= snr + 1
else:
snr_weight /= snr
loss = loss * snr_weight
return loss
And produced the same loss curve as the current implementation. But the paper says it should match up except for the ones closer to the transition so the loss curves should be similar. Still seeing a difference with the Min SNR version though.
38 = weight = snr * gamma / (snr + gamma)
37 = the current PR version
35 = Min SNR version.
(38 and 37 are overlapping in the graph)
Maybe there's something else that is missing in these calculations.
Here is some example snr,gamma
to use with the following test script
snr.txt
And using this code to test the formulas.
import math
with open("snr.txt", "r") as f:
lines = f.readlines()
print("snr min_snr soft soft2")
for line in lines:
snr, gamma = line.split(",")
snr = float(snr)
gamma = float(gamma)
min_snr = min(1 / math.pow(snr, 2), gamma)
soft_min_snr_gamma = 1 / (math.pow(snr, 2) + (1 / gamma))
snr_weight = (snr * gamma / (snr + gamma))
print(f"{snr:10.4f} {min_snr:4.4f}, {soft_min_snr_gamma:4.4f}, {snr_weight:4.4f}")
I don't know what I'm doing, I think, with the math but trying to learn.
here's how to formulate it as an EDM target: here's how to formulate it as an x0 loss weighting: here's an alternative style for expressing it as an EDM target, where we use the x0 loss weighting and apply a correction to adapt it for EDM: |
library/custom_train_functions.py
Outdated
"--soft_min_snr_gamma", | ||
type=float, | ||
default=None, | ||
help="gamma for reducing the weight of high loss timesteps. Lower numbers have stronger effect. 5 is recommended by paper. / 低いタイムステップでの高いlossに対して重みを減らすためのgamma値、低いほど効果が強く、論文では5が推奨", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we don't recommend gamma=5
, we recommend gamma=sigma_data**-2
.
for our pixel-space dataset, ImageNet, we declared sigma_data=0.5
and hence gamma=4
.
for latent datasets, you probably want sigma_data=1.0
and hence gamma=1
. because you are not training on the raw latents, you are first multiplying by 0.13025
to standardize their std to 1.
sigma_data
is the standard deviation of your dataset's pixels (or latents).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated to recommend 1. Mistakenly I copied the help function from the min_snr_gamma option. Thank you!
Thanks for sharing these. Not sure what EDM or x0 loss weighting are to know what makes sense to use here. I foolishly tried to implement it but do not understand the math underlying it. I will try these implementation details you shared to see if i can discern how to implement it properly into this code also integrating what feffy has suggested. |
Sorry for the delay for merging. However, it would be nice if this PR could be improved. I am also trying to understand the math but it is quite difficult... Please let me know when the PR is ready. |
In "Scalable High-Resolution Pixel-Space Image Synthesis with Hourglass Diffusion Transformers" they came up with Soft min SNR gamma which smooths out the transition area.
31 = soft_min_snr_gamma = 5
32 = min_snr_gamma = 5
Note I think the math is correct but I could be wrong so if anyone wants to correct I can update it.