Skip to content

Commit 202f2c3

Browse files
authored
Debias Estimation loss (#889)
* update for bnb 0.41.1 * fixed generate_controlnet_subsets_config for training * Revert "update for bnb 0.41.1" This reverts commit 70bd361. * add debiased_estimation_loss * add train_network * Revert "add train_network" This reverts commit 6539363. * Update train_network.py
1 parent 681034d commit 202f2c3

9 files changed

+38
-2
lines changed

fine_tune.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
get_weighted_text_embeddings,
3333
prepare_scheduler_for_custom_training,
3434
scale_v_prediction_loss_like_noise_prediction,
35+
apply_debiased_estimation,
3536
)
3637

3738

@@ -339,7 +340,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
339340
else:
340341
target = noise
341342

342-
if args.min_snr_gamma or args.scale_v_pred_loss_like_noise_pred:
343+
if args.min_snr_gamma or args.scale_v_pred_loss_like_noise_pred or args.debiased_estimation_loss,:
343344
# do not mean over batch dimension for snr weight or scale v-pred loss
344345
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
345346
loss = loss.mean([1, 2, 3])
@@ -348,6 +349,8 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
348349
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma)
349350
if args.scale_v_pred_loss_like_noise_pred:
350351
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
352+
if args.debiased_estimation_loss:
353+
loss = apply_debiased_estimation(loss, timesteps, noise_scheduler)
351354

352355
loss = loss.mean() # mean over batch dimension
353356
else:

library/custom_train_functions.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,12 @@ def add_v_prediction_like_loss(loss, timesteps, noise_scheduler, v_pred_like_los
8686
loss = loss + loss / scale * v_pred_like_loss
8787
return loss
8888

89+
def apply_debiased_estimation(loss, timesteps, noise_scheduler):
90+
snr_t = torch.stack([noise_scheduler.all_snr[t] for t in timesteps]) # batch_size
91+
snr_t = torch.minimum(snr_t, torch.ones_like(snr_t) * 1000) # if timestep is 0, snr_t is inf, so limit it to 1000
92+
weight = 1/torch.sqrt(snr_t)
93+
loss = weight * loss
94+
return loss
8995

9096
# TODO train_utilと分散しているのでどちらかに寄せる
9197

@@ -108,6 +114,11 @@ def add_custom_train_arguments(parser: argparse.ArgumentParser, support_weighted
108114
default=None,
109115
help="add v-prediction like loss multiplied by this value / v-prediction lossをこの値をかけたものをlossに加算する",
110116
)
117+
parser.add_argument(
118+
"--debiased_estimation_loss",
119+
action="store_true",
120+
help="debiased estimation loss / debiased estimation loss",
121+
)
111122
if support_weighted_captions:
112123
parser.add_argument(
113124
"--weighted_captions",

sdxl_train.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
prepare_scheduler_for_custom_training,
3535
scale_v_prediction_loss_like_noise_prediction,
3636
add_v_prediction_like_loss,
37+
apply_debiased_estimation,
3738
)
3839
from library.sdxl_original_unet import SdxlUNet2DConditionModel
3940

@@ -548,7 +549,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
548549

549550
target = noise
550551

551-
if args.min_snr_gamma or args.scale_v_pred_loss_like_noise_pred or args.v_pred_like_loss:
552+
if args.min_snr_gamma or args.scale_v_pred_loss_like_noise_pred or args.v_pred_like_loss or args.debiased_estimation_loss:
552553
# do not mean over batch dimension for snr weight or scale v-pred loss
553554
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
554555
loss = loss.mean([1, 2, 3])
@@ -559,6 +560,8 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
559560
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
560561
if args.v_pred_like_loss:
561562
loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss)
563+
if args.debiased_estimation_loss:
564+
loss = apply_debiased_estimation(loss, timesteps, noise_scheduler)
562565

563566
loss = loss.mean() # mean over batch dimension
564567
else:

