Skip to content

Commit 9f14e82

Browse files
davnov134facebook-github-bot
authored andcommitted
SO3 improvements for stable gradients.
Summary: Improves so3 functions to make gradient computation stable: - Instead of `torch.acos`, uses `acos_linear_extrapolation` which has finite gradients of reasonable magnitude for all inputs. - Adds tests for the latter. The tests of the finiteness of the gradient in `test_so3_exp_singularity`, `test_so3_exp_singularity`, `test_so3_cos_bound` would fail if the `so3` functions would call `torch.acos` instead of `acos_linear_extrapolation`. Reviewed By: bottler Differential Revision: D23326429 fbshipit-source-id: dc296abf2ae3ddfb3942c8146621491a9cb740ee
1 parent dd45123 commit 9f14e82

File tree

3 files changed

+180
-67
lines changed

3 files changed

+180
-67
lines changed

pytorch3d/transforms/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
)
2323
from .so3 import (
2424
so3_exponential_map,
25+
so3_exp_map,
2526
so3_log_map,
2627
so3_relative_angle,
2728
so3_rotation_angle,

pytorch3d/transforms/so3.py

+78-33
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,20 @@
11
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
22

3+
from typing import Tuple
34

45
import torch
56

7+
from ..transforms import acos_linear_extrapolation
68

79
HAT_INV_SKEW_SYMMETRIC_TOL = 1e-5
810

911

10-
def so3_relative_angle(R1, R2, cos_angle: bool = False):
12+
def so3_relative_angle(
13+
R1: torch.Tensor,
14+
R2: torch.Tensor,
15+
cos_angle: bool = False,
16+
cos_bound: float = 1e-4,
17+
) -> torch.Tensor:
1118
"""
1219
Calculates the relative angle (in radians) between pairs of
1320
rotation matrices `R1` and `R2` with `angle = acos(0.5 * (Trace(R1 R2^T)-1))`
@@ -20,8 +27,12 @@ def so3_relative_angle(R1, R2, cos_angle: bool = False):
2027
R1: Batch of rotation matrices of shape `(minibatch, 3, 3)`.
2128
R2: Batch of rotation matrices of shape `(minibatch, 3, 3)`.
2229
cos_angle: If==True return cosine of the relative angle rather than
23-
the angle itself. This can avoid the unstable
24-
calculation of `acos`.
30+
the angle itself. This can avoid the unstable calculation of `acos`.
31+
cos_bound: Clamps the cosine of the relative rotation angle to
32+
[-1 + cos_bound, 1 - cos_bound] to avoid non-finite outputs/gradients
33+
of the `acos` call. Note that the non-finite outputs/gradients
34+
are returned when the angle is requested (i.e. `cos_angle==False`)
35+
and the rotation angle is close to 0 or π.
2536
2637
Returns:
2738
Corresponding rotation angles of shape `(minibatch,)`.
@@ -32,10 +43,15 @@ def so3_relative_angle(R1, R2, cos_angle: bool = False):
3243
ValueError if `R1` or `R2` has an unexpected trace.
3344
"""
3445
R12 = torch.bmm(R1, R2.permute(0, 2, 1))
35-
return so3_rotation_angle(R12, cos_angle=cos_angle)
46+
return so3_rotation_angle(R12, cos_angle=cos_angle, cos_bound=cos_bound)
3647

3748

38-
def so3_rotation_angle(R, eps: float = 1e-4, cos_angle: bool = False):
49+
def so3_rotation_angle(
50+
R: torch.Tensor,
51+
eps: float = 1e-4,
52+
cos_angle: bool = False,
53+
cos_bound: float = 1e-4,
54+
) -> torch.Tensor:
3955
"""
4056
Calculates angles (in radians) of a batch of rotation matrices `R` with
4157
`angle = acos(0.5 * (Trace(R)-1))`. The trace of the
@@ -47,8 +63,13 @@ def so3_rotation_angle(R, eps: float = 1e-4, cos_angle: bool = False):
4763
R: Batch of rotation matrices of shape `(minibatch, 3, 3)`.
4864
eps: Tolerance for the valid trace check.
4965
cos_angle: If==True return cosine of the rotation angles rather than
50-
the angle itself. This can avoid the unstable
51-
calculation of `acos`.
66+
the angle itself. This can avoid the unstable
67+
calculation of `acos`.
68+
cos_bound: Clamps the cosine of the rotation angle to
69+
[-1 + cos_bound, 1 - cos_bound] to avoid non-finite outputs/gradients
70+
of the `acos` call. Note that the non-finite outputs/gradients
71+
are returned when the angle is requested (i.e. `cos_angle==False`)
72+
and the rotation angle is close to 0 or π.
5273
5374
Returns:
5475
Corresponding rotation angles of shape `(minibatch,)`.
@@ -68,20 +89,19 @@ def so3_rotation_angle(R, eps: float = 1e-4, cos_angle: bool = False):
6889
if ((rot_trace < -1.0 - eps) + (rot_trace > 3.0 + eps)).any():
6990
raise ValueError("A matrix has trace outside valid range [-1-eps,3+eps].")
7091

71-
# clamp to valid range
72-
rot_trace = torch.clamp(rot_trace, -1.0, 3.0)
73-
7492
# phi ... rotation angle
75-
phi = 0.5 * (rot_trace - 1.0)
93+
phi_cos = (rot_trace - 1.0) * 0.5
7694

7795
if cos_angle:
78-
return phi
96+
return phi_cos
7997
else:
80-
# pyre-fixme[16]: `float` has no attribute `acos`.
81-
return phi.acos()
98+
if cos_bound > 0.0:
99+
return acos_linear_extrapolation(phi_cos, 1.0 - cos_bound)
100+
else:
101+
return torch.acos(phi_cos)
82102

83103

84-
def so3_exponential_map(log_rot, eps: float = 0.0001):
104+
def so3_exp_map(log_rot: torch.Tensor, eps: float = 0.0001) -> torch.Tensor:
85105
"""
86106
Convert a batch of logarithmic representations of rotation matrices `log_rot`
87107
to a batch of 3x3 rotation matrices using Rodrigues formula [1].
@@ -94,18 +114,31 @@ def so3_exponential_map(log_rot, eps: float = 0.0001):
94114
which is handled by clamping controlled with the `eps` argument.
95115
96116
Args:
97-
log_rot: Batch of vectors of shape `(minibatch , 3)`.
117+
log_rot: Batch of vectors of shape `(minibatch, 3)`.
98118
eps: A float constant handling the conversion singularity.
99119
100120
Returns:
101-
Batch of rotation matrices of shape `(minibatch , 3 , 3)`.
121+
Batch of rotation matrices of shape `(minibatch, 3, 3)`.
102122
103123
Raises:
104124
ValueError if `log_rot` is of incorrect shape.
105125
106126
[1] https://en.wikipedia.org/wiki/Rodrigues%27_rotation_formula
107127
"""
128+
return _so3_exp_map(log_rot, eps=eps)[0]
129+
130+
131+
so3_exponential_map = so3_exp_map
132+
108133

134+
def _so3_exp_map(
135+
log_rot: torch.Tensor, eps: float = 0.0001
136+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
137+
"""
138+
A helper function that computes the so3 exponential map and,
139+
apart from the rotation matrix, also returns intermediate variables
140+
that can be re-used in other functions.
141+
"""
109142
_, dim = log_rot.shape
110143
if dim != 3:
111144
raise ValueError("Input tensor shape has to be Nx3.")
@@ -117,27 +150,35 @@ def so3_exponential_map(log_rot, eps: float = 0.0001):
117150
fac1 = rot_angles_inv * rot_angles.sin()
118151
fac2 = rot_angles_inv * rot_angles_inv * (1.0 - rot_angles.cos())
119152
skews = hat(log_rot)
153+
skews_square = torch.bmm(skews, skews)
120154

121155
R = (
122156
# pyre-fixme[16]: `float` has no attribute `__getitem__`.
123157
fac1[:, None, None] * skews
124-
+ fac2[:, None, None] * torch.bmm(skews, skews)
158+
+ fac2[:, None, None] * skews_square
125159
+ torch.eye(3, dtype=log_rot.dtype, device=log_rot.device)[None]
126160
)
127161

128-
return R
162+
return R, rot_angles, skews, skews_square
129163

130164

131-
def so3_log_map(R, eps: float = 0.0001):
165+
def so3_log_map(
166+
R: torch.Tensor, eps: float = 0.0001, cos_bound: float = 1e-4
167+
) -> torch.Tensor:
132168
"""
133169
Convert a batch of 3x3 rotation matrices `R`
134170
to a batch of 3-dimensional matrix logarithms of rotation matrices
135171
The conversion has a singularity around `(R=I)` which is handled
136-
by clamping controlled with the `eps` argument.
172+
by clamping controlled with the `eps` and `cos_bound` arguments.
137173
138174
Args:
139175
R: batch of rotation matrices of shape `(minibatch, 3, 3)`.
140176
eps: A float constant handling the conversion singularity.
177+
cos_bound: Clamps the cosine of the rotation angle to
178+
[-1 + cos_bound, 1 - cos_bound] to avoid non-finite outputs/gradients
179+
of the `acos` call when computing `so3_rotation_angle`.
180+
Note that the non-finite outputs/gradients are returned when
181+
the rotation angle is close to 0 or π.
141182
142183
Returns:
143184
Batch of logarithms of input rotation matrices
@@ -152,22 +193,26 @@ def so3_log_map(R, eps: float = 0.0001):
152193
if dim1 != 3 or dim2 != 3:
153194
raise ValueError("Input has to be a batch of 3x3 Tensors.")
154195

155-
phi = so3_rotation_angle(R)
196+
phi = so3_rotation_angle(R, cos_bound=cos_bound, eps=eps)
156197

157-
phi_sin = phi.sin()
198+
phi_sin = torch.sin(phi)
158199

159-
phi_denom = (
160-
torch.clamp(phi_sin.abs(), eps) * phi_sin.sign()
161-
+ (phi_sin == 0).type_as(phi) * eps
162-
)
200+
# We want to avoid a tiny denominator of phi_factor = phi / (2.0 * phi_sin).
201+
# Hence, for phi_sin.abs() <= 0.5 * eps, we approximate phi_factor with
202+
# 2nd order Taylor expansion: phi_factor = 0.5 + (1.0 / 12) * phi**2
203+
phi_factor = torch.empty_like(phi)
204+
ok_denom = phi_sin.abs() > (0.5 * eps)
205+
phi_factor[~ok_denom] = 0.5 + (phi[~ok_denom] ** 2) * (1.0 / 12)
206+
phi_factor[ok_denom] = phi[ok_denom] / (2.0 * phi_sin[ok_denom])
207+
208+
log_rot_hat = phi_factor[:, None, None] * (R - R.permute(0, 2, 1))
163209

164-
log_rot_hat = (phi / (2.0 * phi_denom))[:, None, None] * (R - R.permute(0, 2, 1))
165210
log_rot = hat_inv(log_rot_hat)
166211

167212
return log_rot
168213

169214

170-
def hat_inv(h):
215+
def hat_inv(h: torch.Tensor) -> torch.Tensor:
171216
"""
172217
Compute the inverse Hat operator [1] of a batch of 3x3 matrices.
173218
@@ -188,9 +233,9 @@ def hat_inv(h):
188233
if dim1 != 3 or dim2 != 3:
189234
raise ValueError("Input has to be a batch of 3x3 Tensors.")
190235

191-
ss_diff = (h + h.permute(0, 2, 1)).abs().max()
236+
ss_diff = torch.abs(h + h.permute(0, 2, 1)).max()
192237
if float(ss_diff) > HAT_INV_SKEW_SYMMETRIC_TOL:
193-
raise ValueError("One of input matrices not skew-symmetric.")
238+
raise ValueError("One of input matrices is not skew-symmetric.")
194239

195240
x = h[:, 2, 1]
196241
y = h[:, 0, 2]
@@ -201,7 +246,7 @@ def hat_inv(h):
201246
return v
202247

203248

204-
def hat(v):
249+
def hat(v: torch.Tensor) -> torch.Tensor:
205250
"""
206251
Compute the Hat operator [1] of a batch of 3D vectors.
207252
@@ -225,7 +270,7 @@ def hat(v):
225270
if dim != 3:
226271
raise ValueError("Input vectors have to be 3-dimensional.")
227272

228-
h = v.new_zeros(N, 3, 3)
273+
h = torch.zeros((N, 3, 3), dtype=v.dtype, device=v.device)
229274

230275
x, y, z = v.unbind(1)
231276

0 commit comments

Comments
 (0)