Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add adjoint, hat and vee for SE3 #68

Merged
merged 10 commits into from
Feb 9, 2022
1 change: 1 addition & 0 deletions theseus/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
)
from .geometry import (
SE2,
SE3,
SO2,
SO3,
LieGroup,
Expand Down
1 change: 1 addition & 0 deletions theseus/geometry/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from .manifold import Manifold, OptionalJacobians, local, retract
from .point_types import Point2, Point3
from .se2 import SE2
from .se3 import SE3
from .so2 import SO2
from .so3 import SO3
from .vector import Vector
175 changes: 175 additions & 0 deletions theseus/geometry/se3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from typing import List, Optional, Union, cast

import torch

import theseus
import theseus.constants

from .lie_group import LieGroup
from .point_types import Point3
from .so3 import SO3


class SE3(LieGroup):
def __init__(
self,
x_y_z_quaternion: Optional[torch.Tensor] = None,
mhmukadam marked this conversation as resolved.
Show resolved Hide resolved
data: Optional[torch.Tensor] = None,
name: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
):
if x_y_z_quaternion is not None and data is not None:
raise ValueError("Please provide only one of x_y_z_quaternion or data.")
mhmukadam marked this conversation as resolved.
Show resolved Hide resolved
if x_y_z_quaternion is not None:
dtype = x_y_z_quaternion.dtype
if data is not None:
self._SE3_matrix_check(data)
super().__init__(data=data, name=name, dtype=dtype)
if x_y_z_quaternion is not None:
self.update_from_x_y_z_quaternion(x_y_z_quaternion=x_y_z_quaternion)

@staticmethod
def _init_data() -> torch.Tensor: # type: ignore
return torch.eye(3, 4).view(1, 3, 4)

def dof(self) -> int:
return 6

def __repr__(self) -> str:
return f"SE3(data={self.data}, name={self.name})"

def __str__(self) -> str:
with torch.no_grad():
return f"SE3(matrix={self.data}), name={self.name})"

def _adjoint_impl(self) -> torch.Tensor:
ret = torch.zeros(self.shape[0], 6, 6).to(dtype=self.dtype, device=self.device)
ret[:, :3, :3] = self[:, :3, :3]
ret[:, 3:, 3:] = self[:, :3, :3]
ret[:, :3, 3:] = SO3.hat(self[:, :3, 3]) @ self[:, :3, :3]

return ret

@staticmethod
def _SE3_matrix_check(matrix: torch.Tensor):
if matrix.ndim != 3 or matrix.shape[1:] != (3, 4):
raise ValueError("SE(3) can only be 3x4 matrices.")
SO3._SO3_matrix_check(matrix.data[:, :3, :3])

@staticmethod
def x_y_z_unit_quaternion_to_SE3(x_y_z_quaternion: torch.Tensor) -> "SE3":
if x_y_z_quaternion.ndim == 1:
x_y_z_quaternion = x_y_z_quaternion.unsqueeze(0)

if x_y_z_quaternion.ndim != 2 and x_y_z_quaternion.shape[1] != 7:
raise ValueError("x_y_z_quaternion can only be 7-D vectors.")

ret = SE3()

batch_size = x_y_z_quaternion.shape[0]
ret.data = torch.empty(batch_size, 3, 4).to(
device=x_y_z_quaternion.device, dtype=x_y_z_quaternion.dtype
)
ret[:, :3, :3] = SO3.unit_quaternion_to_SO3(x_y_z_quaternion[:, 3:]).data
ret[:, :3, 3] = x_y_z_quaternion[:, :3]

return ret

@staticmethod
def _hat_matrix_check(matrix: torch.Tensor):
if matrix.ndim != 3 or matrix.shape[1:] != (4, 4):
raise ValueError("Hat matrices of SE(3) can only be 4x4 matrices")

if matrix[:, 3].abs().max().item() > theseus.constants.EPS:
raise ValueError("The last row of hat matrices of SE(3) can only be zero.")

if (
matrix[:, :3, :3].transpose(1, 2) + matrix[:, :3, :3]
).abs().max().item() > theseus.constants.EPS:
raise ValueError(
"The 3x3 top-left corner of hat matrices of SE(3) can only be skew-symmetric."
)

@staticmethod
def exp_map(tangent_vector: torch.Tensor) -> LieGroup:
raise NotImplementedError

def _log_map_impl(self) -> torch.Tensor:
raise NotImplementedError

def _compose_impl(self, so3_2: LieGroup) -> "SE3":
raise NotImplementedError

def _inverse_impl(self, get_jacobian: bool = False) -> "SE3":
ret = torch.zeros(self.shape[0], 3, 4).to(dtype=self.dtype, device=self.device)
ret[:, :, :3] = self.data[:, :3, :3].transpose(1, 2)
ret[:, :, 3] = -(ret[:, :3, :3] @ self.data[:, :3, 3].unsqueeze(2)).view(-1, 3)
return SE3(data=ret)

def to_matrix(self) -> torch.Tensor:
ret = torch.zeros(self.shape[0], 4, 4).to(dtype=self.dtype, device=self.device)
ret[:, :3] = self.data
ret[:, 3, 3] = 1
return ret

def update_from_x_y_z_quaternion(self, x_y_z_quaternion: torch.Tensor):
self.update(SE3.x_y_z_unit_quaternion_to_SE3(x_y_z_quaternion))

def update_from_rot_and_trans(self, rotation: SO3, translation: Point3):
if rotation.shape[0] != translation.shape[0]:
raise ValueError("rotation and translation must have the same batch size.")
mhmukadam marked this conversation as resolved.
Show resolved Hide resolved

if rotation.dtype != translation.dtype:
raise ValueError("rotation and translation must be of the same type.")

if rotation.device != translation.device:
raise ValueError("rotation and translation must be on the same device.")

self.data = torch.cat((rotation.data, translation.data.unsqueeze(2)), dim=2)

@staticmethod
def hat(tangent_vector: torch.Tensor) -> torch.Tensor:
_check = tangent_vector.ndim == 2 and tangent_vector.shape[1] == 6
if not _check:
raise ValueError("Invalid vee matrix for SE(3).")
matrix = torch.zeros(tangent_vector.shape[0], 4, 4).to(
dtype=tangent_vector.dtype, device=tangent_vector.device
)
matrix[:, :3, :3] = SO3.hat(tangent_vector[:, 3:])
matrix[:, :3, 3] = tangent_vector[:, :3]

return matrix

@staticmethod
def vee(matrix: torch.Tensor) -> torch.Tensor:
SE3._hat_matrix_check(matrix)
return torch.cat((matrix[:, :3, 3], SO3.vee(matrix[:, :3, :3])), dim=1)

def _transform_shape_check(self, point: Union[Point3, torch.Tensor]):
raise NotImplementedError

def _copy_impl(self, new_name: Optional[str] = None) -> "SE3":
return SE3(data=self.data.clone(), name=new_name)

# only added to avoid casting downstream
def copy(self, new_name: Optional[str] = None) -> "SE3":
return cast(SE3, super().copy(new_name=new_name))

def transform_to(
self,
point: Union[Point3, torch.Tensor],
jacobians: Optional[List[torch.Tensor]] = None,
) -> Point3:
raise NotImplementedError

def transform_from(
self,
point: Union[Point3, torch.Tensor],
jacobians: Optional[List[torch.Tensor]] = None,
) -> Point3:
raise NotImplementedError
13 changes: 8 additions & 5 deletions theseus/geometry/so3.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,16 +31,14 @@ def __init__(
self._SO3_matrix_check(data)
super().__init__(data=data, name=name, dtype=dtype)
if quaternion is not None:
if quaternion.ndim == 1:
quaternion = quaternion.unsqueeze(0)
self.update_from_unit_quaternion(quaternion)

@staticmethod
def _init_data() -> torch.Tensor: # type: ignore
return torch.eye(3, 3).view(1, 3, 3)

def update_from_unit_quaternion(self, quaternion: torch.Tensor):
self.update(self.unit_quaternion_to_matrix(quaternion))
self.update(self.unit_quaternion_to_SO3(quaternion))

def dof(self) -> int:
return 3
Expand Down Expand Up @@ -244,8 +242,11 @@ def _rotate_shape_check(self, point: Union[Point3, torch.Tensor]):
)

@staticmethod
def unit_quaternion_to_matrix(quaternion: torch.torch.Tensor):
def unit_quaternion_to_SO3(quaternion: torch.torch.Tensor) -> "SO3":
if quaternion.ndim == 1:
quaternion = quaternion.unsqueeze(0)
SO3._unit_quaternion_check(quaternion)

q0 = quaternion[:, 0]
q1 = quaternion[:, 1]
q2 = quaternion[:, 2]
Expand All @@ -260,7 +261,9 @@ def unit_quaternion_to_matrix(quaternion: torch.torch.Tensor):
q22 = q2 * q2
q23 = q2 * q3
q33 = q3 * q3
ret = torch.zeros(quaternion.shape[0], 3, 3).to(

ret = SO3()
ret.data = torch.zeros(quaternion.shape[0], 3, 3).to(
dtype=quaternion.dtype, device=quaternion.device
)
ret[:, 0, 0] = 2 * (q00 + q11) - 1
Expand Down
28 changes: 28 additions & 0 deletions theseus/geometry/tests/test_se3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import pytest # noqa: F401
import torch

import theseus as th

from .common import check_adjoint


def _create_random_se3(batch_size, rng):
x_y_z_quaternion = torch.rand(batch_size, 7, generator=rng).double() - 0.5
quaternion_norm = torch.linalg.norm(x_y_z_quaternion[:, 3:], dim=1, keepdim=True)
x_y_z_quaternion[:, 3:] /= quaternion_norm

return th.SE3(x_y_z_quaternion=x_y_z_quaternion)


def test_adjoint():
rng = torch.Generator()
rng.manual_seed(0)
for batch_size in [1, 20, 100]:
se3 = _create_random_se3(batch_size, rng)
tangent = torch.randn(batch_size, 6).double()
check_adjoint(se3, tangent)
6 changes: 3 additions & 3 deletions theseus/geometry/tests/test_so3.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,10 @@ def check_SO3_log_map(tangent_vector):
assert torch.allclose(error, torch.zeros_like(error), atol=EPS)


def check_SO3_to_quaternion(so3: th.SO3):
def check_SO3_to_quaternion(so3: th.SO3, atol=1e-10):
quaternions = so3.to_quaternion()
assert torch.allclose(
th.SO3(quaternion=quaternions).to_matrix(), so3.to_matrix(), atol=1e-8
th.SO3(quaternion=quaternions).to_matrix(), so3.to_matrix(), atol=atol
)


Expand Down Expand Up @@ -125,7 +125,7 @@ def test_quaternion():
tangent_vector /= torch.linalg.norm(tangent_vector, dim=1, keepdim=True)
tangent_vector *= np.pi - 1e-11
so3 = th.SO3.exp_map(tangent_vector)
check_SO3_to_quaternion(so3)
check_SO3_to_quaternion(so3, 1e-7)

for batch_size in [1, 2, 100]:
tangent_vector = torch.rand(batch_size, 3).double() - 0.5
Expand Down