Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Soft min SNR gamma #1068

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions fine_tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import library.custom_train_functions as custom_train_functions
from library.custom_train_functions import (
apply_snr_weight,
apply_soft_snr_weight,
get_weighted_text_embeddings,
prepare_scheduler_for_custom_training,
scale_v_prediction_loss_like_noise_prediction,
Expand Down Expand Up @@ -353,6 +354,8 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):

if args.min_snr_gamma:
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization)
if args.soft_min_snr_gamma:
loss = apply_soft_snr_weight(loss, timesteps, noise_scheduler, args.soft_min_snr_gamma, args.v_parameterization)
if args.scale_v_pred_loss_like_noise_pred:
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
if args.debiased_estimation_loss:
Expand Down
13 changes: 13 additions & 0 deletions library/custom_train_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,13 @@ def apply_snr_weight(loss, timesteps, noise_scheduler, gamma, v_prediction=False
return loss


def apply_soft_snr_weight(loss, timesteps, noise_scheduler, gamma, v_prediction=False):
snr = torch.stack([noise_scheduler.all_snr[t] for t in timesteps])
soft_min_snr_gamma_weight = 1 / (torch.pow(snr if v_prediction is False else snr + 1, 2) + (1 / float(gamma)))
Copy link
Contributor

@feffy380 feffy380 Jan 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The math here is incorrect. SNR is equal to the whole expression 1/sigma**2, not sigma (at least based on the fact that here they use Min(1/sigma**2, gamma) and in the min-snr paper they use Min(SNR, gamma). The variable names are inconsistent between papers so I don't blame you for getting them confused).
The correct weight should be:

sigma2 = 1 / snr
1 / (sigma2 + 1/gamma)
1 / (1/snr + 1/gamma)
# simplified
weight = snr * gamma / (snr + gamma)

Finally, the given formulation for soft-min-snr is for x_0 prediction. We use epsilon or v-prediction, which according to the original min-snr paper means we need to divide by SNR or SNR+1 respectively, so the final weight calculation should be:

snr_weight = (snr * gamma / (snr + gamma)).float().to(loss.device)
if v_prediction:
    snr_weight /= snr + 1
else:
    snr_weight /= snr

Copy link
Contributor Author

@rockerBOO rockerBOO Jan 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tried this formula

def apply_soft_snr_weight(loss, timesteps, noise_scheduler, gamma, v_prediction=False):
    snr = torch.stack([noise_scheduler.all_snr[t] for t in timesteps])
    snr_weight = (snr * gamma / (snr + gamma)).float().to(loss.device)
    if v_prediction:
        snr_weight /= snr + 1
    else:
        snr_weight /= snr
    loss = loss * snr_weight
    return loss

And produced the same loss curve as the current implementation. But the paper says it should match up except for the ones closer to the transition so the loss curves should be similar. Still seeing a difference with the Min SNR version though.

38 = weight = snr * gamma / (snr + gamma)
37 = the current PR version
35 = Min SNR version.

(38 and 37 are overlapping in the graph)
Screenshot 2024-01-24 at 11-58-51 Weights   Biases

Maybe there's something else that is missing in these calculations.

Here is some example snr,gamma to use with the following test script
snr.txt

And using this code to test the formulas.

import math

with open("snr.txt", "r") as f:
    lines = f.readlines()

    print("snr min_snr soft soft2")
    for line in lines:
        snr, gamma = line.split(",")

        snr = float(snr)
        gamma = float(gamma)

        min_snr = min(1 / math.pow(snr, 2), gamma)
        soft_min_snr_gamma = 1 / (math.pow(snr, 2) + (1 / gamma))

        snr_weight = (snr * gamma / (snr + gamma))

        print(f"{snr:10.4f} {min_snr:4.4f}, {soft_min_snr_gamma:4.4f}, {snr_weight:4.4f}")

I don't know what I'm doing, I think, with the math but trying to learn.

loss = loss * soft_min_snr_gamma_weight
return loss


def scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler):
scale = get_snr_scale(timesteps, noise_scheduler)
loss = loss * scale
Expand Down Expand Up @@ -106,6 +113,12 @@ def add_custom_train_arguments(parser: argparse.ArgumentParser, support_weighted
default=None,
help="gamma for reducing the weight of high loss timesteps. Lower numbers have stronger effect. 5 is recommended by paper. / 低いタイムステップでの高いlossに対して重みを減らすためのgamma値、低いほど効果が強く、論文では5が推奨",
)
parser.add_argument(
"--soft_min_snr_gamma",
type=float,
default=None,
help="gamma for reducing the weight of high loss timesteps. Lower numbers have stronger effect. 1 is recommended. / 低いタイムステップでの高いlossに対して重みを減らすためのgamma値、低いほど効果が強く、論文では1が推奨",
)
parser.add_argument(
"--scale_v_pred_loss_like_noise_pred",
action="store_true",
Expand Down
3 changes: 3 additions & 0 deletions train_controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import library.custom_train_functions as custom_train_functions
from library.custom_train_functions import (
apply_snr_weight,
apply_soft_snr_weight,
pyramid_noise_like,
apply_noise_offset,
)
Expand Down Expand Up @@ -457,6 +458,8 @@ def remove_model(old_ckpt_name):

if args.min_snr_gamma:
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization)
if args.soft_min_snr_gamma:
loss = apply_soft_snr_weight(loss, timesteps, noise_scheduler, args.soft_min_snr_gamma, args.v_parameterization)

loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし

Expand Down
3 changes: 3 additions & 0 deletions train_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import library.custom_train_functions as custom_train_functions
from library.custom_train_functions import (
apply_snr_weight,
apply_soft_snr_weight,
get_weighted_text_embeddings,
prepare_scheduler_for_custom_training,
pyramid_noise_like,
Expand Down Expand Up @@ -341,6 +342,8 @@ def train(args):

if args.min_snr_gamma:
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization)
if args.soft_min_snr_gamma:
loss = apply_soft_snr_weight(loss, timesteps, noise_scheduler, args.soft_min_snr_gamma, args.v_parameterization)
if args.scale_v_pred_loss_like_noise_pred:
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
if args.debiased_estimation_loss:
Expand Down
2 changes: 2 additions & 0 deletions train_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
import library.custom_train_functions as custom_train_functions
from library.custom_train_functions import (
apply_snr_weight,
apply_soft_snr_weight,
get_weighted_text_embeddings,
prepare_scheduler_for_custom_training,
scale_v_prediction_loss_like_noise_prediction,
Expand Down Expand Up @@ -522,6 +523,7 @@ def train(self, args):
"ss_face_crop_aug_range": args.face_crop_aug_range,
"ss_prior_loss_weight": args.prior_loss_weight,
"ss_min_snr_gamma": args.min_snr_gamma,
"ss_soft_min_snr_gamma": args.soft_min_snr_gamma,
"ss_scale_weight_norms": args.scale_weight_norms,
"ss_ip_noise_gamma": args.ip_noise_gamma,
"ss_debiased_estimation": bool(args.debiased_estimation_loss),
Expand Down
3 changes: 3 additions & 0 deletions train_textual_inversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import library.custom_train_functions as custom_train_functions
from library.custom_train_functions import (
apply_snr_weight,
apply_soft_snr_weight,
prepare_scheduler_for_custom_training,
scale_v_prediction_loss_like_noise_prediction,
add_v_prediction_like_loss,
Expand Down Expand Up @@ -590,6 +591,8 @@ def remove_model(old_ckpt_name):

if args.min_snr_gamma:
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization)
if args.soft_min_snr_gamma:
loss = apply_soft_snr_weight(loss, timesteps, noise_scheduler, args.soft_min_snr_gamma, args.v_parameterization)
if args.scale_v_pred_loss_like_noise_pred:
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
if args.v_pred_like_loss:
Expand Down