34
34
prepare_scheduler_for_custom_training ,
35
35
scale_v_prediction_loss_like_noise_prediction ,
36
36
add_v_prediction_like_loss ,
37
+ apply_debiased_estimation ,
37
38
)
38
39
from library .sdxl_original_unet import SdxlUNet2DConditionModel
39
40
@@ -548,7 +549,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
548
549
549
550
target = noise
550
551
551
- if args .min_snr_gamma or args .scale_v_pred_loss_like_noise_pred or args .v_pred_like_loss :
552
+ if args .min_snr_gamma or args .scale_v_pred_loss_like_noise_pred or args .v_pred_like_loss or args . debiased_estimation_loss :
552
553
# do not mean over batch dimension for snr weight or scale v-pred loss
553
554
loss = torch .nn .functional .mse_loss (noise_pred .float (), target .float (), reduction = "none" )
554
555
loss = loss .mean ([1 , 2 , 3 ])
@@ -559,6 +560,8 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
559
560
loss = scale_v_prediction_loss_like_noise_prediction (loss , timesteps , noise_scheduler )
560
561
if args .v_pred_like_loss :
561
562
loss = add_v_prediction_like_loss (loss , timesteps , noise_scheduler , args .v_pred_like_loss )
563
+ if args .debiased_estimation_loss :
564
+ loss = apply_debiased_estimation (loss , timesteps , noise_scheduler )
562
565
563
566
loss = loss .mean () # mean over batch dimension
564
567
else :
0 commit comments