@@ -16,53 +16,56 @@ def error(self, cuda_t, cpu_t):
16
16
def test_fixed_point (self ):
17
17
for wl , fl in [(5 , 4 ), (3 , 2 )]:
18
18
for rounding in ["nearest" ]:
19
- t_max = 1 - (2 ** (- fl ))
20
- to_quantize_cuda = torch .linspace (
21
- - t_max , t_max , steps = 20 , device = "cuda"
22
- )
23
- to_quantize_cpu = to_quantize_cuda .clone ().to ("cpu" )
24
- fixed_quantized_cuda = fixed_point_quantize (
25
- to_quantize_cuda , wl = wl , fl = fl , rounding = rounding
26
- )
27
- fixed_quantized_cpu = fixed_point_quantize (
28
- to_quantize_cpu , wl = wl , fl = fl , rounding = rounding
29
- )
30
- mse = self .error (fixed_quantized_cuda , fixed_quantized_cpu )
31
- self .assertTrue (mse < 1e-15 )
32
- # self.assertTrue(torch.eq(fixed_quantized_cuda.cpu(), fixed_quantized_cpu).all().item())
33
-
34
- def test_block_floating_point (self ):
35
- for wl in [5 , 3 ]:
36
- for rounding in ["nearest" ]:
37
- for dim in [- 1 , 0 , 1 ]:
38
- t_max = 1 - (2 ** (- 4 ))
19
+ for device in [("cuda:%d" % d ) for d in range (torch .cuda .device_count ())]:
20
+ t_max = 1 - (2 ** (- fl ))
39
21
to_quantize_cuda = torch .linspace (
40
- - t_max , t_max , steps = 20 , device = "cuda"
22
+ - t_max , t_max , steps = 1200 , device = torch . device ( device )
41
23
)
42
24
to_quantize_cpu = to_quantize_cuda .clone ().to ("cpu" )
43
- block_quantized_cuda = block_quantize (
44
- to_quantize_cuda , wl = wl , rounding = rounding
25
+ fixed_quantized_cuda = fixed_point_quantize (
26
+ to_quantize_cuda , wl = wl , fl = fl , rounding = rounding
45
27
)
46
- block_quantized_cpu = block_quantize (
47
- to_quantize_cpu , wl = wl , rounding = rounding
28
+ fixed_quantized_cpu = fixed_point_quantize (
29
+ to_quantize_cpu , wl = wl , fl = fl , rounding = rounding
48
30
)
49
- mse = self .error (block_quantized_cuda , block_quantized_cpu )
50
- self .assertTrue (mse < 1e-15 )
51
- # self.assertTrue(torch.eq(block_quantized_cuda.cpu(), block_quantized_cpu).all().item())
31
+ mse = self .error (fixed_quantized_cuda , fixed_quantized_cpu )
32
+ self .assertTrue (mse < 1e-15 , msg = "%.2e MSE on device '%s'" % (mse , device ))
33
+ # self.assertTrue(torch.eq(fixed_quantized_cuda.cpu(), fixed_quantized_cpu).all().item())
34
+
35
+ def test_block_floating_point (self ):
36
+ for wl in [5 , 3 ]:
37
+ for rounding in ["nearest" ]:
38
+ for dim in [- 1 , 0 , 1 ]:
39
+ for device in [("cuda:%d" % d ) for d in range (torch .cuda .device_count ())]:
40
+ t_max = 1 - (2 ** (- 4 ))
41
+ to_quantize_cuda = torch .linspace (
42
+ - t_max , t_max , steps = 1200 , device = torch .device (device )
43
+ )
44
+ to_quantize_cpu = to_quantize_cuda .clone ().to ("cpu" )
45
+ block_quantized_cuda = block_quantize (
46
+ to_quantize_cuda , wl = wl , rounding = rounding
47
+ )
48
+ block_quantized_cpu = block_quantize (
49
+ to_quantize_cpu , wl = wl , rounding = rounding
50
+ )
51
+ mse = self .error (block_quantized_cuda , block_quantized_cpu )
52
+ self .assertTrue (mse < 1e-15 , msg = "%.2e MSE on device '%s'" % (mse , device ))
53
+ # self.assertTrue(torch.eq(block_quantized_cuda.cpu(), block_quantized_cpu).all().item())
52
54
53
55
def test_floating_point (self ):
54
56
for man , exp in [(2 , 5 ), (6 , 9 )]:
55
57
for rounding in ["nearest" ]:
56
- to_quantize_cuda = torch .rand (20 ).cuda ()
57
- to_quantize_cpu = to_quantize_cuda .clone ().to ("cpu" )
58
- float_quantized_cuda = float_quantize (
59
- to_quantize_cuda , man = man , exp = exp , rounding = rounding
60
- )
61
- float_quantized_cpu = float_quantize (
62
- to_quantize_cpu , man = man , exp = exp , rounding = rounding
63
- )
64
- mse = self .error (float_quantized_cuda , float_quantized_cpu )
65
- self .assertTrue (mse < 1e-15 )
58
+ for device in [("cuda:%d" % d ) for d in range (torch .cuda .device_count ())]:
59
+ to_quantize_cuda = torch .rand (1200 ).to (torch .device (device ))
60
+ to_quantize_cpu = to_quantize_cuda .clone ().to ("cpu" )
61
+ float_quantized_cuda = float_quantize (
62
+ to_quantize_cuda , man = man , exp = exp , rounding = rounding
63
+ )
64
+ float_quantized_cpu = float_quantize (
65
+ to_quantize_cpu , man = man , exp = exp , rounding = rounding
66
+ )
67
+ mse = self .error (float_quantized_cuda , float_quantized_cpu )
68
+ self .assertTrue (mse < 1e-15 , msg = "%.2e MSE on device '%s'" % (mse , device ))
66
69
67
70
68
71
if __name__ == "__main__" :
0 commit comments