Skip to content

Commit 2f3cd98

Browse files
shapovalovfacebook-github-bot
authored andcommitted
6D representation of rotations.
Summary: Conversion to/from the 6D representation of rotation from the paper http://arxiv.org/abs/1812.07035 ; based on David’s implementation. Reviewed By: davnov134 Differential Revision: D22234397 fbshipit-source-id: 9e25ee93da7e3a2f2068cbe362cb5edc88649ce0
1 parent ce3da64 commit 2f3cd98

File tree

4 files changed

+81
-7
lines changed

4 files changed

+81
-7
lines changed

pytorch3d/transforms/rotation_conversions.py

Lines changed: 45 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from typing import Optional
55

66
import torch
7+
import torch.nn.functional as F
78

89

910
"""
@@ -252,7 +253,7 @@ def random_quaternions(
252253
i.e. versors with nonnegative real part.
253254
254255
Args:
255-
n: Number to return.
256+
n: Number of quaternions in a batch to return.
256257
dtype: Type to return.
257258
device: Desired device of returned tensor. Default:
258259
uses the current device for the default tensor type.
@@ -275,7 +276,7 @@ def random_rotations(
275276
Generate random rotations as 3x3 rotation matrices.
276277
277278
Args:
278-
n: Number to return.
279+
n: Number of rotation matrices in a batch to return.
279280
dtype: Type to return.
280281
device: Device of returned tensor. Default: if None,
281282
uses the current device for the default tensor type.
@@ -400,3 +401,45 @@ def quaternion_apply(quaternion, point):
400401
quaternion_invert(quaternion),
401402
)
402403
return out[..., 1:]
404+
405+
406+
def rotation_6d_to_matrix(d6: torch.Tensor) -> torch.Tensor:
407+
"""
408+
Converts 6D rotation representation by Zhou et al. [1] to rotation matrix
409+
using Gram--Schmidt orthogonalisation per Section B of [1].
410+
Args:
411+
d6: 6D rotation representation, of size (*, 6)
412+
413+
Returns:
414+
batch of rotation matrices of size (*, 3, 3)
415+
416+
[1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H.
417+
On the Continuity of Rotation Representations in Neural Networks.
418+
IEEE Conference on Computer Vision and Pattern Recognition, 2019.
419+
Retrieved from http://arxiv.org/abs/1812.07035
420+
"""
421+
422+
a1, a2 = d6[..., :3], d6[..., 3:]
423+
b1 = F.normalize(a1, dim=-1)
424+
b2 = a2 - (b1 * a2).sum(-1, keepdim=True) * b1
425+
b2 = F.normalize(b2, dim=-1)
426+
b3 = torch.cross(b1, b2, dim=-1)
427+
return torch.stack((b1, b2, b3), dim=-2)
428+
429+
430+
def matrix_to_rotation_6d(matrix: torch.Tensor) -> torch.Tensor:
431+
"""
432+
Converts rotation matrices to 6D rotation representation by Zhou et al. [1]
433+
by dropping the last row. Note that 6D representation is not unique.
434+
Args:
435+
matrix: batch of rotation matrices of size (*, 3, 3)
436+
437+
Returns:
438+
6D rotation representation, of size (*, 6)
439+
440+
[1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H.
441+
On the Continuity of Rotation Representations in Neural Networks.
442+
IEEE Conference on Computer Vision and Pattern Recognition, 2019.
443+
Retrieved from http://arxiv.org/abs/1812.07035
444+
"""
445+
return matrix[..., :2, :].clone().reshape(*matrix.size()[:-2], 6)

pytorch3d/transforms/so3.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
def so3_relative_angle(R1, R2, cos_angle: bool = False):
1111
"""
1212
Calculates the relative angle (in radians) between pairs of
13-
rotation matrices `R1` and `R2` with `angle = acos(0.5 * Trace(R1 R2^T)-1)`
13+
rotation matrices `R1` and `R2` with `angle = acos(0.5 * (Trace(R1 R2^T)-1))`
1414
1515
.. note::
1616
This corresponds to a geodesic distance on the 3D manifold of rotation

pytorch3d/transforms/transform3d.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,9 +45,9 @@ class Transform3d:
4545
4646
.. code-block:: python
4747
48-
y1 = t3.transform_points(t2.transform_points(t2.transform_points(x)))
49-
y2 = t1.compose(t2).compose(t3).transform_points()
50-
y3 = t1.compose(t2, t3).transform_points()
48+
y1 = t3.transform_points(t2.transform_points(t1.transform_points(x)))
49+
y2 = t1.compose(t2).compose(t3).transform_points(x)
50+
y3 = t1.compose(t2, t3).transform_points(x)
5151
5252
5353
Composing transforms should broadcast.

tests/test_rotation_conversions.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,16 +6,19 @@
66
import unittest
77

88
import torch
9+
from common_testing import TestCaseMixin
910
from pytorch3d.transforms.rotation_conversions import (
1011
euler_angles_to_matrix,
1112
matrix_to_euler_angles,
1213
matrix_to_quaternion,
14+
matrix_to_rotation_6d,
1315
quaternion_apply,
1416
quaternion_multiply,
1517
quaternion_to_matrix,
1618
random_quaternions,
1719
random_rotation,
1820
random_rotations,
21+
rotation_6d_to_matrix,
1922
)
2023

2124

@@ -48,7 +51,7 @@ def test_random_rotation_invariant(self):
4851
self.assertLess(chisquare_statistic, 12, (counts, chisquare_statistic, k))
4952

5053

51-
class TestRotationConversion(unittest.TestCase):
54+
class TestRotationConversion(TestCaseMixin, unittest.TestCase):
5255
def setUp(self) -> None:
5356
super().setUp()
5457
torch.manual_seed(1)
@@ -154,3 +157,31 @@ def test_quaternion_application(self):
154157
[p, q] = torch.autograd.grad(transform1.sum(), [points, quaternions])
155158
self.assertTrue(torch.isfinite(p).all())
156159
self.assertTrue(torch.isfinite(q).all())
160+
161+
def test_6d(self):
162+
"""Converting to 6d and back"""
163+
r = random_rotations(13, dtype=torch.float64)
164+
165+
# 6D representation is not unique,
166+
# but we implement it by taking the first two rows of the matrix
167+
r6d = matrix_to_rotation_6d(r)
168+
self.assertClose(r6d, r[:, :2, :].reshape(-1, 6))
169+
170+
# going to 6D and back should not change the matrix
171+
r_hat = rotation_6d_to_matrix(r6d)
172+
self.assertClose(r_hat, r)
173+
174+
# moving the second row R2 in the span of (R1, R2) should not matter
175+
r6d[:, 3:] += 2 * r6d[:, :3]
176+
r6d[:, :3] *= 3.0
177+
r_hat = rotation_6d_to_matrix(r6d)
178+
self.assertClose(r_hat, r)
179+
180+
# check that we map anything to a valid rotation
181+
r6d = torch.rand(13, 6)
182+
r6d[:4, :] *= 3.0
183+
r6d[4:8, :] -= 0.5
184+
r = rotation_6d_to_matrix(r6d)
185+
self.assertClose(
186+
torch.matmul(r, r.permute(0, 2, 1)), torch.eye(3).expand_as(r), atol=1e-6
187+
)

0 commit comments

Comments
 (0)