6
6
7
7
from torchao .dtypes .uintx .Uintx import to_uintx
8
8
from torchao .quantization .quant_api import quantize_ , uintx_weight_only
9
- from torchao .utils import TORCH_VERSION_AT_LEAST_2_5
9
+ from torchao .utils import (
10
+ TORCH_VERSION_AT_LEAST_2_3 ,
11
+ TORCH_VERSION_AT_LEAST_2_5 ,
12
+ )
10
13
11
14
from torchao .quantization .quant_primitives import (
12
15
MappingType ,
16
19
dequantize_affine ,
17
20
)
18
21
19
- bit_widths = (1 , 2 , 3 , 4 , 5 , 6 , 7 )
22
+ # torch.uintx dtypes are introduced in 2.3
23
+ if TORCH_VERSION_AT_LEAST_2_3 :
24
+ dtypes = (torch .uint1 , torch .uint2 , torch .uint3 , torch .uint4 , torch .uint5 , torch .uint6 , torch .uint7 )
25
+ else :
26
+ dtypes = ()
27
+
20
28
group_sizes = [32 , 64 , 128 ]
21
29
devices = ["cpu" , "cuda" ]
22
30
@pytest .fixture (autouse = True )
@@ -36,72 +44,116 @@ def __init__(self, scale, device):
36
44
def forward (self , x ):
37
45
return self .net (x )
38
46
39
- @pytest .mark .parametrize ("bit_width " , bit_widths )
47
+ @pytest .mark .parametrize ("dtype " , dtypes )
40
48
@pytest .mark .parametrize ("group_size" , group_sizes )
41
49
@pytest .mark .skipif (not torch .cuda .is_available (), reason = "CUDA not available" )
42
50
@pytest .mark .skipif (not TORCH_VERSION_AT_LEAST_2_5 , reason = "only works with fix in the nightly build" )
43
- def test_uintx_quant_on_cpu_then_move_to_cuda (bit_width , group_size ):
51
+ def test_uintx_quant_on_cpu_then_move_to_cuda (dtype , group_size ):
44
52
scale = 512
45
53
fp16_mod_on_cpu = Linear16 (scale , "cpu" )
46
- quantize_ (fp16_mod_on_cpu , uintx_weight_only (bit_width , group_size = group_size ))
54
+ quantize_ (fp16_mod_on_cpu , uintx_weight_only (dtype , group_size = group_size ))
47
55
test_input_on_cpu = torch .randn (scale * 2 , dtype = torch .float16 , device = "cpu" )
48
56
output_on_cpu = fp16_mod_on_cpu (test_input_on_cpu )
49
57
fp16_mod_on_cuda = fp16_mod_on_cpu .to ("cuda" )
50
58
test_input_on_cuda = test_input_on_cpu .to ("cuda" )
51
59
output_on_cuda = fp16_mod_on_cuda (test_input_on_cuda )
52
60
assert torch .allclose (output_on_cpu , output_on_cuda .cpu (), atol = 1.0e-3 ), "The output of the model on CPU and CUDA should be close"
53
61
54
- @pytest .mark .parametrize ("bit_width " , bit_widths )
62
+ @pytest .mark .parametrize ("dtype " , dtypes )
55
63
@pytest .mark .parametrize ("group_size" , group_sizes )
56
64
@pytest .mark .parametrize ("device" , devices )
57
65
@pytest .mark .skipif (not torch .cuda .is_available (), reason = "CUDA not available" )
58
66
@pytest .mark .skipif (not TORCH_VERSION_AT_LEAST_2_5 , reason = "only works with fix in the nightly build" )
59
- def test_uintx_weight_only_model_quant (bit_width , group_size , device ):
67
+ def test_uintx_weight_only_model_quant (dtype , group_size , device ):
60
68
scale = 512
61
69
fp16 = Linear16 (scale , device )
62
- quantize_ (fp16 , uintx_weight_only (bit_width , group_size = group_size ))
70
+ quantize_ (fp16 , uintx_weight_only (dtype , group_size = group_size ))
63
71
uintx = torch .compile (fp16 , fullgraph = True )
64
72
test_input = torch .randn (scale * 2 , dtype = torch .float16 , device = device )
65
73
output = uintx .forward (test_input )
66
74
assert output != None , "model quantization failed"
67
75
68
- @pytest .mark .parametrize ("bit_width " , bit_widths )
76
+ @pytest .mark .parametrize ("dtype " , dtypes )
69
77
@pytest .mark .parametrize ("group_size" , group_sizes )
70
78
@pytest .mark .parametrize ("device" , devices )
71
79
@pytest .mark .skipif (not torch .cuda .is_available (), reason = "CUDA not available" )
72
80
@pytest .mark .skipif (not TORCH_VERSION_AT_LEAST_2_5 , reason = "only works with fix in the nightly build" )
73
- def test_uintx_weight_only_quant (bit_width , group_size , device ):
81
+ def test_uintx_weight_only_quant (dtype , group_size , device ):
74
82
input_float = torch .randn ((1 , 256 ), dtype = torch .float16 , device = device )
75
83
mapping_type = MappingType .SYMMETRIC
76
- quant_min = 0
77
- quant_max = 2 ** bit_width - 1
78
84
eps = torch .finfo (torch .float32 ).eps
79
85
zero_point_dtype = torch .int32
80
86
zero_point_domain = ZeroPointDomain .INT
81
- target_dtype = torch .uint8
82
87
block_size = (1 , group_size )
83
88
84
89
scale , zero_point = choose_qparams_affine (
85
90
input_float , mapping_type , block_size ,
86
- target_dtype , quant_min , quant_max , eps , torch .float32 ,
87
- zero_point_dtype , True , zero_point_domain
91
+ dtype , eps = eps , scale_dtype = torch .float32 ,
92
+ zero_point_dtype = zero_point_dtype , preserve_zero = True , zero_point_domain = zero_point_domain
88
93
)
89
94
90
95
aqt = quantize_affine (
91
96
input_float , block_size , scale ,
92
- zero_point , target_dtype ,
93
- quant_min = quant_min ,
94
- quant_max = quant_max ,
95
- zero_point_domain = zero_point_domain
97
+ zero_point , dtype ,
98
+ zero_point_domain = zero_point_domain
96
99
)
100
+ # Note: output will be uint8 tensor for sub byte tensors for now
97
101
98
- q = to_uintx (aqt , bit_width , - 1 )
102
+ q = to_uintx (aqt , dtype , - 1 )
99
103
assert q != None , "quantization failed"
100
104
deqaunt = dequantize_affine (
101
105
q , block_size , scale ,
102
- zero_point , target_dtype ,
103
- quant_min = quant_min ,
104
- quant_max = quant_max ,
105
- zero_point_domain = zero_point_domain
106
+ zero_point , dtype ,
107
+ zero_point_domain = zero_point_domain
106
108
)
107
109
assert deqaunt != None , "deqauntization failed"
110
+
111
+
112
+ @pytest .mark .parametrize ("dtype" , dtypes )
113
+ @pytest .mark .skipif (not torch .cuda .is_available (), reason = "Need CUDA available" )
114
+ @pytest .mark .skipif (not TORCH_VERSION_AT_LEAST_2_3 , reason = "sub byte dtype requires torch 2.3+" )
115
+ def test_uintx_target_dtype (dtype ):
116
+ from torchao .quantization .quant_api import uintx_weight_only
117
+ l = torch .nn .Linear (128 , 256 , dtype = torch .bfloat16 , device = "cuda" )
118
+ # make sure it runs
119
+ uintx_weight_only (dtype )(l )
120
+ l (torch .randn (1 , 128 , dtype = torch .bfloat16 , device = "cuda" ))
121
+
122
+ @pytest .mark .parametrize ("dtype" , dtypes )
123
+ @pytest .mark .skipif (not torch .cuda .is_available (), reason = "Need CUDA available" )
124
+ @pytest .mark .skipif (not TORCH_VERSION_AT_LEAST_2_5 , reason = "torch.compile without unwrap_tensor_subclass requires torch 2.5+" )
125
+ def test_uintx_target_dtype_compile (dtype ):
126
+ from torchao .quantization .quant_api import uintx_weight_only
127
+ l = torch .nn .Linear (128 , 256 , dtype = torch .bfloat16 , device = "cuda" )
128
+ # make sure it runs
129
+ uintx_weight_only (dtype )(l )
130
+ l = torch .compile (l )
131
+ l (torch .randn (1 , 128 , dtype = torch .bfloat16 , device = "cuda" ))
132
+
133
+
134
+ @pytest .mark .parametrize ("dtype" , dtypes )
135
+ @pytest .mark .skipif (not torch .cuda .is_available (), reason = "Need CUDA available" )
136
+ @pytest .mark .skipif (not TORCH_VERSION_AT_LEAST_2_3 , reason = "sub byte dtype requires torch 2.3+" )
137
+ def test_uintx_model_size (dtype ):
138
+ from torchao .quantization .quant_api import uintx_weight_only
139
+ from torchao .utils import get_model_size_in_bytes
140
+ # scale size = 1/64 * 2 bytes = 1/32 bytes
141
+ # zero_point size = 1/64 * 4 bytes = 1/16 bytes
142
+ # dtype data size = 1 * bit_width/8 = bit_width/8 bytes
143
+ _dtype_to_ratio = {
144
+ torch .uint1 : (1 / 8 + 1 / 16 + 1 / 32 ) / 2 ,
145
+ torch .uint2 : (2 / 8 + 1 / 16 + 1 / 32 ) / 2 ,
146
+ torch .uint3 : (3 / 8 + 1 / 16 + 1 / 32 ) / 2 ,
147
+ torch .uint4 : (4 / 8 + 1 / 16 + 1 / 32 ) / 2 ,
148
+ torch .uint5 : (5 / 8 + 1 / 16 + 1 / 32 ) / 2 ,
149
+ torch .uint6 : (6 / 8 + 1 / 16 + 1 / 32 ) / 2 ,
150
+ torch .uint7 : (7 / 8 + 1 / 16 + 1 / 32 ) / 2 ,
151
+ }
152
+ l = torch .nn .Sequential (
153
+ torch .nn .Linear (128 , 256 , bias = False , dtype = torch .bfloat16 , device = "cuda" )
154
+ )
155
+ bf16_size = get_model_size_in_bytes (l )
156
+ # make sure it runs
157
+ uintx_weight_only (dtype )(l [0 ])
158
+ quantized_size = get_model_size_in_bytes (l )
159
+ assert bf16_size * _dtype_to_ratio [dtype ] == quantized_size
0 commit comments