Skip to content

Use torch.uint1 to torch.uint7 for Uintx tensor subclass #672

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions test/dtypes/test_affine_quantized.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,14 @@
int8_dynamic_activation_int8_weight,
int8_dynamic_activation_int8_semi_sparse_weight,
)
from torchao.dtypes import (
to_affine_quantized,
)
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5

import torch
import unittest
import tempfile
from torchao.utils import (
TORCH_VERSION_AT_LEAST_2_5,
)


class TestAffineQuantized(TestCase):
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
Expand Down
100 changes: 76 additions & 24 deletions test/dtypes/test_uintx.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@

from torchao.dtypes.uintx.Uintx import to_uintx
from torchao.quantization.quant_api import quantize_, uintx_weight_only
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5
from torchao.utils import (
TORCH_VERSION_AT_LEAST_2_3,
TORCH_VERSION_AT_LEAST_2_5,
)

from torchao.quantization.quant_primitives import (
MappingType,
Expand All @@ -16,7 +19,12 @@
dequantize_affine,
)

bit_widths = (1, 2, 3, 4, 5, 6, 7)
# torch.uintx dtypes are introduced in 2.3
if TORCH_VERSION_AT_LEAST_2_3:
dtypes = (torch.uint1, torch.uint2, torch.uint3, torch.uint4, torch.uint5, torch.uint6, torch.uint7)
else:
dtypes = ()

group_sizes = [32, 64, 128]
devices = ["cpu", "cuda"]
@pytest.fixture(autouse=True)
Expand All @@ -36,72 +44,116 @@ def __init__(self, scale, device):
def forward(self, x):
return self.net(x)

@pytest.mark.parametrize("bit_width", bit_widths)
@pytest.mark.parametrize("dtype", dtypes)
@pytest.mark.parametrize("group_size", group_sizes)
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_5, reason="only works with fix in the nightly build")
def test_uintx_quant_on_cpu_then_move_to_cuda(bit_width, group_size):
def test_uintx_quant_on_cpu_then_move_to_cuda(dtype, group_size):
scale = 512
fp16_mod_on_cpu = Linear16(scale, "cpu")
quantize_(fp16_mod_on_cpu, uintx_weight_only(bit_width, group_size=group_size))
quantize_(fp16_mod_on_cpu, uintx_weight_only(dtype, group_size=group_size))
test_input_on_cpu = torch.randn(scale*2, dtype=torch.float16, device="cpu")
output_on_cpu = fp16_mod_on_cpu(test_input_on_cpu)
fp16_mod_on_cuda = fp16_mod_on_cpu.to("cuda")
test_input_on_cuda = test_input_on_cpu.to("cuda")
output_on_cuda = fp16_mod_on_cuda(test_input_on_cuda)
assert torch.allclose(output_on_cpu, output_on_cuda.cpu(), atol=1.0e-3), "The output of the model on CPU and CUDA should be close"

@pytest.mark.parametrize("bit_width", bit_widths)
@pytest.mark.parametrize("dtype", dtypes)
@pytest.mark.parametrize("group_size", group_sizes)
@pytest.mark.parametrize("device", devices)
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_5, reason="only works with fix in the nightly build")
def test_uintx_weight_only_model_quant(bit_width, group_size, device):
def test_uintx_weight_only_model_quant(dtype, group_size, device):
scale = 512
fp16 = Linear16(scale, device)
quantize_(fp16, uintx_weight_only(bit_width, group_size=group_size))
quantize_(fp16, uintx_weight_only(dtype, group_size=group_size))
uintx = torch.compile(fp16, fullgraph=True)
test_input = torch.randn(scale*2, dtype=torch.float16, device=device)
output = uintx.forward(test_input)
assert output != None, "model quantization failed"

@pytest.mark.parametrize("bit_width", bit_widths)
@pytest.mark.parametrize("dtype", dtypes)
@pytest.mark.parametrize("group_size", group_sizes)
@pytest.mark.parametrize("device", devices)
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_5, reason="only works with fix in the nightly build")
def test_uintx_weight_only_quant(bit_width, group_size, device):
def test_uintx_weight_only_quant(dtype, group_size, device):
input_float = torch.randn((1, 256), dtype=torch.float16, device = device)
mapping_type = MappingType.SYMMETRIC
quant_min = 0
quant_max = 2 ** bit_width - 1
eps = torch.finfo(torch.float32).eps
zero_point_dtype = torch.int32
zero_point_domain = ZeroPointDomain.INT
target_dtype = torch.uint8
block_size = (1, group_size)

