1
1
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
2
2
3
+ from typing import Tuple
3
4
4
5
import torch
5
6
7
+ from ..transforms import acos_linear_extrapolation
6
8
7
9
HAT_INV_SKEW_SYMMETRIC_TOL = 1e-5
8
10
9
11
10
- def so3_relative_angle (R1 , R2 , cos_angle : bool = False ):
12
+ def so3_relative_angle (
13
+ R1 : torch .Tensor ,
14
+ R2 : torch .Tensor ,
15
+ cos_angle : bool = False ,
16
+ cos_bound : float = 1e-4 ,
17
+ ) -> torch .Tensor :
11
18
"""
12
19
Calculates the relative angle (in radians) between pairs of
13
20
rotation matrices `R1` and `R2` with `angle = acos(0.5 * (Trace(R1 R2^T)-1))`
@@ -20,8 +27,12 @@ def so3_relative_angle(R1, R2, cos_angle: bool = False):
20
27
R1: Batch of rotation matrices of shape `(minibatch, 3, 3)`.
21
28
R2: Batch of rotation matrices of shape `(minibatch, 3, 3)`.
22
29
cos_angle: If==True return cosine of the relative angle rather than
23
- the angle itself. This can avoid the unstable
24
- calculation of `acos`.
30
+ the angle itself. This can avoid the unstable calculation of `acos`.
31
+ cos_bound: Clamps the cosine of the relative rotation angle to
32
+ [-1 + cos_bound, 1 - cos_bound] to avoid non-finite outputs/gradients
33
+ of the `acos` call. Note that the non-finite outputs/gradients
34
+ are returned when the angle is requested (i.e. `cos_angle==False`)
35
+ and the rotation angle is close to 0 or π.
25
36
26
37
Returns:
27
38
Corresponding rotation angles of shape `(minibatch,)`.
@@ -32,10 +43,15 @@ def so3_relative_angle(R1, R2, cos_angle: bool = False):
32
43
ValueError if `R1` or `R2` has an unexpected trace.
33
44
"""
34
45
R12 = torch .bmm (R1 , R2 .permute (0 , 2 , 1 ))
35
- return so3_rotation_angle (R12 , cos_angle = cos_angle )
46
+ return so3_rotation_angle (R12 , cos_angle = cos_angle , cos_bound = cos_bound )
36
47
37
48
38
- def so3_rotation_angle (R , eps : float = 1e-4 , cos_angle : bool = False ):
49
+ def so3_rotation_angle (
50
+ R : torch .Tensor ,
51
+ eps : float = 1e-4 ,
52
+ cos_angle : bool = False ,
53
+ cos_bound : float = 1e-4 ,
54
+ ) -> torch .Tensor :
39
55
"""
40
56
Calculates angles (in radians) of a batch of rotation matrices `R` with
41
57
`angle = acos(0.5 * (Trace(R)-1))`. The trace of the
@@ -47,8 +63,13 @@ def so3_rotation_angle(R, eps: float = 1e-4, cos_angle: bool = False):
47
63
R: Batch of rotation matrices of shape `(minibatch, 3, 3)`.
48
64
eps: Tolerance for the valid trace check.
49
65
cos_angle: If==True return cosine of the rotation angles rather than
50
- the angle itself. This can avoid the unstable
51
- calculation of `acos`.
66
+ the angle itself. This can avoid the unstable
67
+ calculation of `acos`.
68
+ cos_bound: Clamps the cosine of the rotation angle to
69
+ [-1 + cos_bound, 1 - cos_bound] to avoid non-finite outputs/gradients
70
+ of the `acos` call. Note that the non-finite outputs/gradients
71
+ are returned when the angle is requested (i.e. `cos_angle==False`)
72
+ and the rotation angle is close to 0 or π.
52
73
53
74
Returns:
54
75
Corresponding rotation angles of shape `(minibatch,)`.
@@ -68,20 +89,19 @@ def so3_rotation_angle(R, eps: float = 1e-4, cos_angle: bool = False):
68
89
if ((rot_trace < - 1.0 - eps ) + (rot_trace > 3.0 + eps )).any ():
69
90
raise ValueError ("A matrix has trace outside valid range [-1-eps,3+eps]." )
70
91
71
- # clamp to valid range
72
- rot_trace = torch .clamp (rot_trace , - 1.0 , 3.0 )
73
-
74
92
# phi ... rotation angle
75
- phi = 0.5 * (rot_trace - 1.0 )
93
+ phi_cos = (rot_trace - 1.0 ) * 0.5
76
94
77
95
if cos_angle :
78
- return phi
96
+ return phi_cos
79
97
else :
80
- # pyre-fixme[16]: `float` has no attribute `acos`.
81
- return phi .acos ()
98
+ if cos_bound > 0.0 :
99
+ return acos_linear_extrapolation (phi_cos , 1.0 - cos_bound )
100
+ else :
101
+ return torch .acos (phi_cos )
82
102
83
103
84
- def so3_exponential_map (log_rot , eps : float = 0.0001 ):
104
+ def so3_exp_map (log_rot : torch . Tensor , eps : float = 0.0001 ) -> torch . Tensor :
85
105
"""
86
106
Convert a batch of logarithmic representations of rotation matrices `log_rot`
87
107
to a batch of 3x3 rotation matrices using Rodrigues formula [1].
@@ -94,18 +114,31 @@ def so3_exponential_map(log_rot, eps: float = 0.0001):
94
114
which is handled by clamping controlled with the `eps` argument.
95
115
96
116
Args:
97
- log_rot: Batch of vectors of shape `(minibatch , 3)`.
117
+ log_rot: Batch of vectors of shape `(minibatch, 3)`.
98
118
eps: A float constant handling the conversion singularity.
99
119
100
120
Returns:
101
- Batch of rotation matrices of shape `(minibatch , 3 , 3)`.
121
+ Batch of rotation matrices of shape `(minibatch, 3, 3)`.
102
122
103
123
Raises:
104
124
ValueError if `log_rot` is of incorrect shape.
105
125
106
126
[1] https://en.wikipedia.org/wiki/Rodrigues%27_rotation_formula
107
127
"""
128
+ return _so3_exp_map (log_rot , eps = eps )[0 ]
129
+
130
+
131
+ so3_exponential_map = so3_exp_map
132
+
108
133
134
+ def _so3_exp_map (
135
+ log_rot : torch .Tensor , eps : float = 0.0001
136
+ ) -> Tuple [torch .Tensor , torch .Tensor , torch .Tensor , torch .Tensor ]:
137
+ """
138
+ A helper function that computes the so3 exponential map and,
139
+ apart from the rotation matrix, also returns intermediate variables
140
+ that can be re-used in other functions.
141
+ """
109
142
_ , dim = log_rot .shape
110
143
if dim != 3 :
111
144
raise ValueError ("Input tensor shape has to be Nx3." )
@@ -117,27 +150,35 @@ def so3_exponential_map(log_rot, eps: float = 0.0001):
117
150
fac1 = rot_angles_inv * rot_angles .sin ()
118
151
fac2 = rot_angles_inv * rot_angles_inv * (1.0 - rot_angles .cos ())
119
152
skews = hat (log_rot )
153
+ skews_square = torch .bmm (skews , skews )
120
154
121
155
R = (
122
156
# pyre-fixme[16]: `float` has no attribute `__getitem__`.
123
157
fac1 [:, None , None ] * skews
124
- + fac2 [:, None , None ] * torch . bmm ( skews , skews )
158
+ + fac2 [:, None , None ] * skews_square
125
159
+ torch .eye (3 , dtype = log_rot .dtype , device = log_rot .device )[None ]
126
160
)
127
161
128
- return R
162
+ return R , rot_angles , skews , skews_square
129
163
130
164
131
- def so3_log_map (R , eps : float = 0.0001 ):
165
+ def so3_log_map (
166
+ R : torch .Tensor , eps : float = 0.0001 , cos_bound : float = 1e-4
167
+ ) -> torch .Tensor :
132
168
"""
133
169
Convert a batch of 3x3 rotation matrices `R`
134
170
to a batch of 3-dimensional matrix logarithms of rotation matrices
135
171
The conversion has a singularity around `(R=I)` which is handled
136
- by clamping controlled with the `eps` argument .
172
+ by clamping controlled with the `eps` and `cos_bound` arguments .
137
173
138
174
Args:
139
175
R: batch of rotation matrices of shape `(minibatch, 3, 3)`.
140
176
eps: A float constant handling the conversion singularity.
177
+ cos_bound: Clamps the cosine of the rotation angle to
178
+ [-1 + cos_bound, 1 - cos_bound] to avoid non-finite outputs/gradients
179
+ of the `acos` call when computing `so3_rotation_angle`.
180
+ Note that the non-finite outputs/gradients are returned when
181
+ the rotation angle is close to 0 or π.
141
182
142
183
Returns:
143
184
Batch of logarithms of input rotation matrices
@@ -152,22 +193,26 @@ def so3_log_map(R, eps: float = 0.0001):
152
193
if dim1 != 3 or dim2 != 3 :
153
194
raise ValueError ("Input has to be a batch of 3x3 Tensors." )
154
195
155
- phi = so3_rotation_angle (R )
196
+ phi = so3_rotation_angle (R , cos_bound = cos_bound , eps = eps )
156
197
157
- phi_sin = phi .sin ()
198
+ phi_sin = torch .sin (phi )
158
199
159
- phi_denom = (
160
- torch .clamp (phi_sin .abs (), eps ) * phi_sin .sign ()
161
- + (phi_sin == 0 ).type_as (phi ) * eps
162
- )
200
+ # We want to avoid a tiny denominator of phi_factor = phi / (2.0 * phi_sin).
201
+ # Hence, for phi_sin.abs() <= 0.5 * eps, we approximate phi_factor with
202
+ # 2nd order Taylor expansion: phi_factor = 0.5 + (1.0 / 12) * phi**2
203
+ phi_factor = torch .empty_like (phi )
204
+ ok_denom = phi_sin .abs () > (0.5 * eps )
205
+ phi_factor [~ ok_denom ] = 0.5 + (phi [~ ok_denom ] ** 2 ) * (1.0 / 12 )
206
+ phi_factor [ok_denom ] = phi [ok_denom ] / (2.0 * phi_sin [ok_denom ])
207
+
208
+ log_rot_hat = phi_factor [:, None , None ] * (R - R .permute (0 , 2 , 1 ))
163
209
164
- log_rot_hat = (phi / (2.0 * phi_denom ))[:, None , None ] * (R - R .permute (0 , 2 , 1 ))
165
210
log_rot = hat_inv (log_rot_hat )
166
211
167
212
return log_rot
168
213
169
214
170
- def hat_inv (h ) :
215
+ def hat_inv (h : torch . Tensor ) -> torch . Tensor :
171
216
"""
172
217
Compute the inverse Hat operator [1] of a batch of 3x3 matrices.
173
218
@@ -188,9 +233,9 @@ def hat_inv(h):
188
233
if dim1 != 3 or dim2 != 3 :
189
234
raise ValueError ("Input has to be a batch of 3x3 Tensors." )
190
235
191
- ss_diff = (h + h .permute (0 , 2 , 1 )). abs ( ).max ()
236
+ ss_diff = torch . abs (h + h .permute (0 , 2 , 1 )).max ()
192
237
if float (ss_diff ) > HAT_INV_SKEW_SYMMETRIC_TOL :
193
- raise ValueError ("One of input matrices not skew-symmetric." )
238
+ raise ValueError ("One of input matrices is not skew-symmetric." )
194
239
195
240
x = h [:, 2 , 1 ]
196
241
y = h [:, 0 , 2 ]
@@ -201,7 +246,7 @@ def hat_inv(h):
201
246
return v
202
247
203
248
204
- def hat (v ) :
249
+ def hat (v : torch . Tensor ) -> torch . Tensor :
205
250
"""
206
251
Compute the Hat operator [1] of a batch of 3D vectors.
207
252
@@ -225,7 +270,7 @@ def hat(v):
225
270
if dim != 3 :
226
271
raise ValueError ("Input vectors have to be 3-dimensional." )
227
272
228
- h = v . new_zeros ( N , 3 , 3 )
273
+ h = torch . zeros (( N , 3 , 3 ), dtype = v . dtype , device = v . device )
229
274
230
275
x , y , z = v .unbind (1 )
231
276
0 commit comments