Skip to content

Commit 8d10856

Browse files
committed
Add Transform3d dtype propagation test
1 parent 97894fb commit 8d10856

File tree

1 file changed

+28
-0
lines changed

1 file changed

+28
-0
lines changed

tests/test_transforms.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,34 @@ def test_to(self):
8787
t = t.cuda()
8888
t = t.cpu()
8989

90+
def test_dtype_propagation(self):
91+
"""
92+
Check that a given dtype is correctly passed along to child
93+
transformations.
94+
"""
95+
# Use at least two dtypes so we avoid only testing on the
96+
# default dtype.
97+
for dtype in [torch.float32, torch.float64]:
98+
R = torch.tensor(
99+
[[0.0, 1.0, 0.0], [0.0, 0.0, 1.0], [1.0, 0.0, 0.0]],
100+
dtype=dtype,
101+
)
102+
tf = Transform3d(dtype=dtype) \
103+
.rotate(R) \
104+
.rotate_axis_angle(
105+
R[0],
106+
'X',
107+
) \
108+
.translate(3, 2, 1) \
109+
.scale(0.5)
110+
111+
self.assertEqual(tf.dtype, dtype)
112+
for inner_tf in tf._transforms:
113+
self.assertEqual(inner_tf.dtype, dtype)
114+
115+
transformed = tf.transform_points(R)
116+
self.assertEqual(transformed.dtype, dtype)
117+
90118
def test_clone(self):
91119
"""
92120
Check that cloned transformations contain different _matrix objects.

0 commit comments

Comments
 (0)