sdxl_train_control_net_lllite.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
pyramid_noise_like,
4545
apply_noise_offset,
4646
scale_v_prediction_loss_like_noise_prediction,
47+
apply_debiased_estimation,
4748
)
4849
import networks.control_net_lllite_for_train as control_net_lllite_for_train
4950

@@ -465,6 +466,8 @@ def remove_model(old_ckpt_name):
465466
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
466467
if args.v_pred_like_loss:
467468
loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss)
469+
if args.debiased_estimation_loss:
470+
loss = apply_debiased_estimation(loss, timesteps, noise_scheduler)
468471

469472
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
470473

sdxl_train_control_net_lllite_old.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
pyramid_noise_like,
4141
apply_noise_offset,
4242
scale_v_prediction_loss_like_noise_prediction,
43+
apply_debiased_estimation,
4344
)
4445
import networks.control_net_lllite as control_net_lllite
4546

@@ -435,6 +436,8 @@ def remove_model(old_ckpt_name):
435436
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
436437
if args.v_pred_like_loss:
437438
loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss)
439+
if args.debiased_estimation_loss:
440+
loss = apply_debiased_estimation(loss, timesteps, noise_scheduler)
438441

439442
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
440443

train_db.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
pyramid_noise_like,
3636
apply_noise_offset,
3737
scale_v_prediction_loss_like_noise_prediction,
38+
apply_debiased_estimation,
3839
)
3940

4041
# perlin_noise,
@@ -336,6 +337,8 @@ def train(args):
336337
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma)
337338
if args.scale_v_pred_loss_like_noise_pred:
338339
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
340+
if args.debiased_estimation_loss:
341+
loss = apply_debiased_estimation(loss, timesteps, noise_scheduler)
339342

340343
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
341344

train_network.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
prepare_scheduler_for_custom_training,
4444
scale_v_prediction_loss_like_noise_prediction,
4545
add_v_prediction_like_loss,
46+
apply_debiased_estimation,
4647
)
4748

4849

@@ -528,6 +529,7 @@ def train(self, args):
528529
"ss_min_snr_gamma": args.min_snr_gamma,
529530
"ss_scale_weight_norms": args.scale_weight_norms,
530531
"ss_ip_noise_gamma": args.ip_noise_gamma,
532+
"ss_debiased_estimation": bool(args.debiased_estimation_loss),
531533
}
532534

533535
if use_user_config:
@@ -811,6 +813,8 @@ def remove_model(old_ckpt_name):
811813
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
812814
if args.v_pred_like_loss:
813815
loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss)
816+
if args.debiased_estimation_loss:
817+
loss = apply_debiased_estimation(loss, timesteps, noise_scheduler)
814818

815819
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
816820

train_textual_inversion.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
prepare_scheduler_for_custom_training,
3333
scale_v_prediction_loss_like_noise_prediction,
3434
add_v_prediction_like_loss,
35+
apply_debiased_estimation,
3536
)
3637

3738
imagenet_templates_small = [
@@ -582,6 +583,8 @@ def remove_model(old_ckpt_name):
582583
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
583584
if args.v_pred_like_loss:
584585
loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss)
586+
if args.debiased_estimation_loss:
587+
loss = apply_debiased_estimation(loss, timesteps, noise_scheduler)
585588

586589
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
587590

train_textual_inversion_XTI.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
pyramid_noise_like,
3535
apply_noise_offset,
3636
scale_v_prediction_loss_like_noise_prediction,
37+
apply_debiased_estimation,
3738
)
3839
import library.original_unet as original_unet
3940
from XTI_hijack import unet_forward_XTI, downblock_forward_XTI, upblock_forward_XTI
@@ -471,6 +472,8 @@ def remove_model(old_ckpt_name):
471472
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma)
472473
if args.scale_v_pred_loss_like_noise_pred:
473474
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
475+
if args.debiased_estimation_loss:
476+
loss = apply_debiased_estimation(loss, timesteps, noise_scheduler)
474477

475478
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
476479

0 commit comments

Comments
 (0)