@@ -390,16 +390,16 @@ def transform_normals(self, normals) -> torch.Tensor:
390
390
return normals_out
391
391
392
392
def translate (self , * args , ** kwargs ) -> "Transform3d" :
393
- return self .compose (Translate (device = self .device , * args , ** kwargs ))
393
+ return self .compose (Translate (device = self .device , dtype = self . dtype , * args , ** kwargs ))
394
394
395
395
def scale (self , * args , ** kwargs ) -> "Transform3d" :
396
- return self .compose (Scale (device = self .device , * args , ** kwargs ))
396
+ return self .compose (Scale (device = self .device , dtype = self . dtype , * args , ** kwargs ))
397
397
398
398
def rotate (self , * args , ** kwargs ) -> "Transform3d" :
399
- return self .compose (Rotate (device = self .device , * args , ** kwargs ))
399
+ return self .compose (Rotate (device = self .device , dtype = self . dtype , * args , ** kwargs ))
400
400
401
401
def rotate_axis_angle (self , * args , ** kwargs ) -> "Transform3d" :
402
- return self .compose (RotateAxisAngle (device = self .device , * args , ** kwargs ))
402
+ return self .compose (RotateAxisAngle (device = self .device , dtype = self . dtype , * args , ** kwargs ))
403
403
404
404
def clone (self ) -> "Transform3d" :
405
405
"""
@@ -488,7 +488,7 @@ def __init__(
488
488
- A 1D torch tensor
489
489
"""
490
490
xyz = _handle_input (x , y , z , dtype , device , "Translate" )
491
- super ().__init__ (device = xyz .device )
491
+ super ().__init__ (device = xyz .device , dtype = dtype )
492
492
N = xyz .shape [0 ]
493
493
494
494
mat = torch .eye (4 , dtype = dtype , device = self .device )
@@ -532,7 +532,7 @@ def __init__(
532
532
- 1D torch tensor
533
533
"""
534
534
xyz = _handle_input (x , y , z , dtype , device , "scale" , allow_singleton = True )
535
- super ().__init__ (device = xyz .device )
535
+ super ().__init__ (device = xyz .device , dtype = dtype )
536
536
N = xyz .shape [0 ]
537
537
538
538
# TODO: Can we do this all in one go somehow?
@@ -571,7 +571,7 @@ def __init__(
571
571
572
572
"""
573
573
device_ = get_device (R , device )
574
- super ().__init__ (device = device_ )
574
+ super ().__init__ (device = device_ , dtype = dtype )
575
575
if R .dim () == 2 :
576
576
R = R [None ]
577
577
if R .shape [- 2 :] != (3 , 3 ):
@@ -629,7 +629,7 @@ def __init__(
629
629
# is for transforming column vectors. Therefore we transpose this matrix.
630
630
# R will always be of shape (N, 3, 3)
631
631
R = _axis_angle_rotation (axis , angle ).transpose (1 , 2 )
632
- super ().__init__ (device = angle .device , R = R )
632
+ super ().__init__ (device = angle .device , R = R , dtype = dtype )
633
633
634
634
635
635
def _handle_coord (c , dtype : torch .dtype , device : torch .device ) -> torch .Tensor :
@@ -646,8 +646,8 @@ def _handle_coord(c, dtype: torch.dtype, device: torch.device) -> torch.Tensor:
646
646
c = torch .tensor (c , dtype = dtype , device = device )
647
647
if c .dim () == 0 :
648
648
c = c .view (1 )
649
- if c .device != device :
650
- c = c .to (device = device )
649
+ if c .device != device or c . dtype != dtype :
650
+ c = c .to (device = device , dtype = dtype )
651
651
return c
652
652
653
653
@@ -696,7 +696,7 @@ def _handle_input(
696
696
if y is not None or z is not None :
697
697
msg = "Expected y and z to be None (in %s)" % name
698
698
raise ValueError (msg )
699
- return x .to (device = device_ )
699
+ return x .to (device = device_ , dtype = dtype )
700
700
701
701
if allow_singleton and y is None and z is None :
702
702
y = x
0 commit comments