Skip to content

Commit 9a56e80

Browse files
authored
Use torch.uint1 to torch.uint7 for Uintx tensor subclass (#672)
Summary: Previously we are using bit_width for uintx quantization, but we can actually use `dtype` directly. But there are still some workaround to convert from torch dtype to bit_width right now, if we want to remove all the hacks, we'd need to support Uintx tensor subclass properly and have `torch.uintx` dispatch to the tensor subclass this is probably not the highest priority for now since good perf is more important. Test Plan: python test/dtypes/test_affine_quantized.py pytest test/dtypes/test_uintx.py Reviewers: Subscribers: Tasks: Tags:
1 parent eaf2908 commit 9a56e80

File tree

6 files changed

+138
-54
lines changed

6 files changed

+138
-54
lines changed

test/dtypes/test_affine_quantized.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,14 @@
99
int8_dynamic_activation_int8_weight,
1010
int8_dynamic_activation_int8_semi_sparse_weight,
1111
)
12+
from torchao.dtypes import (
13+
to_affine_quantized,
14+
)
15+
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5
16+
1217
import torch
1318
import unittest
1419
import tempfile
15-
from torchao.utils import (
16-
TORCH_VERSION_AT_LEAST_2_5,
17-
)
18-
1920

2021
class TestAffineQuantized(TestCase):
2122
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")

test/dtypes/test_uintx.py

Lines changed: 76 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,10 @@
66

77
from torchao.dtypes.uintx.Uintx import to_uintx
88
from torchao.quantization.quant_api import quantize_, uintx_weight_only
9-
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5
9+
from torchao.utils import (
10+
TORCH_VERSION_AT_LEAST_2_3,
11+
TORCH_VERSION_AT_LEAST_2_5,
12+
)
1013

1114
from torchao.quantization.quant_primitives import (
1215
MappingType,
@@ -16,7 +19,12 @@
1619
dequantize_affine,
1720
)
1821

19-
bit_widths = (1, 2, 3, 4, 5, 6, 7)
22+
# torch.uintx dtypes are introduced in 2.3
23+
if TORCH_VERSION_AT_LEAST_2_3:
24+
dtypes = (torch.uint1, torch.uint2, torch.uint3, torch.uint4, torch.uint5, torch.uint6, torch.uint7)
25+
else:
26+
dtypes = ()
27+
2028
group_sizes = [32, 64, 128]
2129
devices = ["cpu", "cuda"]
2230
@pytest.fixture(autouse=True)
@@ -36,72 +44,116 @@ def __init__(self, scale, device):
3644
def forward(self, x):
3745
return self.net(x)
3846

39-
@pytest.mark.parametrize("bit_width", bit_widths)
47+
@pytest.mark.parametrize("dtype", dtypes)
4048
@pytest.mark.parametrize("group_size", group_sizes)
4149
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
4250
@pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_5, reason="only works with fix in the nightly build")
43-
def test_uintx_quant_on_cpu_then_move_to_cuda(bit_width, group_size):
51+
def test_uintx_quant_on_cpu_then_move_to_cuda(dtype, group_size):
4452
scale = 512
4553
fp16_mod_on_cpu = Linear16(scale, "cpu")
46-
quantize_(fp16_mod_on_cpu, uintx_weight_only(bit_width, group_size=group_size))
54+
quantize_(fp16_mod_on_cpu, uintx_weight_only(dtype, group_size=group_size))
4755
test_input_on_cpu = torch.randn(scale*2, dtype=torch.float16, device="cpu")
4856
output_on_cpu = fp16_mod_on_cpu(test_input_on_cpu)
4957
fp16_mod_on_cuda = fp16_mod_on_cpu.to("cuda")
5058
test_input_on_cuda = test_input_on_cpu.to("cuda")
5159
output_on_cuda = fp16_mod_on_cuda(test_input_on_cuda)
5260
assert torch.allclose(output_on_cpu, output_on_cuda.cpu(), atol=1.0e-3), "The output of the model on CPU and CUDA should be close"
5361

54-
@pytest.mark.parametrize("bit_width", bit_widths)
62+
@pytest.mark.parametrize("dtype", dtypes)
5563
@pytest.mark.parametrize("group_size", group_sizes)
5664
@pytest.mark.parametrize("device", devices)
5765
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
5866
@pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_5, reason="only works with fix in the nightly build")
59-
def test_uintx_weight_only_model_quant(bit_width, group_size, device):
67+
def test_uintx_weight_only_model_quant(dtype, group_size, device):
6068
scale = 512
6169
fp16 = Linear16(scale, device)
62-
quantize_(fp16, uintx_weight_only(bit_width, group_size=group_size))
70+
quantize_(fp16, uintx_weight_only(dtype, group_size=group_size))
6371
uintx = torch.compile(fp16, fullgraph=True)
6472
test_input = torch.randn(scale*2, dtype=torch.float16, device=device)
6573
output = uintx.forward(test_input)
6674
assert output != None, "model quantization failed"
6775

68-
@pytest.mark.parametrize("bit_width", bit_widths)
76+
@pytest.mark.parametrize("dtype", dtypes)
6977
@pytest.mark.parametrize("group_size", group_sizes)
7078
@pytest.mark.parametrize("device", devices)
7179
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
7280
@pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_5, reason="only works with fix in the nightly build")
73-
def test_uintx_weight_only_quant(bit_width, group_size, device):
81+
def test_uintx_weight_only_quant(dtype, group_size, device):
7482
input_float = torch.randn((1, 256), dtype=torch.float16, device = device)
7583
mapping_type = MappingType.SYMMETRIC
76-
quant_min = 0
77-
quant_max = 2 ** bit_width - 1
7884
eps = torch.finfo(torch.float32).eps
7985
zero_point_dtype = torch.int32
8086
zero_point_domain = ZeroPointDomain.INT
81-
target_dtype = torch.uint8
8287
block_size = (1, group_size)
8388

8489
scale, zero_point = choose_qparams_affine(
8590
input_float, mapping_type, block_size,
86-
target_dtype, quant_min, quant_max, eps, torch.float32,
87-
zero_point_dtype, True, zero_point_domain
91+
dtype, eps=eps, scale_dtype=torch.float32,
92+
zero_point_dtype=zero_point_dtype, preserve_zero=True, zero_point_domain=zero_point_domain
8893
)
8994

9095
aqt = quantize_affine(
9196
input_float, block_size, scale,
92-
zero_point, target_dtype,
93-
quant_min = quant_min,
94-
quant_max = quant_max,
95-
zero_point_domain = zero_point_domain
97+
zero_point, dtype,
98+
zero_point_domain=zero_point_domain
9699
)
100+
# Note: output will be uint8 tensor for sub byte tensors for now
97101

