Skip to content

Commit

Permalink
Replace depth alpha mask logic with torch.where for gradients in spla…
Browse files Browse the repository at this point in the history
…tfacto (#2856)


* replace alpha depth logic with torch.where for differentiability
  • Loading branch information
kerrj authored Jan 31, 2024
1 parent de36210 commit 0e01a90
Showing 1 changed file with 1 addition and 2 deletions.
3 changes: 1 addition & 2 deletions nerfstudio/models/splatfacto.py
Original file line number Diff line number Diff line change
Expand Up @@ -779,8 +779,7 @@ def get_outputs(self, camera: Cameras) -> Dict[str, Union[torch.Tensor, List]]:
W,
background=torch.zeros(3, device=self.device),
)[..., 0:1] # type: ignore
depth_im[alpha > 0] = depth_im[alpha > 0] / alpha[alpha > 0]
depth_im[alpha == 0] = 1000
depth_im = torch.where(alpha > 0, depth_im / alpha, depth_im.detach().max())

return {"rgb": rgb, "depth": depth_im, "accumulation": alpha} # type: ignore

Expand Down

0 comments on commit 0e01a90

Please sign in to comment.