Skip to content

Commit 209ab79

Browse files
committed
Add multi-GPU unit tests
If multiple GPUs are detected, conducts tests on each device. This could detect bugs relating to cudaSetDevice(). Additionary, the test data size is enlarged, as CUDA memory errors tend to be latent with small data.
1 parent b9d790c commit 209ab79

File tree

1 file changed

+41
-38
lines changed

1 file changed

+41
-38
lines changed

test/test_device.py

+41-38
Original file line numberDiff line numberDiff line change
@@ -16,53 +16,56 @@ def error(self, cuda_t, cpu_t):
1616
def test_fixed_point(self):
1717
for wl, fl in [(5, 4), (3, 2)]:
1818
for rounding in ["nearest"]:
19-
t_max = 1 - (2 ** (-fl))
20-
to_quantize_cuda = torch.linspace(
21-
-t_max, t_max, steps=20, device="cuda"
22-
)
23-
to_quantize_cpu = to_quantize_cuda.clone().to("cpu")
24-
fixed_quantized_cuda = fixed_point_quantize(
25-
to_quantize_cuda, wl=wl, fl=fl, rounding=rounding
26-
)
27-
fixed_quantized_cpu = fixed_point_quantize(
28-
to_quantize_cpu, wl=wl, fl=fl, rounding=rounding
29-
)
30-
mse = self.error(fixed_quantized_cuda, fixed_quantized_cpu)
31-
self.assertTrue(mse < 1e-15)
32-
# self.assertTrue(torch.eq(fixed_quantized_cuda.cpu(), fixed_quantized_cpu).all().item())
33-
34-
def test_block_floating_point(self):
35-
for wl in [5, 3]:
36-
for rounding in ["nearest"]:
37-
for dim in [-1, 0, 1]:
38-
t_max = 1 - (2 ** (-4))
19+
for device in [("cuda:%d" % d) for d in range(torch.cuda.device_count())]:
20+
t_max = 1 - (2 ** (-fl))
3921
to_quantize_cuda = torch.linspace(
40-
-t_max, t_max, steps=20, device="cuda"
22+
-t_max, t_max, steps=1200, device=torch.device(device)
4123
)
4224
to_quantize_cpu = to_quantize_cuda.clone().to("cpu")
43-
block_quantized_cuda = block_quantize(
44-
to_quantize_cuda, wl=wl, rounding=rounding
25+
fixed_quantized_cuda = fixed_point_quantize(
26+
to_quantize_cuda, wl=wl, fl=fl, rounding=rounding
4527
)
46-
block_quantized_cpu = block_quantize(
47-
to_quantize_cpu, wl=wl, rounding=rounding
28+
fixed_quantized_cpu = fixed_point_quantize(
29+
to_quantize_cpu, wl=wl, fl=fl, rounding=rounding
4830
)
49-
mse = self.error(block_quantized_cuda, block_quantized_cpu)
50-
self.assertTrue(mse < 1e-15)
51-
# self.assertTrue(torch.eq(block_quantized_cuda.cpu(), block_quantized_cpu).all().item())
31+
mse = self.error(fixed_quantized_cuda, fixed_quantized_cpu)
32+
self.assertTrue(mse < 1e-15, msg="%.2e MSE on device '%s'" % (mse, device))
33+
# self.assertTrue(torch.eq(fixed_quantized_cuda.cpu(), fixed_quantized_cpu).all().item())
34+
35+
def test_block_floating_point(self):
36+
for wl in [5, 3]:
37+
for rounding in ["nearest"]:
38+
for dim in [-1, 0, 1]:
39+
for device in [("cuda:%d" % d) for d in range(torch.cuda.device_count())]:
40+
t_max = 1 - (2 ** (-4))
41+
to_quantize_cuda = torch.linspace(
42+
-t_max, t_max, steps=1200, device=torch.device(device)
43+
)
44+
to_quantize_cpu = to_quantize_cuda.clone().to("cpu")
45+
block_quantized_cuda = block_quantize(
46+
to_quantize_cuda, wl=wl, rounding=rounding
47+
)
48+
block_quantized_cpu = block_quantize(
49+
to_quantize_cpu, wl=wl, rounding=rounding
50+
)
51+
mse = self.error(block_quantized_cuda, block_quantized_cpu)
52+
self.assertTrue(mse < 1e-15, msg="%.2e MSE on device '%s'" % (mse, device))
53+
# self.assertTrue(torch.eq(block_quantized_cuda.cpu(), block_quantized_cpu).all().item())
5254

5355
def test_floating_point(self):
5456
for man, exp in [(2, 5), (6, 9)]:
5557
for rounding in ["nearest"]:
56-
to_quantize_cuda = torch.rand(20).cuda()
57-
to_quantize_cpu = to_quantize_cuda.clone().to("cpu")
58-
float_quantized_cuda = float_quantize(
59-
to_quantize_cuda, man=man, exp=exp, rounding=rounding
60-
)
61-
float_quantized_cpu = float_quantize(
62-
to_quantize_cpu, man=man, exp=exp, rounding=rounding
63-
)
64-
mse = self.error(float_quantized_cuda, float_quantized_cpu)
65-
self.assertTrue(mse < 1e-15)
58+
for device in [("cuda:%d" % d) for d in range(torch.cuda.device_count())]:
59+
to_quantize_cuda = torch.rand(1200).to(torch.device(device))
60+
to_quantize_cpu = to_quantize_cuda.clone().to("cpu")
61+
float_quantized_cuda = float_quantize(
62+
to_quantize_cuda, man=man, exp=exp, rounding=rounding
63+
)
64+
float_quantized_cpu = float_quantize(
65+
to_quantize_cpu, man=man, exp=exp, rounding=rounding
66+
)
67+
mse = self.error(float_quantized_cuda, float_quantized_cpu)
68+
self.assertTrue(mse < 1e-15, msg="%.2e MSE on device '%s'" % (mse, device))
6669

6770

6871
if __name__ == "__main__":

0 commit comments

Comments
 (0)