98-
q = to_uintx(aqt, bit_width, -1)
102+
q = to_uintx(aqt, dtype, -1)
99103
assert q != None, "quantization failed"
100104
deqaunt = dequantize_affine(
101105
q, block_size, scale,
102-
zero_point, target_dtype,
103-
quant_min = quant_min,
104-
quant_max = quant_max,
105-
zero_point_domain = zero_point_domain
106+
zero_point, dtype,
107+
zero_point_domain=zero_point_domain
106108
)
107109
assert deqaunt != None, "deqauntization failed"
110+
111+
112+
@pytest.mark.parametrize("dtype", dtypes)
113+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available")
114+
@pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_3, reason="sub byte dtype requires torch 2.3+")
115+
def test_uintx_target_dtype(dtype):
116+
from torchao.quantization.quant_api import uintx_weight_only
117+
l = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda")
118+
# make sure it runs
119+
uintx_weight_only(dtype)(l)
120+
l(torch.randn(1, 128, dtype=torch.bfloat16, device="cuda"))
121+
122+
@pytest.mark.parametrize("dtype", dtypes)
123+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available")
124+
@pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_5, reason="torch.compile without unwrap_tensor_subclass requires torch 2.5+")
125+
def test_uintx_target_dtype_compile(dtype):
126+
from torchao.quantization.quant_api import uintx_weight_only
127+
l = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda")
128+
# make sure it runs
129+
uintx_weight_only(dtype)(l)
130+
l = torch.compile(l)
131+
l(torch.randn(1, 128, dtype=torch.bfloat16, device="cuda"))
132+
133+
134+
@pytest.mark.parametrize("dtype", dtypes)
135+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available")
136+
@pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_3, reason="sub byte dtype requires torch 2.3+")
137+
def test_uintx_model_size(dtype):
138+
from torchao.quantization.quant_api import uintx_weight_only
139+
from torchao.utils import get_model_size_in_bytes
140+
# scale size = 1/64 * 2 bytes = 1/32 bytes
141+
# zero_point size = 1/64 * 4 bytes = 1/16 bytes
142+
# dtype data size = 1 * bit_width/8 = bit_width/8 bytes
143+
_dtype_to_ratio = {
144+
torch.uint1: (1/8 + 1/16 + 1/32) / 2,
145+
torch.uint2: (2/8 + 1/16 + 1/32) / 2,
146+
torch.uint3: (3/8 + 1/16 + 1/32) / 2,
147+
torch.uint4: (4/8 + 1/16 + 1/32) / 2,
148+
torch.uint5: (5/8 + 1/16 + 1/32) / 2,
149+
torch.uint6: (6/8 + 1/16 + 1/32) / 2,
150+
torch.uint7: (7/8 + 1/16 + 1/32) / 2,
151+
}
152+
l = torch.nn.Sequential(
153+
torch.nn.Linear(128, 256, bias=False, dtype=torch.bfloat16, device="cuda")
154+
)
155+
bf16_size = get_model_size_in_bytes(l)
156+
# make sure it runs
157+
uintx_weight_only(dtype)(l[0])
158+
quantized_size = get_model_size_in_bytes(l)
159+
assert bf16_size * _dtype_to_ratio[dtype] == quantized_size

torchao/dtypes/affine_quantized_tensor.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636

3737
aten = torch.ops.aten
3838

39+
3940
###############################
4041
# Base Layout Tensor Subclass #
4142
###############################
@@ -198,8 +199,9 @@ def from_float(
198199
use_hqq: bool = False,
199200
):
200201
original_shape = input_float.shape
202+
input_float = layout_type.pre_process(input_float)
201203

202-
if(use_hqq):
204+
if use_hqq:
203205
assert zero_point_domain == ZeroPointDomain.FLOAT and mapping_type == MappingType.ASYMMETRIC and quant_min==0, "Invalid input parameters for HQQ quantization."
204206
nbits = int(math.log2(quant_max + 1))
205207
axis = 1 if (block_size[0]==1) else 0
@@ -208,11 +210,10 @@ def from_float(
208210
device = input_float.device
209211
int_data, scale, zero_point, _ = quantize_affine_hqq(input_float, nbits=nbits, group_size=group_size, axis=axis, compute_dtype=compute_dtype, device=device, verbose=False, raw_output=False)
210212
int_data = int_data.to(target_dtype)
211-
212213
else:
213-
input_float = layout_type.pre_process(input_float)
214214
scale, zero_point = choose_qparams_affine(input_float, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, scale_dtype, zero_point_dtype, preserve_zero, zero_point_domain)
215215
int_data = quantize_affine(input_float, block_size, scale, zero_point, target_dtype, quant_min, quant_max, zero_point_domain)
216+
# Note: output will be uint8 tensor for sub byte tensors for now
216217

217218
int_data = layout_type.post_process(int_data)
218219
layout_tensor_ctr = get_layout_tensor_constructor(type(layout_type))

torchao/dtypes/uintx/Uintx.py

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,30 @@
1111
_dispatch__torch_dispatch__,
1212
)
1313
from torchao.dtypes.affine_quantized_tensor import PlainAQTLayout, register_layout_cls
14-
14+
from torchao.utils import TORCH_VERSION_AT_LEAST_2_3
1515

1616
aten = torch.ops.aten
1717

18+
# Note: Uintx does not work for torch 2.3 and below
19+
_DTYPE_TO_BIT_WIDTH = {}
20+
_BIT_WIDTH_TO_DTYPE = {}
21+
22+
if TORCH_VERSION_AT_LEAST_2_3:
23+
_DTYPE_TO_BIT_WIDTH = {
24+
torch.uint1: 1,
25+
torch.uint2: 2,
26+
torch.uint3: 3,
27+
torch.uint4: 4,
28+
torch.uint5: 5,
29+
torch.uint6: 6,
30+
torch.uint7: 7,
31+
}
32+
33+
_BIT_WIDTH_TO_DTYPE = {v: k for k, v in _DTYPE_TO_BIT_WIDTH.items()}
34+
else:
35+
print("uintx feature need torch 2.3+, please upgrade pytorch")
36+
37+
1838
class UintxTensor(torch.Tensor):
1939
"""
2040
Splits int data into packed shards based on bit size
@@ -90,15 +110,18 @@ def get_plain(self):
90110
def apply_transformation(self, fn):
91111
og = self.get_plain()
92112
new = fn(og)
93-
return self.from_uint8(new, self.bit_width, self.pack_dim)
113+
dtype = _BIT_WIDTH_TO_DTYPE[self.bit_width]
114+
return self.from_uint8(new, dtype, self.pack_dim)
94115

95116
# temporary until kernels on packed tensors are created
96117
def apply_fn_to_shards(self, fn):
97118
new_shards = [fn(shard) for shard in self.get_shards()]
98119
return self.__class__(new_shards, self.packed_shape, self.bit_width, self.pack_dim)
99120

100121
@classmethod
101-
def from_uint8(cls, int_data: torch.Tensor, bit_width, pack_dim: int = -1):
122+
def from_uint8(cls, int_data: torch.Tensor, dtype: torch.dtype, pack_dim: int = -1):
123+
assert dtype in _DTYPE_TO_BIT_WIDTH.keys(), "Expected dtype to be one of {_DTYPE_TO_BIT_WIDTH.keys()}"
124+
bit_width = _DTYPE_TO_BIT_WIDTH[dtype]
102125
shards = pack(int_data, bit_width, dim=pack_dim)
103126
shape = list(int_data.shape)
104127
shape[pack_dim] = shape[pack_dim] * bit_width // 8
@@ -136,7 +159,6 @@ def to(self, *args, **kwargs):
136159

137160
implements = UintxTensor.implements
138161

139-
140162
@implements(aten.detach.default)
141163
def _(func, types, args, kwargs):
142164
return return_and_correct_aliasing(
@@ -166,16 +188,17 @@ def _(func, types, args, kwargs):
166188
return return_and_correct_aliasing(
167189
func, args, kwargs, args[0].apply_transformation(lambda x: (x * args[1]).to(torch.uint8))
168190
)
191+
169192
# quantization api integrations
170193
to_uintx = UintxTensor.from_uint8
171194

172195
@dataclass(frozen=True)
173196
class UintxLayoutType(LayoutType):
174-
bit_width: int
197+
dtype: torch.dtype
175198
pack_dim: int = -1
176199

177200
def post_process(self, input: torch.Tensor) -> torch.Tensor:
178-
return to_uintx(input, self.bit_width, self.pack_dim)
201+
return to_uintx(input, self.dtype, self.pack_dim)
179202

180203
@register_layout_cls(UintxLayoutType)
181204
class UintxAQTLayout(PlainAQTLayout):

torchao/quantization/quant_api.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -489,34 +489,34 @@ def int8_dynamic_activation_int8_semi_sparse_weight():
489489
return int8_dynamic_activation_int8_weight(layout_type=SemiSparseLayoutType())
490490

491491

492-
def uintx_weight_only(bit_width, group_size=64, pack_dim=-1):
492+
def uintx_weight_only(dtype, group_size=64, pack_dim=-1):
493493
"""
494494
Applies uintx weight-only asymmetric per-group quantization to linear layers, using uintx quantization where
495-
x is the number of bits specified by the `bit_width` argument
495+
x is the number of bits specified by `dtype`
496+
497+
Args:
498+
`dtype`: torch.uint1 to torch.uint7 sub byte dtypes
499+
`group_size`: parameter for quantization, controls the granularity of quantization, smaller
500+
size is more fine grained, defaults to 64
501+
`pack_dim`: the dimension we use for packing, defaults to -1
496502
"""
497503
from torchao.quantization.quant_primitives import (
498504
MappingType,
499505
ZeroPointDomain,
500-
choose_qparams_affine,
501-
quantize_affine,
502-
dequantize_affine,
503506
)
504507
from torchao.quantization.quant_api import _get_linear_subclass_inserter
505-
def apply_uintx_weight_only_quant(weight):
506508

507-
layout_type = UintxLayoutType(bit_width=bit_width, pack_dim=pack_dim)
509+
def apply_uintx_weight_only_quant(weight):
510+
layout_type = UintxLayoutType(dtype=dtype, pack_dim=pack_dim)
508511
mapping_type = MappingType.ASYMMETRIC
509512
block_size = (1, group_size)
510-
quant_min = 0
511-
quant_max = 2**bit_width - 1
512513
eps = torch.finfo(torch.float32).eps
513514
zero_point_dtype = torch.int32
514515
zero_point_domain = ZeroPointDomain.INT
515516

516517
return to_affine_quantized(
517-
weight, mapping_type, block_size, torch.uint8,
518-
quant_min = quant_min, quant_max = quant_max,
519-
eps = eps, zero_point_dtype=zero_point_dtype,
518+
weight, mapping_type, block_size, dtype,
519+
eps=eps, zero_point_dtype=zero_point_dtype,
520520
zero_point_domain=zero_point_domain,
521521
layout_type=layout_type,
522522
)

torchao/quantization/quant_primitives.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -68,17 +68,21 @@ class ZeroPointDomain(Enum):
6868
torch.int16: (-(2**15), 2**15 - 1),
6969
torch.int32: (-(2**31), 2**31 - 1),
7070
}
71+
_SUB_BYTE_DTYPE_BOUNDS: Dict[torch.dtype, Tuple[int, int]] = {}
7172

7273
if TORCH_VERSION_AT_LEAST_2_3:
73-
_DTYPE_TO_QVALUE_BOUNDS.update({
74+
_SUB_BYTE_DTYPE_BOUNDS = {
7475
torch.uint1: (0, 2**1-1),
7576
torch.uint2: (0, 2**2-1),
7677
torch.uint3: (0, 2**3-1),
7778
torch.uint4: (0, 2**4-1),
7879
torch.uint5: (0, 2**5-1),
7980
torch.uint6: (0, 2**6-1),
8081
torch.uint7: (0, 2**7-1),
81-
})
82+
}
83+
_DTYPE_TO_QVALUE_BOUNDS.update(
84+
_SUB_BYTE_DTYPE_BOUNDS
85+
)
8286

8387

8488
quant_lib = torch.library.Library("quant", "FRAGMENT")
@@ -216,6 +220,10 @@ def _quantize_affine(
216220
"""op definition that has compatible signatures with custom op library
217221
"""
218222
quant_min, quant_max = _get_and_check_qmin_qmax(output_dtype, quant_min, quant_max)
223+
# workaround for uintx dtypes, since we don't have native Uintx dtype connected with
224+
# torch.uintx dtypes yet
225+
if output_dtype in _SUB_BYTE_DTYPE_BOUNDS:
226+
output_dtype = torch.uint8
219227
return _quantize_affine_no_dtype_cast(
220228
input,
221229
block_size,
@@ -328,10 +336,9 @@ def _dequantize_affine(
328336
) -> torch.Tensor:
329337
"""op definition that has compatible signatures with custom op library
330338
"""
331-
332-
# TODO: validations
333339
# TODO: validate scale/zero_point dimensions are compatible with block_size
334-
assert input.dtype == input_dtype, f"Expected: {input_dtype}, got: {input.dtype}"
340+
if input_dtype not in _SUB_BYTE_DTYPE_BOUNDS:
341+
assert input.dtype == input_dtype, f"Expected: {input_dtype}, got: {input.dtype}"
335342
assert output_dtype in [torch.float32, torch.float16, torch.bfloat16], f"Unsupported output dtype: {output_dtype}"
336343
quant_min, quant_max = _get_and_check_qmin_qmax(input_dtype, quant_min, quant_max)
337344
return _dequantize_affine_no_dtype_check(

0 commit comments

Comments
 (0)