Skip to content

Commit 2f668ec

Browse files
Alexey Sidnevfacebook-github-bot
Alexey Sidnev
authored andcommitted
Disable gradient calculation in _check_valid_rotation_matrix()
Summary: # Make `transform3d.py` a little bit better (performance and code quality) ## 1. Add decorator `torch.no_grad()` to the function `_check_valid_rotation_matrix()` Function `_check_valid_rotation_matrix()` is needed to identify errors during forward pass only, it's not used for gradients. ## 2. Replace two calls `to` with the single one Reviewed By: bottler Differential Revision: D29656501 fbshipit-source-id: 4419e24dbf436c1b60abf77bda4376fb87a593be
1 parent 0c02ae9 commit 2f668ec

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

pytorch3d/transforms/transform3d.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -566,7 +566,7 @@ def __init__(
566566
if R.shape[-2:] != (3, 3):
567567
msg = "R must have shape (3, 3) or (N, 3, 3); got %s"
568568
raise ValueError(msg % repr(R.shape))
569-
R = R.to(dtype=dtype).to(device=device_)
569+
R = R.to(device=device_, dtype=dtype)
570570
_check_valid_rotation_matrix(R, tol=orthogonal_tol)
571571
N = R.shape[0]
572572
mat = torch.eye(4, dtype=dtype, device=device_)
@@ -752,6 +752,9 @@ def _broadcast_bmm(a, b):
752752
return a.bmm(b)
753753

754754

755+
# pyre-fixme[56]: Decorator `torch.no_grad(...)` could not be called, because
756+
# its type `no_grad` is not callable.
757+
@torch.no_grad()
755758
def _check_valid_rotation_matrix(R, tol: float = 1e-7):
756759
"""
757760
Determine if R is a valid rotation matrix by checking it satisfies the

0 commit comments

Comments
 (0)