Skip to content

Commit 5e904ef

Browse files
committed
Fix dtype propagation
Sometimes, the dtype argument was simply not propagated, sometimes a tensor was not converted even though it had a dtype different from the desired one.
1 parent 8d10856 commit 5e904ef

File tree

1 file changed

+11
-11
lines changed

1 file changed

+11
-11
lines changed

pytorch3d/transforms/transform3d.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -390,16 +390,16 @@ def transform_normals(self, normals) -> torch.Tensor:
390390
return normals_out
391391

392392
def translate(self, *args, **kwargs) -> "Transform3d":
393-
return self.compose(Translate(device=self.device, *args, **kwargs))
393+
return self.compose(Translate(device=self.device, dtype=self.dtype, *args, **kwargs))
394394

395395
def scale(self, *args, **kwargs) -> "Transform3d":
396-
return self.compose(Scale(device=self.device, *args, **kwargs))
396+
return self.compose(Scale(device=self.device, dtype=self.dtype, *args, **kwargs))
397397

398398
def rotate(self, *args, **kwargs) -> "Transform3d":
399-
return self.compose(Rotate(device=self.device, *args, **kwargs))
399+
return self.compose(Rotate(device=self.device, dtype=self.dtype, *args, **kwargs))
400400

401401
def rotate_axis_angle(self, *args, **kwargs) -> "Transform3d":
402-
return self.compose(RotateAxisAngle(device=self.device, *args, **kwargs))
402+
return self.compose(RotateAxisAngle(device=self.device, dtype=self.dtype, *args, **kwargs))
403403

404404
def clone(self) -> "Transform3d":
405405
"""
@@ -488,7 +488,7 @@ def __init__(
488488
- A 1D torch tensor
489489
"""
490490
xyz = _handle_input(x, y, z, dtype, device, "Translate")
491-
super().__init__(device=xyz.device)
491+
super().__init__(device=xyz.device, dtype=dtype)
492492
N = xyz.shape[0]
493493

494494
mat = torch.eye(4, dtype=dtype, device=self.device)
@@ -532,7 +532,7 @@ def __init__(
532532
- 1D torch tensor
533533
"""
534534
xyz = _handle_input(x, y, z, dtype, device, "scale", allow_singleton=True)
535-
super().__init__(device=xyz.device)
535+
super().__init__(device=xyz.device, dtype=dtype)
536536
N = xyz.shape[0]
537537

538538
# TODO: Can we do this all in one go somehow?
@@ -571,7 +571,7 @@ def __init__(
571571
572572
"""
573573
device_ = get_device(R, device)
574-
super().__init__(device=device_)
574+
super().__init__(device=device_, dtype=dtype)
575575
if R.dim() == 2:
576576
R = R[None]
577577
if R.shape[-2:] != (3, 3):
@@ -629,7 +629,7 @@ def __init__(
629629
# is for transforming column vectors. Therefore we transpose this matrix.
630630
# R will always be of shape (N, 3, 3)
631631
R = _axis_angle_rotation(axis, angle).transpose(1, 2)
632-
super().__init__(device=angle.device, R=R)
632+
super().__init__(device=angle.device, R=R, dtype=dtype)
633633

634634

635635
def _handle_coord(c, dtype: torch.dtype, device: torch.device) -> torch.Tensor:
@@ -646,8 +646,8 @@ def _handle_coord(c, dtype: torch.dtype, device: torch.device) -> torch.Tensor:
646646
c = torch.tensor(c, dtype=dtype, device=device)
647647
if c.dim() == 0:
648648
c = c.view(1)
649-
if c.device != device:
650-
c = c.to(device=device)
649+
if c.device != device or c.dtype != dtype:
650+
c = c.to(device=device, dtype=dtype)
651651
return c
652652

653653

@@ -696,7 +696,7 @@ def _handle_input(
696696
if y is not None or z is not None:
697697
msg = "Expected y and z to be None (in %s)" % name
698698
raise ValueError(msg)
699-
return x.to(device=device_)
699+
return x.to(device=device_, dtype=dtype)
700700

701701
if allow_singleton and y is None and z is None:
702702
y = x

0 commit comments

Comments
 (0)