Skip to content

Fix dtype propagation #1141

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 4 commits into from

Conversation

janEbert
Copy link
Contributor

Previously, dtypes were not propagated correctly in composed transforms, resulting in errors when different dtypes were mixed. Even specifying a dtype in the constructor does not fix this. Neither does specifying the dtype for each composition function invocation (e.g. as a kwarg in rotate_axis_angle).

With the change, I also had to modify the default dtype of RotateAxisAngle, which was torch.float64; it is now torch.float32 like for all other transforms. This was required because the fix in propagation broke some tests due to dtype mismatches.

This change in default dtype in turn broke two tests due to precision changes (calculations that were previously done in torch.float64 were now done in torch.float32), so I changed the precision tolerances to be less strict. I chose the lowest power of ten that passed the tests here.

Sometimes, the dtype argument was simply not propagated, sometimes a
tensor was not converted even though it had a dtype different from the
desired one.
This previously never threw an error due to dtypes not being handled
correctly. With the dtype propagation fix in 5e904ef, we'd either have
to

- pass the `torch.float32` dtype explicitly everywhere or
- adjust the default like done here.
Due to the change in default dtype in 695e92c, the calculations are no
longer done in double, but single precision. Thus, the result is less
precise which we need to handle. Another option would be to use double
precision explicitly.
@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Mar 28, 2022
@janEbert
Copy link
Contributor Author

janEbert commented Mar 28, 2022

For someone that stumbles upon the bug and needs an immediate solution, here's a monkey patch:

import pytorch3d.transforms as tfs

def fix_pytorch3d() -> None:
    """Monkey patch missing dtype propagation using default dtype."""
    # Neither a good nor elegant solution, but it handles the bug.
    tfs.Transform3d.__init__.__defaults__ = (
        (th.get_default_dtype(),)
        + tfs.Transform3d.__init__.__defaults__[1:]
    )
    tfs.Translate.__init__.__defaults__ = (
        tfs.Translate.__init__.__defaults__[:2]
        + (th.get_default_dtype(),)
        + tfs.Translate.__init__.__defaults__[3:]
    )
    tfs.Scale.__init__.__defaults__ = (
        tfs.Scale.__init__.__defaults__[:2]
        + (th.get_default_dtype(),)
        + tfs.Scale.__init__.__defaults__[3:]
    )
    tfs.Rotate.__init__.__defaults__ = (
        (th.get_default_dtype(),)
        + tfs.Rotate.__init__.__defaults__[1:]
    )
    tfs.RotateAxisAngle.__init__.__defaults__ = (
        tfs.Scale.__init__.__defaults__[:2]
        + (th.get_default_dtype(),)
        + tfs.Rotate.__init__.__defaults__[3:]
    )

@facebook-github-bot
Copy link
Contributor

@bottler has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@bottler
Copy link
Contributor

bottler commented Mar 29, 2022

Thank you for fixing this!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants