Skip to content

Commit 6539363

Browse files
committed
add train_network
1 parent fe6f189 commit 6539363

File tree

2 files changed

+3
-0
lines changed

2 files changed

+3
-0
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,4 @@ venv
66
build
77
.vscode
88
wandb
9+
.vs

train_network.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -813,6 +813,8 @@ def remove_model(old_ckpt_name):
813813
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
814814
if args.v_pred_like_loss:
815815
loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss)
816+
if args.debiased_estimation_loss:
817+
loss = apply_debiased_estimation(loss, timesteps, noise_scheduler)
816818

817819
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
818820

0 commit comments

Comments
 (0)