Skip to content

Commit

Permalink
fix: Fix masked normal loss
Browse files Browse the repository at this point in the history
  • Loading branch information
hugoycj committed Oct 16, 2024
1 parent fb8cd9c commit 608afe0
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,13 +210,13 @@ def training(dataset, opt, pipe, testing_iterations, saving_iterations, checkpoi
if lambda_normal > 0.0:
normal_error = 0.6 * (1 - F.cosine_similarity(rend_normal, surf_normal_median, dim=0)) + \
0.4 * (1 - F.cosine_similarity(rend_normal, surf_normal_expected, dim=0))
normal_error = normal_error * viewpoint_cam.gt_alpha_mask
normal_error = normal_error * viewpoint_cam.gt_alpha_mask.mean(dim=0)
normal_error = ranking_loss(normal_error.view(-1), penalize_ratio=0.7, type='mean')
normal_loss += lambda_normal * normal_error

if lambda_normal_prior > 0 and dataset.w_normal_prior:
prior_normal = viewpoint_cam.normal_prior * (rend_alpha).detach()
prior_normal_mask = viewpoint_cam.normal_mask[0] & viewpoint_cam.gt_alpha_mask
prior_normal_mask = viewpoint_cam.normal_mask[0] & viewpoint_cam.gt_alpha_mask.mean(dim=0)

normal_prior_error = 0.6 * (1 - F.cosine_similarity(prior_normal, rend_normal, dim=0)) + \
0.4 * (1 - F.cosine_similarity(prior_normal, surf_normal_expected, dim=0))
Expand Down

0 comments on commit 608afe0

Please sign in to comment.