Skip to content

Commit

Permalink
Allow indexing for classes inheriting Transform3d (#1801)
Browse files Browse the repository at this point in the history
Summary:
Currently, it is not possible to access a sub-transform using an indexer for all 3d transforms inheriting the `Transforms3d` class.
For instance:

```python
from pytorch3d import transforms

N = 10
r = transforms.random_rotations(N)
T = transforms.Transform3d().rotate(R=r)
R = transforms.Rotate(r)

x = T[0]  # ok
x = R[0]  # TypeError: __init__() got an unexpected keyword argument 'matrix'
```

This is because all these classes (namely `Rotate`, `Translate`, `Scale`, `RotateAxisAngle`) inherit the `__getitem__()` method from `Transform3d` which has the [following code on line 201](https://github.com/facebookresearch/pytorch3d/blob/main/pytorch3d/transforms/transform3d.py#L201):

```python
return self.__class__(matrix=self.get_matrix()[index])
```

The four classes inheriting `Transform3d` are not initialized through a matrix argument, hence they error.
I propose to modify the `__getitem__()` method of the `Transform3d` class to fix this behavior. The least invasive way to do it I can think of consists of creating an empty instance of the current class, then setting the `_matrix` attribute manually. Thus, instead of
```python
return self.__class__(matrix=self.get_matrix()[index])
```
I propose to do:
```python
instance = self.__class__.__new__(self.__class__)
instance._matrix = self.get_matrix()[index]
return instance
```

As far as I can tell, this modification occurs no modification whatsoever for the user, except for the ability to index all 3d transforms.

Pull Request resolved: #1801

Reviewed By: MichaelRamamonjisoa

Differential Revision: D58410389

Pulled By: bottler

fbshipit-source-id: f371e4c63d2ae4c927a7ad48c2de8862761078de
  • Loading branch information
ListIndexOutOfRange authored and facebook-github-bot committed Jun 17, 2024
1 parent b66d17a commit b0462d8
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 0 deletions.
52 changes: 52 additions & 0 deletions pytorch3d/transforms/transform3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -564,6 +564,22 @@ def _get_matrix_inverse(self) -> torch.Tensor:
i_matrix = self._matrix * inv_mask
return i_matrix

def __getitem__(
self, index: Union[int, List[int], slice, torch.BoolTensor, torch.LongTensor]
) -> "Transform3d":
"""
Args:
index: Specifying the index of the transform to retrieve.
Can be an int, slice, list of ints, boolean, long tensor.
Supports negative indices.
Returns:
Transform3d object with selected transforms. The tensors are not cloned.
"""
if isinstance(index, int):
index = [index]
return self.__class__(self.get_matrix()[index, 3, :3])


class Scale(Transform3d):
def __init__(
Expand Down Expand Up @@ -613,6 +629,26 @@ def _get_matrix_inverse(self) -> torch.Tensor:
imat = torch.diag_embed(ixyz, dim1=1, dim2=2)
return imat

def __getitem__(
self, index: Union[int, List[int], slice, torch.BoolTensor, torch.LongTensor]
) -> "Transform3d":
"""
Args:
index: Specifying the index of the transform to retrieve.
Can be an int, slice, list of ints, boolean, long tensor.
Supports negative indices.
Returns:
Transform3d object with selected transforms. The tensors are not cloned.
"""
if isinstance(index, int):
index = [index]
mat = self.get_matrix()[index]
x = mat[:, 0, 0]
y = mat[:, 1, 1]
z = mat[:, 2, 2]
return self.__class__(x, y, z)


class Rotate(Transform3d):
def __init__(
Expand Down Expand Up @@ -655,6 +691,22 @@ def _get_matrix_inverse(self) -> torch.Tensor:
"""
return self._matrix.permute(0, 2, 1).contiguous()

def __getitem__(
self, index: Union[int, List[int], slice, torch.BoolTensor, torch.LongTensor]
) -> "Transform3d":
"""
Args:
index: Specifying the index of the transform to retrieve.
Can be an int, slice, list of ints, boolean, long tensor.
Supports negative indices.
Returns:
Transform3d object with selected transforms. The tensors are not cloned.
"""
if isinstance(index, int):
index = [index]
return self.__class__(self.get_matrix()[index, :3, :3])


class RotateAxisAngle(Rotate):
def __init__(
Expand Down
27 changes: 27 additions & 0 deletions tests/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -685,6 +685,15 @@ def test_inverse(self):
self.assertTrue(torch.allclose(im, im_comp))
self.assertTrue(torch.allclose(im, im_2))

def test_get_item(self, batch_size=5):
device = torch.device("cuda:0")
xyz = torch.randn(size=[batch_size, 3], device=device, dtype=torch.float32)
t3d = Translate(xyz)
index = 1
t3d_selected = t3d[index]
self.assertEqual(len(t3d_selected), 1)
self.assertIsInstance(t3d_selected, Translate)


class TestScale(unittest.TestCase):
def test_single_python_scalar(self):
Expand Down Expand Up @@ -871,6 +880,15 @@ def test_inverse(self):
self.assertTrue(torch.allclose(im, im_comp))
self.assertTrue(torch.allclose(im, im_2))

def test_get_item(self, batch_size=5):
device = torch.device("cuda:0")
s = torch.randn(size=[batch_size, 3], device=device, dtype=torch.float32)
t3d = Scale(s)
index = 1
t3d_selected = t3d[index]
self.assertEqual(len(t3d_selected), 1)
self.assertIsInstance(t3d_selected, Scale)


class TestTransformBroadcast(unittest.TestCase):
def test_broadcast_transform_points(self):
Expand Down Expand Up @@ -986,6 +1004,15 @@ def test_inverse(self, batch_size=5):
self.assertTrue(torch.allclose(im, im_comp, atol=1e-4))
self.assertTrue(torch.allclose(im, im_2, atol=1e-4))

def test_get_item(self, batch_size=5):
device = torch.device("cuda:0")
r = random_rotations(batch_size, dtype=torch.float32, device=device)
t3d = Rotate(r)
index = 1
t3d_selected = t3d[index]
self.assertEqual(len(t3d_selected), 1)
self.assertIsInstance(t3d_selected, Rotate)


class TestRotateAxisAngle(unittest.TestCase):
def test_rotate_x_python_scalar(self):
Expand Down

0 comments on commit b0462d8

Please sign in to comment.