diff --git a/pytorch3d/transforms/transform3d.py b/pytorch3d/transforms/transform3d.py index b2ee2593..fef29845 100644 --- a/pytorch3d/transforms/transform3d.py +++ b/pytorch3d/transforms/transform3d.py @@ -564,6 +564,22 @@ def _get_matrix_inverse(self) -> torch.Tensor: i_matrix = self._matrix * inv_mask return i_matrix + def __getitem__( + self, index: Union[int, List[int], slice, torch.BoolTensor, torch.LongTensor] + ) -> "Transform3d": + """ + Args: + index: Specifying the index of the transform to retrieve. + Can be an int, slice, list of ints, boolean, long tensor. + Supports negative indices. + + Returns: + Transform3d object with selected transforms. The tensors are not cloned. + """ + if isinstance(index, int): + index = [index] + return self.__class__(self.get_matrix()[index, 3, :3]) + class Scale(Transform3d): def __init__( @@ -613,6 +629,26 @@ def _get_matrix_inverse(self) -> torch.Tensor: imat = torch.diag_embed(ixyz, dim1=1, dim2=2) return imat + def __getitem__( + self, index: Union[int, List[int], slice, torch.BoolTensor, torch.LongTensor] + ) -> "Transform3d": + """ + Args: + index: Specifying the index of the transform to retrieve. + Can be an int, slice, list of ints, boolean, long tensor. + Supports negative indices. + + Returns: + Transform3d object with selected transforms. The tensors are not cloned. + """ + if isinstance(index, int): + index = [index] + mat = self.get_matrix()[index] + x = mat[:, 0, 0] + y = mat[:, 1, 1] + z = mat[:, 2, 2] + return self.__class__(x, y, z) + class Rotate(Transform3d): def __init__( @@ -655,6 +691,22 @@ def _get_matrix_inverse(self) -> torch.Tensor: """ return self._matrix.permute(0, 2, 1).contiguous() + def __getitem__( + self, index: Union[int, List[int], slice, torch.BoolTensor, torch.LongTensor] + ) -> "Transform3d": + """ + Args: + index: Specifying the index of the transform to retrieve. + Can be an int, slice, list of ints, boolean, long tensor. + Supports negative indices. + + Returns: + Transform3d object with selected transforms. The tensors are not cloned. + """ + if isinstance(index, int): + index = [index] + return self.__class__(self.get_matrix()[index, :3, :3]) + class RotateAxisAngle(Rotate): def __init__( diff --git a/tests/test_transforms.py b/tests/test_transforms.py index 5a2d729f..6851afbf 100644 --- a/tests/test_transforms.py +++ b/tests/test_transforms.py @@ -685,6 +685,15 @@ def test_inverse(self): self.assertTrue(torch.allclose(im, im_comp)) self.assertTrue(torch.allclose(im, im_2)) + def test_get_item(self, batch_size=5): + device = torch.device("cuda:0") + xyz = torch.randn(size=[batch_size, 3], device=device, dtype=torch.float32) + t3d = Translate(xyz) + index = 1 + t3d_selected = t3d[index] + self.assertEqual(len(t3d_selected), 1) + self.assertIsInstance(t3d_selected, Translate) + class TestScale(unittest.TestCase): def test_single_python_scalar(self): @@ -871,6 +880,15 @@ def test_inverse(self): self.assertTrue(torch.allclose(im, im_comp)) self.assertTrue(torch.allclose(im, im_2)) + def test_get_item(self, batch_size=5): + device = torch.device("cuda:0") + s = torch.randn(size=[batch_size, 3], device=device, dtype=torch.float32) + t3d = Scale(s) + index = 1 + t3d_selected = t3d[index] + self.assertEqual(len(t3d_selected), 1) + self.assertIsInstance(t3d_selected, Scale) + class TestTransformBroadcast(unittest.TestCase): def test_broadcast_transform_points(self): @@ -986,6 +1004,15 @@ def test_inverse(self, batch_size=5): self.assertTrue(torch.allclose(im, im_comp, atol=1e-4)) self.assertTrue(torch.allclose(im, im_2, atol=1e-4)) + def test_get_item(self, batch_size=5): + device = torch.device("cuda:0") + r = random_rotations(batch_size, dtype=torch.float32, device=device) + t3d = Rotate(r) + index = 1 + t3d_selected = t3d[index] + self.assertEqual(len(t3d_selected), 1) + self.assertIsInstance(t3d_selected, Rotate) + class TestRotateAxisAngle(unittest.TestCase): def test_rotate_x_python_scalar(self):