Skip to content

Commit 7d8cdce

Browse files
committed
Use torch.uint1 to torch.uint7 for Uintx tensor subclass
Summary: Previously we are using bit_width for uintx quantization, but we can actually use `dtype` directly. But there are still some workaround to convert from torch dtype to bit_width right now, if we want to remove all the hacks, we'd need to support Uintx tensor subclass properly and have `torch.uintx` dispatch to the tensor subclass this is probably not the highest priority for now since good perf is more important. Test Plan: python test/dtypes/test_affine_quantized.py pytest test/dtypes/test_uintx.py Reviewers: Subscribers: Tasks: Tags:
1 parent 18e38f1 commit 7d8cdce

File tree

6 files changed

+129
-55
lines changed

6 files changed

+129
-55
lines changed

test/dtypes/test_affine_quantized.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,14 @@
99
int8_dynamic_activation_int8_weight,
1010
int8_dynamic_activation_int8_semi_sparse_weight,
1111
)
12+
from torchao.dtypes import (
13+
to_affine_quantized,
14+
)
15+
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5
16+
1217
import torch
1318
import unittest
1419
import tempfile
15-
from torchao.utils import (
16-
TORCH_VERSION_AT_LEAST_2_5,
17-
)
18-
1920

2021
class TestAffineQuantized(TestCase):
2122
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")

test/dtypes/test_uintx.py

Lines changed: 63 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,10 @@
66

77
from torchao.dtypes.uintx.Uintx import to_uintx
88
from torchao.quantization.quant_api import quantize_, uintx_weight_only
9-
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5
9+
from torchao.utils import (
10+
TORCH_VERSION_AT_LEAST_2_3,
11+
TORCH_VERSION_AT_LEAST_2_5,
12+
)
1013

1114
from torchao.quantization.quant_primitives import (
1215
MappingType,
@@ -16,7 +19,12 @@
1619
dequantize_affine,
1720
)
1821

19-
bit_widths = (1, 2, 3, 4, 5, 6, 7)
22+
# torch.uintx dtypes are introduced in 2.3
23+
if TORCH_VERSION_AT_LEAST_2_3:
24+
dtypes = (torch.uint1, torch.uint2, torch.uint3, torch.uint4, torch.uint5, torch.uint6, torch.uint7)
25+
else:
26+
dtypes = ()
27+
2028
group_sizes = [32, 64, 128]
2129
devices = ["cpu", "cuda"]
2230
@pytest.fixture(autouse=True)
@@ -36,57 +44,91 @@ def __init__(self, scale, device):
3644
def forward(self, x):
3745
return self.net(x)
3846

39-
@pytest.mark.parametrize("bit_width", bit_widths)
47+
@pytest.mark.parametrize("dtype", dtypes)
4048
@pytest.mark.parametrize("group_size", group_sizes)
4149
@pytest.mark.parametrize("device", devices)
4250
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
4351
@pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_5, reason="only works with fix in the nightly build")
44-
def test_uintx_weight_only_model_quant(bit_width, group_size, device):
52+
def test_uintx_weight_only_model_quant(dtype, group_size, device):
4553
scale = 512
4654
fp16 = Linear16(scale, device)
47-
quantize_(fp16, uintx_weight_only(bit_width, group_size=group_size))
55+
quantize_(fp16, uintx_weight_only(dtype, group_size=group_size))
4856
uintx = torch.compile(fp16, fullgraph=True)
4957
test_input = torch.randn(scale*2, dtype=torch.float16, device=device)
5058
output = uintx.forward(test_input)
5159
assert output != None, "model quantization failed"
5260

53-
@pytest.mark.parametrize("bit_width", bit_widths)
61+
@pytest.mark.parametrize("dtype", dtypes)
5462
@pytest.mark.parametrize("group_size", group_sizes)
5563
@pytest.mark.parametrize("device", devices)
5664
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
5765
@pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_5, reason="only works with fix in the nightly build")
58-
def test_uintx_weight_only_quant(bit_width, group_size, device):
66+
def test_uintx_weight_only_quant(dtype, group_size, device):
5967
input_float = torch.randn((1, 256), dtype=torch.float16, device = device)
6068
mapping_type = MappingType.SYMMETRIC
61-
quant_min = 0
62-
quant_max = 2 ** bit_width - 1
6369
eps = torch.finfo(torch.float32).eps
6470
zero_point_dtype = torch.int32
6571
zero_point_domain = ZeroPointDomain.INT
66-
target_dtype = torch.uint8
6772
block_size = (1, group_size)
6873

6974
scale, zero_point = choose_qparams_affine(
7075
input_float, mapping_type, block_size,
71-
target_dtype, quant_min, quant_max, eps, torch.float32,
72-
zero_point_dtype, True, zero_point_domain
76+
dtype, eps=eps, scale_dtype=torch.float32,
77+
zero_point_dtype=zero_point_dtype, preserve_zero=True, zero_point_domain=zero_point_domain
7378
)
7479

7580
aqt = quantize_affine(
7681
input_float, block_size, scale,
77-
zero_point, target_dtype,
78-
quant_min = quant_min,
79-
quant_max = quant_max,
80-
zero_point_domain = zero_point_domain
82+
zero_point, dtype,
83+
zero_point_domain=zero_point_domain
8184
)
85+
# Note: output will be uint8 tensor for sub byte tensors for now
8286

83-
q = to_uintx(aqt, bit_width, -1)
87+
q = to_uintx(aqt, dtype, -1)
8488
assert q != None, "quantization failed"
8589
deqaunt = dequantize_affine(
8690
q, block_size, scale,
87-
zero_point, target_dtype,
88-
quant_min = quant_min,
89-
quant_max = quant_max,
90-
zero_point_domain = zero_point_domain
91+
zero_point, dtype,
92+
zero_point_domain=zero_point_domain
9193
)
9294
assert deqaunt != None, "deqauntization failed"
95+
96+
97+
@pytest.mark.parametrize("dtype", dtypes)
98+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available")
99+
@pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_3, reason="sub byte dtype requires torch 2.3+")
100+
def test_uintx_target_dtype(dtype):
101+
from torchao.quantization.quant_api import uintx_weight_only
102+
l = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda")
103+
# make sure it runs
104+
uintx_weight_only(dtype)(l)
105+
l = torch.compile(l)
106+
l(torch.randn(1, 128, dtype=torch.bfloat16, device="cuda"))
107+
108+
109+
@pytest.mark.parametrize("dtype", dtypes)
110+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available")
111+
@pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_3, reason="sub byte dtype requires torch 2.3+")
112+
def test_uintx_model_size(dtype):
113+
from torchao.quantization.quant_api import uintx_weight_only
114+
from torchao.utils import get_model_size_in_bytes
115+
# scale size = 1/64 * 2 bytes = 1/32 bytes
116+
# zero_point size = 1/64 * 4 bytes = 1/16 bytes
117+
# dtype data size = 1 * bit_width/8 = bit_width/8 bytes
118+
_dtype_to_ratio = {
119+
torch.uint1: (1/8 + 1/16 + 1/32) / 2,
120+
torch.uint2: (2/8 + 1/16 + 1/32) / 2,
121+
torch.uint3: (3/8 + 1/16 + 1/32) / 2,
122+
torch.uint4: (4/8 + 1/16 + 1/32) / 2,
123+
torch.uint5: (5/8 + 1/16 + 1/32) / 2,
124+
torch.uint6: (6/8 + 1/16 + 1/32) / 2,
125+
torch.uint7: (7/8 + 1/16 + 1/32) / 2,
126+
}
127+
l = torch.nn.Sequential(
128+
torch.nn.Linear(128, 256, bias=False, dtype=torch.bfloat16, device="cuda")
129+
)
130+
bf16_size = get_model_size_in_bytes(l)
131+
# make sure it runs
132+
uintx_weight_only(dtype)(l[0])
133+
quantized_size = get_model_size_in_bytes(l)
134+
assert bf16_size * _dtype_to_ratio[dtype] == quantized_size

torchao/dtypes/affine_quantized_tensor.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333

3434
aten = torch.ops.aten
3535

36+
3637
###############################
3738
# Base Layout Tensor Subclass #
3839
###############################
@@ -208,8 +209,9 @@ def from_float(
208209
use_hqq: bool = False,
209210
):
210211
original_shape = input_float.shape
212+
input_float = layout_type.pre_process(input_float)
211213

212-
if(use_hqq):
214+
if use_hqq:
213215
assert zero_point_domain == ZeroPointDomain.FLOAT and mapping_type == MappingType.ASYMMETRIC and quant_min==0, "Invalid input parameters for HQQ quantization."
214216
nbits = int(math.log2(quant_max + 1))
215217
axis = 1 if (block_size[0]==1) else 0
@@ -218,12 +220,11 @@ def from_float(
218220
device = input_float.device
219221
int_data, scale, zero_point, _ = quantize_affine_hqq(input_float, nbits=nbits, group_size=group_size, axis=axis, compute_dtype=compute_dtype, device=device, verbose=False, raw_output=False)
220222
int_data = int_data.to(target_dtype)
221-
222223
else:
223-
input_float = layout_type.pre_process(input_float)
224224
scale, zero_point = choose_qparams_affine(input_float, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, scale_dtype, zero_point_dtype, preserve_zero, zero_point_domain)
225225
int_data = quantize_affine(input_float, block_size, scale, zero_point, target_dtype, quant_min, quant_max, zero_point_domain)
226-
226+
# Note: output will be uint8 tensor for sub byte tensors for now
227+
227228
int_data = layout_type.post_process(int_data)
228229
layout_tensor_ctr = get_layout_tensor_constructor(type(layout_type))
229230
layout_tensor = layout_tensor_ctr(int_data, scale, zero_point, layout_type)
@@ -576,10 +577,10 @@ def from_plain(
576577
scale: torch.Tensor,
577578
zero_point: torch.Tensor,
578579
layout_type: LayoutType
579-
):
580-
580+
):
581+
581582
assert isinstance(layout_type, TensorCoreTiledLayoutType)
582-
583+
583584
if TORCH_VERSION_AT_LEAST_2_5:
584585
int_data = (int_data[::, ::2] << 4 | int_data[::, 1::2]).to(torch.uint8)
585586
assert int_data.dtype == torch.uint8, "torch.ops.aten._convert_weight_to_int4pack in torch 2.5 expects `uint8` dtype"

torchao/dtypes/uintx/Uintx.py

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,30 @@
1111
_dispatch__torch_dispatch__,
1212
)
1313
from torchao.dtypes.affine_quantized_tensor import PlainAQTLayout, register_layout_cls
14-
14+
from torchao.utils import TORCH_VERSION_AFTER_2_3
1515

1616
aten = torch.ops.aten
1717

18+
# Note: Uintx does not work for torch 2.3 and below
19+
_DTYPE_TO_BIT_WIDTH = {}
20+
_BIT_WIDTH_TO_DTYPE = {}
21+
22+
if TORCH_VERSION_AFTER_2_3:
23+
_DTYPE_TO_BIT_WIDTH = {
24+
torch.uint1: 1,
25+
torch.uint2: 2,
26+
torch.uint3: 3,
27+
torch.uint4: 4,
28+
torch.uint5: 5,
29+
torch.uint6: 6,
30+
torch.uint7: 7,
31+
}
32+
33+
_BIT_WIDTH_TO_DTYPE = {v: k for k, v in _DTYPE_TO_BIT_WIDTH.items()}
34+
else:
35+
print("uintx feature need torch 2.3+, please upgrade pytorch")
36+
37+
1838
class UintxTensor(torch.Tensor):
1939
"""
2040
Splits int data into packed shards based on bit size
@@ -90,15 +110,18 @@ def get_plain(self):
90110
def apply_transformation(self, fn):
91111
og = self.get_plain()
92112
new = fn(og)
93-
return self.from_uint8(new, self.bit_width, self.pack_dim)
113+
dtype = _BIT_WIDTH_TO_DTYPE[self.bit_width]
114+
return self.from_uint8(new, dtype, self.pack_dim)
94115

95116
# temporary until kernels on packed tensors are created
96117
def apply_fn_to_shards(self, fn):
97118
new_shards = [fn(shard) for shard in self.get_shards()]
98119
return self.__class__(new_shards, self.packed_shape, self.bit_width, self.pack_dim)
99120

100121
@classmethod
101-
def from_uint8(cls, int_data: torch.Tensor, bit_width, pack_dim: int = -1):
122+
def from_uint8(cls, int_data: torch.Tensor, dtype: torch.dtype, pack_dim: int = -1):
123+
assert dtype in _DTYPE_TO_BIT_WIDTH.keys(), "Expected dtype to be one of {_DTYPE_TO_BITWIDTH.keys()}"
124+
bit_width = _DTYPE_TO_BIT_WIDTH[dtype]
102125
shards = pack(int_data, bit_width, dim=pack_dim)
103126
shape = list(int_data.shape)
104127
shape[pack_dim] = shape[pack_dim] * bit_width // 8
@@ -107,7 +130,6 @@ def from_uint8(cls, int_data: torch.Tensor, bit_width, pack_dim: int = -1):
107130

108131
implements = UintxTensor.implements
109132

110-
111133
@implements(aten.detach.default)
112134
def _(func, types, args, kwargs):
113135
return return_and_correct_aliasing(
@@ -137,16 +159,17 @@ def _(func, types, args, kwargs):
137159
return return_and_correct_aliasing(
138160
func, args, kwargs, args[0].apply_transformation(lambda x: (x * args[1]).to(torch.uint8))
139161
)
162+
140163
# quantization api integrations
141164
to_uintx = UintxTensor.from_uint8
142165

143166
@dataclass(frozen=True)
144167
class UintxLayoutType(LayoutType):
145-
bit_width: int
168+
dtype: torch.dtype
146169
pack_dim: int = -1
147170

148171
def post_process(self, input: torch.Tensor) -> torch.Tensor:
149-
return to_uintx(input, self.bit_width, self.pack_dim)
172+
return to_uintx(input, self.dtype, self.pack_dim)
150173

151174
@register_layout_cls(UintxLayoutType)
152175
class UintxAQTLayout(PlainAQTLayout):

torchao/quantization/quant_api.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -468,34 +468,34 @@ def int8_dynamic_activation_int8_semi_sparse_weight():
468468
return int8_dynamic_activation_int8_weight(layout_type=SemiSparseLayoutType())
469469

470470

471-
def uintx_weight_only(bit_width, group_size=64, pack_dim=-1):
471+
def uintx_weight_only(dtype, group_size=64, pack_dim=-1):
472472
"""
473473
Applies uintx weight-only asymmetric per-group quantization to linear layers, using uintx quantization where
474-
x is the number of bits specified by the `bit_width` argument
474+
x is the number of bits specified by `dtype`
475+
476+
Args:
477+
`dtype`: torch.uint1 to torch.uint7 sub byte dtypes
478+
`group_size`: parameter for quantization, controls the granularity of quantization, smaller
479+
size is more fine grained, defaults to 64
480+
`pack_dim`: the dimension we use for packing, defaults to -1
475481
"""
476482
from torchao.quantization.quant_primitives import (
477483
MappingType,
478484
ZeroPointDomain,
479-
choose_qparams_affine,
480-
quantize_affine,
481-
dequantize_affine,
482485
)
483486
from torchao.quantization.quant_api import _get_linear_subclass_inserter
484-
def apply_uintx_weight_only_quant(weight):
485487

486-
layout_type = UintxLayoutType(bit_width=bit_width, pack_dim=pack_dim)
488+
def apply_uintx_weight_only_quant(weight):
489+
layout_type = UintxLayoutType(dtype=dtype, pack_dim=pack_dim)
487490
mapping_type = MappingType.ASYMMETRIC
488491
block_size = (1, group_size)
489-
quant_min = 0
490-
quant_max = 2**bit_width - 1
491492
eps = torch.finfo(torch.float32).eps
492493
zero_point_dtype = torch.int32
493494
zero_point_domain = ZeroPointDomain.INT
494495

495496
return to_affine_quantized(
496-
weight, mapping_type, block_size, torch.uint8,
497-
quant_min = quant_min, quant_max = quant_max,
498-
eps = eps, zero_point_dtype=zero_point_dtype,
497+
weight, mapping_type, block_size, dtype,
498+
eps=eps, zero_point_dtype=zero_point_dtype,
499499
zero_point_domain=zero_point_domain,
500500
layout_type=layout_type,
501501
)

torchao/quantization/quant_primitives.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -68,17 +68,21 @@ class ZeroPointDomain(Enum):
6868
torch.int16: (-(2**15), 2**15 - 1),
6969
torch.int32: (-(2**31), 2**31 - 1),
7070
}
71+
_SUB_BYTE_DTYPE_BOUNDS: Dict[torch.dtype, Tuple[int, int]] = {}
7172

7273
if TORCH_VERSION_AT_LEAST_2_3:
73-
_DTYPE_TO_QVALUE_BOUNDS.update({
74+
_SUB_BYTE_DTYPE_BOUNDS = {
7475
torch.uint1: (0, 2**1-1),
7576
torch.uint2: (0, 2**2-1),
7677
torch.uint3: (0, 2**3-1),
7778
torch.uint4: (0, 2**4-1),
7879
torch.uint5: (0, 2**5-1),
7980
torch.uint6: (0, 2**6-1),
8081
torch.uint7: (0, 2**7-1),
81-
})
82+
}
83+
_DTYPE_TO_QVALUE_BOUNDS.update(
84+
_SUB_BYTE_DTYPE_BOUNDS
85+
)
8286

8387

8488
quant_lib = torch.library.Library("quant", "FRAGMENT")
@@ -216,6 +220,10 @@ def _quantize_affine(
216220
"""op definition that has compatible signatures with custom op library
217221
"""
218222
quant_min, quant_max = _get_and_check_qmin_qmax(output_dtype, quant_min, quant_max)
223+
# workaround for uintx dtypes, since we don't have native Uintx dtype connected with
224+
# torch.uintx dtypes yet
225+
if output_dtype in _SUB_BYTE_DTYPE_BOUNDS:
226+
output_dtype = torch.uint8
219227
return _quantize_affine_no_dtype_cast(
220228
input,
221229
block_size,
@@ -328,10 +336,9 @@ def _dequantize_affine(
328336
) -> torch.Tensor:
329337
"""op definition that has compatible signatures with custom op library
330338
"""
331-
332-
# TODO: validations
333339
# TODO: validate scale/zero_point dimensions are compatible with block_size
334-
assert input.dtype == input_dtype, f"Expected: {input_dtype}, got: {input.dtype}"
340+
if input_dtype not in _SUB_BYTE_DTYPE_BOUNDS:
341+
assert input.dtype == input_dtype, f"Expected: {input_dtype}, got: {input.dtype}"
335342
assert output_dtype in [torch.float32, torch.float16, torch.bfloat16], f"Unsupported output dtype: {output_dtype}"
336343
quant_min, quant_max = _get_and_check_qmin_qmax(input_dtype, quant_min, quant_max)
337344
return _dequantize_affine_no_dtype_check(

0 commit comments

Comments
 (0)