diff --git a/train.py b/train.py index 3ae004eb..90a16c3b 100644 --- a/train.py +++ b/train.py @@ -171,6 +171,13 @@ def training(dataset, opt, pipe, testing_iterations, saving_iterations, checkpoi Ll1 = l1_loss_appearance(image, gt_image, appearances, viewpoint_idx) # use L1 loss for the transformed image if using decoupled appearance loss = (1.0 - opt.lambda_dssim) * Ll1 + opt.lambda_dssim * (1.0 - ssim(image, gt_image)) + # alpha loss + if opt.lambda_mask > 0: + opacity = 1 - render_pkg["rend_alpha"].clamp(1e-6, 1-1e-6) + bg = 1 - viewpoint_cam.gt_alpha_mask + mask_error = (- bg * torch.log(opacity)).mean() + loss += opt.lambda_mask * mask_error + # regularization lambda_normal = opt.lambda_normal if iteration > 7000 else 0.0 lambda_depth = opt.propagation_begin if iteration > opt.propagation_begin else 0.0