-
Notifications
You must be signed in to change notification settings - Fork 867
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
I have reason to believe "scale v-loss like epsilon loss" and Min-SNR-Gamma are implemented wrong. #673
Comments
There's some discussion about this in the original pull request for Min-SNR-gamma and AI-Casanova (the PR's author) also thinks it should probably be |
Thank you very much for this! I am not a math person and my understanding may be incorrect, but does this mean we can modify the following? def apply_snr_weight_noise_pred(loss, timesteps, noise_scheduler, gamma):
snr = torch.stack([noise_scheduler.all_snr[t] for t in timesteps])
gamma_over_snr = torch.div(torch.ones_like(snr) * gamma, snr)
snr_weight = torch.minimum(gamma_over_snr, torch.ones_like(gamma_over_snr)).float().to(loss.device) # from paper
loss = loss * snr_weight
return loss
def apply_snr_weight_alt(v_prediction, loss, timesteps, noise_scheduler, gamma):
if not v_prediction:
return apply_snr_weight_noise_pred(loss, timesteps, noise_scheduler, gamma)
snr = torch.stack([noise_scheduler.all_snr[t] for t in timesteps])
min_snr_gamma = torch.minimum(snr, gamma)
snr_weight = torch.div(min_snr_gamma, snr + 1).float().to(loss.device)
loss = loss * snr_weight
return loss
# we can remove this function
def scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler):
snr_t = torch.stack([noise_scheduler.all_snr[t] for t in timesteps]) # batch_size
snr_t = torch.minimum(snr_t, torch.ones_like(snr_t) * 1000) # if timestep is 0, snr_t is inf, so limit it to 1000
scale = snr_t / (snr_t + 1)
loss = loss * scale
return loss |
That should work, but I think the cleanest way to implement it would be to change the denominator based on v-prediction. If it is epsilon-prediction, it should be def apply_snr_weight_alt(v_prediction, loss, timesteps, noise_scheduler, gamma):
snr = torch.stack([noise_scheduler.all_snr[t] for t in timesteps])
min_snr_gamma = torch.minimum(snr, gamma)
if v_prediction:
snr_weight = torch.div(min_snr_gamma, snr + 1).float().to(loss.device)
else:
snr_weight = torch.div(min_snr_gamma, snr).float().to(loss.device)
loss = loss * snr_weight
return loss |
Thank you for clarification! The formulas seem to say that if we apply the current Both options can be specified at the same time. |
No, the edit: I'd also like to emphasize that the formula used in the v-prediction code path should be outright the correct implementation of min-SNR-gamma, as in there shouldn't be a separate loss rescale function for v-prediction that is optional. min-SNR-gamma used on v-prediction should always behave like this. Compatibility is a possible concern, but at least from the testing I've seen so far this implementation gives better results than the two loss rescales. |
By the way, is clipping necessary? scheduler = DDPMScheduler.from_pretrained("stabilityai/stable-diffusion-2", subfolder="scheduler")
def get_snr(
scheduler,
timesteps: torch.IntTensor,
) -> torch.FloatTensor:
sqrt_alpha_prod = scheduler.alphas_cumprod[timesteps] ** 0.5
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
sqrt_one_minus_alpha_prod = (1 - scheduler.alphas_cumprod[timesteps]) ** 0.5
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
return (sqrt_alpha_prod / sqrt_one_minus_alpha_prod) ** 2
get_snr(scheduler, torch.tensor(0))
# tensor([1175.4406]) |
Any updates? |
As laksjdjf wrote, I believe it is OK when we specify both |
Unfortunately not. Combining both leads to loss being scaled twice. |
I should reiterate that |
i can confirm this after having discussed it with Tian, one of the original paper authors. additionally, i've implemented the fix in SimpleTuner, as a non-conditional fix for v_prediction type models when min-snr gamma is in use. |
May be my lack of knowledge but I had always wondered why for my dataset setting SNR seemed to yield worse results in a way I couldnt really explain. Glad to know I wasnt imagining things. Hope this can be solved soon |
Fixed with merge of #934 |
I've been training a model using Kohya's implementation of Min-SNR-Gamma and the more recent option for scaling v-prediction like epsilon loss. I am also training it on v-prediction and zero terminal SNR, which is important.
I first found that the v-loss rescaling actually prevents a zero terminal SNR model from becoming able to produce fully black images even after about 5 million training samples, but it immediately learned it once I turned that setting off. However, others still noticed that it nevertheless improved quality in other areas, suggesting that there was likely a proper way to correct the flaw.
Looking further into the paper, it seems that the authors for Min-SNR-Gamma stated that the formula should be modified for V-loss, but may have been somewhat unclear in their wording:
Kohya implements the simplified formula on the right hand side.
I have implemented and tested this alternative function for min_snr_gamma, based on the middle formula -- it is the same as the middle formula except the denominator is replaced with
SNR(t) + 1
. My implementation is in JAX since that is what my current training script uses, but converting it to Pytorch should pretty much just be removing the expand_dims line and replacingjnp
withtorch
:This is, as far as I am aware, the correct function for min_snr_gamma for V-loss. It has performed well in my tests and has improved quality of my outputs without compromising on contrast range. It serves the same purpose as the "scale v-loss like epsilon loss" option and results in loss metrics that are in the same range as epsilon loss. It should be how Min-SNR-Gamma behaves under v-prediction and should fully replace the "scale v-loss like epsilon loss" option.
Others who I worked on this problem with have tested this function and found that it improves performance compared to using the current implementation of the aforementioned options. If you need a model to test it on, I can release one of my prototypes (an SD 1.5 model trained on V-loss and zero terminal SNR) for testing purposes. I would imagine SD 2.1 768-v would work as well.
The text was updated successfully, but these errors were encountered: