Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Allow indexing for classes inheriting Transform3d (#1801)
Summary: Currently, it is not possible to access a sub-transform using an indexer for all 3d transforms inheriting the `Transforms3d` class. For instance: ```python from pytorch3d import transforms N = 10 r = transforms.random_rotations(N) T = transforms.Transform3d().rotate(R=r) R = transforms.Rotate(r) x = T[0] # ok x = R[0] # TypeError: __init__() got an unexpected keyword argument 'matrix' ``` This is because all these classes (namely `Rotate`, `Translate`, `Scale`, `RotateAxisAngle`) inherit the `__getitem__()` method from `Transform3d` which has the [following code on line 201](https://github.com/facebookresearch/pytorch3d/blob/main/pytorch3d/transforms/transform3d.py#L201): ```python return self.__class__(matrix=self.get_matrix()[index]) ``` The four classes inheriting `Transform3d` are not initialized through a matrix argument, hence they error. I propose to modify the `__getitem__()` method of the `Transform3d` class to fix this behavior. The least invasive way to do it I can think of consists of creating an empty instance of the current class, then setting the `_matrix` attribute manually. Thus, instead of ```python return self.__class__(matrix=self.get_matrix()[index]) ``` I propose to do: ```python instance = self.__class__.__new__(self.__class__) instance._matrix = self.get_matrix()[index] return instance ``` As far as I can tell, this modification occurs no modification whatsoever for the user, except for the ability to index all 3d transforms. Pull Request resolved: #1801 Reviewed By: MichaelRamamonjisoa Differential Revision: D58410389 Pulled By: bottler fbshipit-source-id: f371e4c63d2ae4c927a7ad48c2de8862761078de
- Loading branch information