scale, zero_point = choose_qparams_affine(
input_float, mapping_type, block_size,
target_dtype, quant_min, quant_max, eps, torch.float32,
zero_point_dtype, True, zero_point_domain
dtype, eps=eps, scale_dtype=torch.float32,
zero_point_dtype=zero_point_dtype, preserve_zero=True, zero_point_domain=zero_point_domain
)

aqt = quantize_affine(
input_float, block_size, scale,
zero_point, target_dtype,
quant_min = quant_min,
quant_max = quant_max,
zero_point_domain = zero_point_domain
zero_point, dtype,
zero_point_domain=zero_point_domain
)
# Note: output will be uint8 tensor for sub byte tensors for now

q = to_uintx(aqt, bit_width, -1)
q = to_uintx(aqt, dtype, -1)
assert q != None, "quantization failed"
deqaunt = dequantize_affine(
q, block_size, scale,
zero_point, target_dtype,
quant_min = quant_min,
quant_max = quant_max,
zero_point_domain = zero_point_domain
zero_point, dtype,
zero_point_domain=zero_point_domain
)
assert deqaunt != None, "deqauntization failed"


@pytest.mark.parametrize("dtype", dtypes)
@pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available")
@pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_3, reason="sub byte dtype requires torch 2.3+")
def test_uintx_target_dtype(dtype):
from torchao.quantization.quant_api import uintx_weight_only
l = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda")
# make sure it runs
uintx_weight_only(dtype)(l)
l(torch.randn(1, 128, dtype=torch.bfloat16, device="cuda"))

@pytest.mark.parametrize("dtype", dtypes)
@pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available")
@pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_5, reason="torch.compile without unwrap_tensor_subclass requires torch 2.5+")
def test_uintx_target_dtype_compile(dtype):
from torchao.quantization.quant_api import uintx_weight_only
l = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda")
# make sure it runs
uintx_weight_only(dtype)(l)
l = torch.compile(l)
l(torch.randn(1, 128, dtype=torch.bfloat16, device="cuda"))


@pytest.mark.parametrize("dtype", dtypes)
@pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available")
@pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_3, reason="sub byte dtype requires torch 2.3+")
def test_uintx_model_size(dtype):
from torchao.quantization.quant_api import uintx_weight_only
from torchao.utils import get_model_size_in_bytes
# scale size = 1/64 * 2 bytes = 1/32 bytes
# zero_point size = 1/64 * 4 bytes = 1/16 bytes
# dtype data size = 1 * bit_width/8 = bit_width/8 bytes
_dtype_to_ratio = {
torch.uint1: (1/8 + 1/16 + 1/32) / 2,
torch.uint2: (2/8 + 1/16 + 1/32) / 2,
torch.uint3: (3/8 + 1/16 + 1/32) / 2,
torch.uint4: (4/8 + 1/16 + 1/32) / 2,
torch.uint5: (5/8 + 1/16 + 1/32) / 2,
torch.uint6: (6/8 + 1/16 + 1/32) / 2,
torch.uint7: (7/8 + 1/16 + 1/32) / 2,
}
l = torch.nn.Sequential(
torch.nn.Linear(128, 256, bias=False, dtype=torch.bfloat16, device="cuda")
)
bf16_size = get_model_size_in_bytes(l)
# make sure it runs
uintx_weight_only(dtype)(l[0])
quantized_size = get_model_size_in_bytes(l)
assert bf16_size * _dtype_to_ratio[dtype] == quantized_size
7 changes: 4 additions & 3 deletions torchao/dtypes/affine_quantized_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@

aten = torch.ops.aten


###############################
# Base Layout Tensor Subclass #
###############################
Expand Down Expand Up @@ -198,8 +199,9 @@ def from_float(
use_hqq: bool = False,
):
original_shape = input_float.shape
input_float = layout_type.pre_process(input_float)

