Skip to content

Allow indexing for classes inheriting Transform3d #1801

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed

Conversation

ListIndexOutOfRange
Copy link
Contributor

Currently, it is not possible to access a sub-transform using an indexer for all 3d transforms inheriting the Transforms3d class.
For instance:

from pytorch3d import transforms

N = 10
r = transforms.random_rotations(N)
T = transforms.Transform3d().rotate(R=r)
R = transforms.Rotate(r)

x = T[0]  # ok
x = R[0]  # TypeError: __init__() got an unexpected keyword argument 'matrix'

This is because all these classes (namely Rotate, Translate, Scale, RotateAxisAngle) inherit the __getitem__() method from Transform3d which has the following code on line 201:

return self.__class__(matrix=self.get_matrix()[index])

The four classes inheriting Transform3d are not initialized through a matrix argument, hence they error.
I propose to modify the __getitem__() method of the Transform3d class to fix this behavior. The least invasive way to do it I can think of consists of creating an empty instance of the current class, then setting the _matrix attribute manually. Thus, instead of

return self.__class__(matrix=self.get_matrix()[index])

I propose to do:

instance = self.__class__.__new__(self.__class__)
instance._matrix = self.get_matrix()[index]
return instance

As far as I can tell, this modification occurs no modification whatsoever for the user, except for the ability to index all 3d transforms.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label May 23, 2024
@bottler
Copy link
Contributor

bottler commented May 23, 2024

Thanks for the report. Could the offending line be more simply changed from:

self.__class__(matrix=self.get_matrix()[index])

to

Transform3d(matrix=self.get_matrix()[index])

?

The use of __class__ seems to be to enable subclasses to work nicely, but it fails completely.

If indexing a Rotate, Translate, Scale or RotateAxisAngle instance should return another such instance, it would be best to give those classes their own __getitem__s.

@ListIndexOutOfRange
Copy link
Contributor Author

Ok so I updated the code.

Indeed, I think changing the class by indexing (i.e. R[0] being an instance of Transform3d when R is an instance of Rotate) should be avoided.
On the other hand, I feel like it's a bit too much to write a __getitem__ method for every class inheriting Transform3d.
As Transforms3d has 4 attributes only, I propose to add the following lines to Transform3d's __getitem method:

for attr in ('_transforms', '_lu', 'device', 'dtype'):
    setattr(instance, attr, getattr(self, attr))

@bottler
Copy link
Contributor

bottler commented May 31, 2024

I think I disagree with both these points. Rotate, Scale etc are effectively just alternate ways to initialize a Transform3d, so there's nothing wrong if R[0] returns a Transform3d. And writing a __getitem__ for each class would not be much code.

@ListIndexOutOfRange
Copy link
Contributor Author

I see your point but

  • if Rotate, Scale, etc... are ways to initialize Transform3d from a pure programming point of view, they correspond to mathematical (nested) groups (SO(3), SE(3), SIM(3)). I'd like to obtain an element of S0(3) when indexing a collection of SO(3) elements. Furthermore, I can easily imagine code relying on type checking (python isinstance(R, Rotate)). Overall, I don't see why not do it, as it just feels more like "the way it should be".

  • as for the way to implement it, sure, giving each subclass its own __getitem__ method is not too much work, but don't you think the code using settatr does the job ? (which is basically what __init__ is doing anyway)

@ListIndexOutOfRange
Copy link
Contributor Author

Well in the end, I implemented the __getitem__ method for each subclass. Sorry, I didn't understand at first why the other way wasn't working properly.
Now, the indexing will return the same type as the indexed class, and behavior should be correct.
Let met now what you think !

@bottler
Copy link
Contributor

bottler commented Jun 10, 2024

Implementation is good. Would you be able to add a test case for each? Then I'd be happy to merge.

@ListIndexOutOfRange
Copy link
Contributor Author

I'm not exactly sure how to do it to be honest, but here is a first attempt.
As far as I understand, it is not required to check every type of indexing (int, list slice, ...), but let me know if that's not the case.

@bottler
Copy link
Contributor

bottler commented Jun 11, 2024

Looks good. We basically just need something. This PR is fine, and will hopefully get merged soon. Thank you!

@facebook-github-bot
Copy link
Contributor

@bottler has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@bottler merged this pull request in b0462d8.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. Merged
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants