Skip to content

Commit

Permalink
Make Transform3d.to() not ignore dtype
Browse files Browse the repository at this point in the history
Summary: Make Transform3d.to() not ignore a different dtype when device is the same and no copy is requested. Fix other methods where dtype is ignored.

Reviewed By: nikhilaravi

Differential Revision: D28981171

fbshipit-source-id: 4528e6092f4a693aecbe8131ede985fca84e84cf
  • Loading branch information
patricklabatut authored and facebook-github-bot committed Jun 9, 2021
1 parent 626bf3f commit 44508ed
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 10 deletions.
27 changes: 17 additions & 10 deletions pytorch3d/transforms/transform3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,13 +162,15 @@ def __init__(
raise ValueError(
'"matrix" has to be a tensor of shape (minibatch, 4, 4)'
)
# set the device from matrix
# set dtype and device from matrix
dtype = matrix.dtype
device = matrix.device
self._matrix = matrix.view(-1, 4, 4)

self._transforms = [] # store transforms to compose
self._lu = None
self.device = make_device(device)
self.dtype = dtype

def __len__(self):
return self.get_matrix().shape[0]
Expand Down Expand Up @@ -200,7 +202,7 @@ def compose(self, *others):
Returns:
A new Transform3d with the stored transforms
"""
out = Transform3d(device=self.device)
out = Transform3d(dtype=self.dtype, device=self.device)
out._matrix = self._matrix.clone()
for other in others:
if not isinstance(other, Transform3d):
Expand Down Expand Up @@ -259,7 +261,7 @@ def inverse(self, invert_composed: bool = False):
transformation.
"""

tinv = Transform3d(device=self.device)
tinv = Transform3d(dtype=self.dtype, device=self.device)

if invert_composed:
# first compose then invert
Expand All @@ -278,7 +280,7 @@ def inverse(self, invert_composed: bool = False):
# right-multiplies by the inverse of self._matrix
# at the end of the composition.
tinv._transforms = [t.inverse() for t in reversed(self._transforms)]
last = Transform3d(device=self.device)
last = Transform3d(dtype=self.dtype, device=self.device)
last._matrix = i_matrix
tinv._transforms.append(last)
else:
Expand All @@ -291,7 +293,7 @@ def inverse(self, invert_composed: bool = False):
def stack(self, *others):
transforms = [self] + list(others)
matrix = torch.cat([t._matrix for t in transforms], dim=0)
out = Transform3d()
out = Transform3d(dtype=self.dtype, device=self.device)
out._matrix = matrix
return out

Expand Down Expand Up @@ -392,7 +394,7 @@ def clone(self):
Returns:
new Transforms object.
"""
other = Transform3d(device=self.device)
other = Transform3d(dtype=self.dtype, device=self.device)
if self._lu is not None:
other._lu = [elem.clone() for elem in self._lu]
other._matrix = self._matrix.clone()
Expand Down Expand Up @@ -422,17 +424,22 @@ def to(
Transform3d object.
"""
device_ = make_device(device)
if not copy and self.device == device_:
dtype_ = self.dtype if dtype is None else dtype
skip_to = self.device == device_ and self.dtype == dtype_

if not copy and skip_to:
return self

other = self.clone()
if self.device == device_:

if skip_to:
return other

other.device = device_
other._matrix = self._matrix.to(device=device_, dtype=dtype)
other.dtype = dtype_
other._matrix = other._matrix.to(device=device_, dtype=dtype_)
other._transforms = [
t.to(device_, copy=copy, dtype=dtype) for t in other._transforms
t.to(device_, copy=copy, dtype=dtype_) for t in other._transforms
]
return other

Expand Down
22 changes: 22 additions & 0 deletions tests/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,23 +28,45 @@ def test_to(self):
cpu_t = t.to("cpu")
self.assertEqual(cpu_device, cpu_t.device)
self.assertEqual(cpu_device, t.device)
self.assertEqual(torch.float32, cpu_t.dtype)
self.assertEqual(torch.float32, t.dtype)
self.assertIs(t, cpu_t)

cpu_t = t.to(cpu_device)
self.assertEqual(cpu_device, cpu_t.device)
self.assertEqual(cpu_device, t.device)
self.assertEqual(torch.float32, cpu_t.dtype)
self.assertEqual(torch.float32, t.dtype)
self.assertIs(t, cpu_t)

cpu_t = t.to(dtype=torch.float64, device=cpu_device)
self.assertEqual(cpu_device, cpu_t.device)
self.assertEqual(cpu_device, t.device)
self.assertEqual(torch.float64, cpu_t.dtype)
self.assertEqual(torch.float32, t.dtype)
self.assertIsNot(t, cpu_t)

cuda_device = torch.device("cuda")

cuda_t = t.to("cuda")
self.assertEqual(cuda_device, cuda_t.device)
self.assertEqual(cpu_device, t.device)
self.assertEqual(torch.float32, cuda_t.dtype)
self.assertEqual(torch.float32, t.dtype)
self.assertIsNot(t, cuda_t)

cuda_t = t.to(cuda_device)
self.assertEqual(cuda_device, cuda_t.device)
self.assertEqual(cpu_device, t.device)
self.assertEqual(torch.float32, cuda_t.dtype)
self.assertEqual(torch.float32, t.dtype)
self.assertIsNot(t, cuda_t)

cuda_t = t.to(dtype=torch.float64, device=cuda_device)
self.assertEqual(cuda_device, cuda_t.device)
self.assertEqual(cpu_device, t.device)
self.assertEqual(torch.float64, cuda_t.dtype)
self.assertEqual(torch.float32, t.dtype)
self.assertIsNot(t, cuda_t)

cpu_points = torch.rand(9, 3)
Expand Down

0 comments on commit 44508ed

Please sign in to comment.