if(use_hqq):
if use_hqq:
assert zero_point_domain == ZeroPointDomain.FLOAT and mapping_type == MappingType.ASYMMETRIC and quant_min==0, "Invalid input parameters for HQQ quantization."
nbits = int(math.log2(quant_max + 1))
axis = 1 if (block_size[0]==1) else 0
Expand All @@ -208,11 +210,10 @@ def from_float(
device = input_float.device
int_data, scale, zero_point, _ = quantize_affine_hqq(input_float, nbits=nbits, group_size=group_size, axis=axis, compute_dtype=compute_dtype, device=device, verbose=False, raw_output=False)
int_data = int_data.to(target_dtype)

else:
input_float = layout_type.pre_process(input_float)
scale, zero_point = choose_qparams_affine(input_float, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, scale_dtype, zero_point_dtype, preserve_zero, zero_point_domain)
int_data = quantize_affine(input_float, block_size, scale, zero_point, target_dtype, quant_min, quant_max, zero_point_domain)
# Note: output will be uint8 tensor for sub byte tensors for now

int_data = layout_type.post_process(int_data)
layout_tensor_ctr = get_layout_tensor_constructor(type(layout_type))
Expand Down
35 changes: 29 additions & 6 deletions torchao/dtypes/uintx/Uintx.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,30 @@
_dispatch__torch_dispatch__,
)
from torchao.dtypes.affine_quantized_tensor import PlainAQTLayout, register_layout_cls

from torchao.utils import TORCH_VERSION_AT_LEAST_2_3

aten = torch.ops.aten

# Note: Uintx does not work for torch 2.3 and below
_DTYPE_TO_BIT_WIDTH = {}
_BIT_WIDTH_TO_DTYPE = {}

if TORCH_VERSION_AT_LEAST_2_3:
_DTYPE_TO_BIT_WIDTH = {
torch.uint1: 1,
torch.uint2: 2,
torch.uint3: 3,
torch.uint4: 4,
torch.uint5: 5,
torch.uint6: 6,
torch.uint7: 7,
}

_BIT_WIDTH_TO_DTYPE = {v: k for k, v in _DTYPE_TO_BIT_WIDTH.items()}
else:
print("uintx feature need torch 2.3+, please upgrade pytorch")


class UintxTensor(torch.Tensor):
"""
Splits int data into packed shards based on bit size
Expand Down Expand Up @@ -90,15 +110,18 @@ def get_plain(self):
def apply_transformation(self, fn):
og = self.get_plain()
new = fn(og)
return self.from_uint8(new, self.bit_width, self.pack_dim)
dtype = _BIT_WIDTH_TO_DTYPE[self.bit_width]
return self.from_uint8(new, dtype, self.pack_dim)

# temporary until kernels on packed tensors are created
def apply_fn_to_shards(self, fn):
new_shards = [fn(shard) for shard in self.get_shards()]
return self.__class__(new_shards, self.packed_shape, self.bit_width, self.pack_dim)

@classmethod
def from_uint8(cls, int_data: torch.Tensor, bit_width, pack_dim: int = -1):
def from_uint8(cls, int_data: torch.Tensor, dtype: torch.dtype, pack_dim: int = -1):
assert dtype in _DTYPE_TO_BIT_WIDTH.keys(), "Expected dtype to be one of {_DTYPE_TO_BIT_WIDTH.keys()}"
bit_width = _DTYPE_TO_BIT_WIDTH[dtype]
shards = pack(int_data, bit_width, dim=pack_dim)
shape = list(int_data.shape)
shape[pack_dim] = shape[pack_dim] * bit_width // 8
Expand Down Expand Up @@ -136,7 +159,6 @@ def to(self, *args, **kwargs):

implements = UintxTensor.implements


@implements(aten.detach.default)
def _(func, types, args, kwargs):
return return_and_correct_aliasing(
Expand Down Expand Up @@ -166,16 +188,17 @@ def _(func, types, args, kwargs):
return return_and_correct_aliasing(
func, args, kwargs, args[0].apply_transformation(lambda x: (x * args[1]).to(torch.uint8))
)

# quantization api integrations
to_uintx = UintxTensor.from_uint8

@dataclass(frozen=True)
class UintxLayoutType(LayoutType):
bit_width: int
dtype: torch.dtype
pack_dim: int = -1

def post_process(self, input: torch.Tensor) -> torch.Tensor:
return to_uintx(input, self.bit_width, self.pack_dim)
return to_uintx(input, self.dtype, self.pack_dim)

@register_layout_cls(UintxLayoutType)
class UintxAQTLayout(PlainAQTLayout):
Expand Down
24 changes: 12 additions & 12 deletions torchao/quantization/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,34 +489,34 @@ def int8_dynamic_activation_int8_semi_sparse_weight():
return int8_dynamic_activation_int8_weight(layout_type=SemiSparseLayoutType())


def uintx_weight_only(bit_width, group_size=64, pack_dim=-1):
def uintx_weight_only(dtype, group_size=64, pack_dim=-1):
"""
Applies uintx weight-only asymmetric per-group quantization to linear layers, using uintx quantization where
x is the number of bits specified by the `bit_width` argument
x is the number of bits specified by `dtype`

Args:
`dtype`: torch.uint1 to torch.uint7 sub byte dtypes
`group_size`: parameter for quantization, controls the granularity of quantization, smaller
size is more fine grained, defaults to 64
`pack_dim`: the dimension we use for packing, defaults to -1
"""
from torchao.quantization.quant_primitives import (
MappingType,
ZeroPointDomain,
choose_qparams_affine,
quantize_affine,
dequantize_affine,
)
from torchao.quantization.quant_api import _get_linear_subclass_inserter
def apply_uintx_weight_only_quant(weight):

layout_type = UintxLayoutType(bit_width=bit_width, pack_dim=pack_dim)
def apply_uintx_weight_only_quant(weight):
layout_type = UintxLayoutType(dtype=dtype, pack_dim=pack_dim)
mapping_type = MappingType.ASYMMETRIC
block_size = (1, group_size)
quant_min = 0
quant_max = 2**bit_width - 1
eps = torch.finfo(torch.float32).eps
zero_point_dtype = torch.int32
zero_point_domain = ZeroPointDomain.INT

return to_affine_quantized(
weight, mapping_type, block_size, torch.uint8,
quant_min = quant_min, quant_max = quant_max,
eps = eps, zero_point_dtype=zero_point_dtype,
weight, mapping_type, block_size, dtype,
eps=eps, zero_point_dtype=zero_point_dtype,
zero_point_domain=zero_point_domain,
layout_type=layout_type,
)
Expand Down
17 changes: 12 additions & 5 deletions torchao/quantization/quant_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,17 +68,21 @@ class ZeroPointDomain(Enum):
torch.int16: (-(2**15), 2**15 - 1),
torch.int32: (-(2**31), 2**31 - 1),
}
_SUB_BYTE_DTYPE_BOUNDS: Dict[torch.dtype, Tuple[int, int]] = {}

if TORCH_VERSION_AT_LEAST_2_3:
_DTYPE_TO_QVALUE_BOUNDS.update({
_SUB_BYTE_DTYPE_BOUNDS = {
torch.uint1: (0, 2**1-1),
torch.uint2: (0, 2**2-1),
torch.uint3: (0, 2**3-1),
torch.uint4: (0, 2**4-1),
torch.uint5: (0, 2**5-1),
torch.uint6: (0, 2**6-1),
torch.uint7: (0, 2**7-1),
})
}
_DTYPE_TO_QVALUE_BOUNDS.update(
_SUB_BYTE_DTYPE_BOUNDS
)


quant_lib = torch.library.Library("quant", "FRAGMENT")
Expand Down Expand Up @@ -216,6 +220,10 @@ def _quantize_affine(
"""op definition that has compatible signatures with custom op library
"""
quant_min, quant_max = _get_and_check_qmin_qmax(output_dtype, quant_min, quant_max)
# workaround for uintx dtypes, since we don't have native Uintx dtype connected with
# torch.uintx dtypes yet
if output_dtype in _SUB_BYTE_DTYPE_BOUNDS:
output_dtype = torch.uint8
return _quantize_affine_no_dtype_cast(
input,
block_size,
Expand Down Expand Up @@ -328,10 +336,9 @@ def _dequantize_affine(
) -> torch.Tensor:
"""op definition that has compatible signatures with custom op library
"""

# TODO: validations
# TODO: validate scale/zero_point dimensions are compatible with block_size
assert input.dtype == input_dtype, f"Expected: {input_dtype}, got: {input.dtype}"
if input_dtype not in _SUB_BYTE_DTYPE_BOUNDS:
assert input.dtype == input_dtype, f"Expected: {input_dtype}, got: {input.dtype}"
assert output_dtype in [torch.float32, torch.float16, torch.bfloat16], f"Unsupported output dtype: {output_dtype}"
quant_min, quant_max = _get_and_check_qmin_qmax(input_dtype, quant_min, quant_max)
return _dequantize_affine_no_dtype_check(
Expand Down
Loading