Skip to content

Commit 1b26f26

Browse files
committed
Add Int4CPULayout and update int4 woq
1 parent 01dc7da commit 1b26f26

File tree

6 files changed

+290
-34
lines changed

6 files changed

+290
-34
lines changed

test/dtypes/test_affine_quantized.py

Lines changed: 22 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
float8_weight_only,
1212
)
1313
from torchao.quantization.quant_primitives import MappingType
14-
from torchao.dtypes import SemiSparseLayout
14+
from torchao.dtypes import SemiSparseLayout, Int4CPULayout
1515
from torch.testing._internal import common_utils
1616
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5
1717

@@ -22,15 +22,18 @@
2222
is_cuda_8_9 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9)
2323

2424

25-
def get_quantization_functions(do_sparse: bool, do_int4: bool):
25+
def get_quantization_functions(do_sparse: bool, do_int4: bool, device: str = "cuda"):
2626
base_functions = [
2727
int8_weight_only(),
2828
int8_dynamic_activation_int4_weight(),
2929
int8_dynamic_activation_int8_weight(),
3030
int8_dynamic_activation_int8_weight(act_mapping_type=MappingType.ASYMMETRIC),
3131
]
3232
if do_int4:
33-
base_functions.append(int4_weight_only(group_size=32))
33+
if device == "cpu":
34+
base_functions.append(int4_weight_only(group_size=32, layout=Int4CPULayout()))
35+
else:
36+
base_functions.append(int4_weight_only(group_size=32))
3437

3538
if do_sparse:
3639
base_functions.append(int8_dynamic_activation_int8_weight(layout=SemiSparseLayout()))
@@ -139,23 +142,24 @@ class TestAffineQuantizedBasic(TestCase):
139142
COMMON_DEVICES = ["cpu"] + (["cuda"] if torch.cuda.is_available() else [])
140143
COMMON_DTYPES = [torch.bfloat16]
141144

142-
@common_utils.parametrize("apply_quant", get_quantization_functions(False, True))
143145
@common_utils.parametrize("device", COMMON_DEVICES)
144146
@common_utils.parametrize("dtype", COMMON_DTYPES)
145-
def test_flatten_unflatten(self, apply_quant, device, dtype):
146-
l = torch.nn.Linear(128, 256, dtype=dtype, device=device)
147-
ql = apply_quant(l)
148-
lp_tensor = ql.weight
149-
tensor_data_name_dict, tensor_attributes = lp_tensor.__tensor_flatten__()
150-
tensor_data_dict = {name: getattr(lp_tensor, name) for name in tensor_data_name_dict}
151-
outer_size = lp_tensor.size()
152-
outer_stride = lp_tensor.stride()
153-
reconstructed = type(lp_tensor).__tensor_unflatten__(tensor_data_dict, tensor_attributes, outer_size, outer_stride)
154-
example_inputs = (torch.randn(32, 128, dtype=dtype, device=device),)
155-
ref = ql(*example_inputs)
156-
ql.weight = torch.nn.Parameter(reconstructed, requires_grad=False)
157-
reconstruct_res = ql(*example_inputs)
158-
self.assertEqual(reconstruct_res, ref)
147+
def test_flatten_unflatten(self, device, dtype):
148+
apply_quant_list = get_quantization_functions(False, True, device)
149+
for apply_quant in apply_quant_list:
150+
l = torch.nn.Linear(128, 256, dtype=dtype, device=device)
151+
ql = apply_quant(l)
152+
lp_tensor = ql.weight
153+
tensor_data_name_dict, tensor_attributes = lp_tensor.__tensor_flatten__()
154+
tensor_data_dict = {name: getattr(lp_tensor, name) for name in tensor_data_name_dict}
155+
outer_size = lp_tensor.size()
156+
outer_stride = lp_tensor.stride()
157+
reconstructed = type(lp_tensor).__tensor_unflatten__(tensor_data_dict, tensor_attributes, outer_size, outer_stride)
158+
example_inputs = (torch.randn(32, 128, dtype=dtype, device=device),)
159+
ref = ql(*example_inputs)
160+
ql.weight = torch.nn.Parameter(reconstructed, requires_grad=False)
161+
reconstruct_res = ql(*example_inputs)
162+
self.assertEqual(reconstruct_res, ref)
159163

160164
common_utils.instantiate_parametrized_tests(TestAffineQuantized)
161165
common_utils.instantiate_parametrized_tests(TestAffineQuantizedBasic)

test/integration/test_integration.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from torchao.quantization.dynamic_quant import (
2020
DynamicallyPerAxisQuantizedLinear,
2121
)
22-
from torchao.dtypes import TensorCoreTiledLayout
22+
from torchao.dtypes import TensorCoreTiledLayout, Int4CPULayout
2323
from torchao.quantization.quant_api import (
2424
int4_weight_only,
2525
int8_weight_only,
@@ -132,7 +132,11 @@ def _int8da_int8w_api(mod):
132132

133133
def _int4wo_api(mod):
134134
if TORCH_VERSION_AT_LEAST_2_4:
135-
quantize_(mod, int4_weight_only(), set_inductor_config=False)
135+
device_type = next(mod.parameters()).device
136+
if device_type == torch.device("cpu"):
137+
quantize_(mod, int4_weight_only(layout=Int4CPULayout()), set_inductor_config=False)
138+
else:
139+
quantize_(mod, int4_weight_only(), set_inductor_config=False)
136140
if not TORCH_VERSION_AT_LEAST_2_5:
137141
unwrap_tensor_subclass(mod)
138142
else:
@@ -911,10 +915,16 @@ def test_int4_weight_only_quant_subclass_api(self, device, dtype):
911915
def test_int4_weight_only_quant_subclass_api_grouped(self, device, dtype):
912916
if dtype != torch.bfloat16:
913917
self.skipTest(f"Fails for {dtype}")
918+
layout_list = []
919+
if device == 'cuda':
920+
for inner_k_tiles in [4, 2]:
921+
layout_list.append(TensorCoreTiledLayout(inner_k_tiles=inner_k_tiles))
922+
elif device == 'cpu':
923+
layout_list.append(Int4CPULayout())
914924
for test_shape in ([(256, 256, 16)] + ([(256, 256, 8)] if device=='cuda' else [])):
915925
for groupsize in [64, 32]:
916-
for inner_k_tiles in [4, 2]:
917-
kwargs = {"groupsize": groupsize, "layout": TensorCoreTiledLayout(inner_k_tiles=inner_k_tiles)}
926+
for layout in layout_list:
927+
kwargs = {"groupsize": groupsize, "layout": layout}
918928

919929
def api(mod):
920930
kwargs_copy = kwargs.copy()

torchao/dtypes/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
PlainLayout,
88
SemiSparseLayout,
99
TensorCoreTiledLayout,
10+
Int4CPULayout,
1011
to_affine_quantized_floatx,
1112
to_affine_quantized_floatx_static,
1213
# experimental, will be merged into floatx in the future

torchao/dtypes/affine_quantized_tensor.py

Lines changed: 229 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -630,6 +630,11 @@ def extra_repr(self):
630630
return f"inner_k_tiles={self.inner_k_tiles}"
631631

632632

633+
@dataclass(frozen=True)
634+
class Int4CPULayout(Layout):
635+
def pre_process(self, input: torch.Tensor) -> torch.Tensor:
636+
return input
637+
633638
@dataclass(frozen=True)
634639
class Float8Layout(Layout):
635640
mm_config: Optional[Float8MMConfig] = None
@@ -1616,6 +1621,230 @@ def get_layout(self) -> Layout:
16161621
return self._layout
16171622

16181623

1624+
@register_layout(Int4CPULayout)
1625+
class Int4CPUAQTTensorImpl(AQTTensorImpl):
1626+
"""
1627+
TensorImpl for int4 CPU layout for affine quantized tensor, this is for int4 only,
1628+
used by tinygemm kernels `_weight_int4pack_mm`
1629+
1630+
It stores the original tensor of dimension [n][k] (int32 dtype) as packed weight of 2-d tensor of
1631+
dimension: [n][k / 2] (uint8 dtype)
1632+
(unpacked Tensor shape is n * k)
1633+
1634+
Note: we also pack scale and zero point together here for tinygemm kernel
1635+
1636+
Note: technically Int4 CPU layout should be the layout for the underlying packed weight
1637+
(int Tensor) but since the scale and zero_point are also packed into the same tensor here which is not used
1638+
in plain layout, we just created a layout for AQT right now, this could be improved if we split out
1639+
int4 aqt into a separate tensor subclass
1640+
1641+
fields:
1642+
packed_weight (torch.Tensor): the 2-d packed tensor in a Int4 CPU layout
1643+
scale_and_zero (torch.Tensor): the combined scale Tensor used to map between floating point tensor to quantized tensor and zero_point Tensor
1644+
"""
1645+
1646+
def __new__(
1647+
cls,
1648+
packed_weight: torch.Tensor,
1649+
scale_and_zero: torch.Tensor,
1650+
transposed: bool,
1651+
_layout: Layout,
1652+
):
1653+
kwargs = {}
1654+
kwargs["device"] = packed_weight.device
1655+
kwargs["layout"] = (
1656+
kwargs.get("layout")
1657+
if kwargs.get("layout", False)
1658+
else packed_weight.layout
1659+
)
1660+
kwargs["dtype"] = packed_weight.dtype
1661+
kwargs["requires_grad"] = False
1662+
shape = packed_weight.shape
1663+
return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined]
1664+
1665+
def __init__(
1666+
self,
1667+
packed_weight: torch.Tensor,
1668+
scale_and_zero: torch.Tensor,
1669+
transposed: bool,
1670+
_layout: Layout,
1671+
):
1672+
self.packed_weight = packed_weight
1673+
self.scale_and_zero = scale_and_zero
1674+
self.transposed = False
1675+
self._layout = _layout
1676+
1677+
def __tensor_flatten__(self):
1678+
return ["packed_weight", "scale_and_zero"], [self.transposed, self._layout]
1679+
1680+
@classmethod
1681+
def __tensor_unflatten__(
1682+
cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride
1683+
):
1684+
packed_weight, scale_and_zero = (
1685+
tensor_data_dict["packed_weight"],
1686+
tensor_data_dict["scale_and_zero"],
1687+
)
1688+
(
1689+
transposed,
1690+
_layout,
1691+
) = tensor_attributes
1692+
return cls(packed_weight, scale_and_zero, transposed, _layout)
1693+
1694+
@classmethod
1695+
def from_plain(
1696+
cls,
1697+
int_data: torch.Tensor,
1698+
scale: torch.Tensor,
1699+
zero_point: Optional[torch.Tensor],
1700+
_layout: Layout,
1701+
):
1702+
assert isinstance(_layout, Int4CPULayout)
1703+
1704+
assert (
1705+
int_data.dtype == torch.int32
1706+
), "torch.ops.aten._convert_weight_to_int4pack_for_cpu expects `int32` dtype"
1707+
packed_weight = torch.ops.aten._convert_weight_to_int4pack_for_cpu(
1708+
int_data, 1 # TODO:remove
1709+
)
1710+
scale = scale.reshape(int_data.shape[0], -1)
1711+
zero_point = zero_point.reshape(int_data.shape[0], -1)
1712+
1713+
scale_and_zero = pack_tinygemm_scales_and_zeros(scale, zero_point)
1714+
return cls(packed_weight, scale_and_zero, False, _layout)
1715+
1716+
def to(self, *args, **kwargs):
1717+
kwargs = self._get_to_kwargs(*args, **kwargs)
1718+
device = kwargs["device"]
1719+
return self.__class__(
1720+
self.packed_weight.to(device),
1721+
self.scale_and_zero.to(device),
1722+
self.transposed,
1723+
self._layout,
1724+
)
1725+
1726+
def _apply_fn_to_data(self, fn):
1727+
# self.packed_weight = fn(self.packed_weight)
1728+
# self.scale_and_zero = fn(self.scale_and_zero)
1729+
# return self
1730+
return self.__class__(
1731+
fn(self.packed_weight),
1732+
fn(self.scale_and_zero),
1733+
self.transposed,
1734+
self._layout,
1735+
)
1736+
1737+
@classmethod
1738+
def __torch_dispatch__(cls, func, types, args, kwargs):
1739+
kwargs = {} if kwargs is None else kwargs
1740+
1741+
if func is aten.detach.default:
1742+
return return_and_correct_aliasing(
1743+
func, args, kwargs, args[0]._apply_fn_to_data(torch.detach)
1744+
)
1745+
1746+
if func is aten.clone.default:
1747+
return return_and_correct_aliasing(
1748+
func, args, kwargs, args[0]._apply_fn_to_data(torch.clone)
1749+
)
1750+
1751+
if func is aten.t.default:
1752+
"""we don't need to repack the weight and just rely on external
1753+
shape being changed and record the status of transpose/no-transpose
1754+
"""
1755+
transposed = Int4CPUAQTTensorImpl(
1756+
args[0].packed_weight,
1757+
args[0].scale_and_zero,
1758+
not args[0].transposed,
1759+
args[0]._layout,
1760+
)
1761+
return return_and_correct_aliasing(func, args, kwargs, transposed)
1762+
1763+
if func is aten.slice.Tensor:
1764+
self, dim, start, end, step = fill_defaults(args, 5, [0, None, None, 1])
1765+
if dim == 0:
1766+
int_data, scale, zero_point = self.get_plain()
1767+
int_data = aten.slice.Tensor(int_data, dim, start, end, step)
1768+
# this is to handle padding
1769+
int_data = self._layout.post_process(int_data)
1770+
sliced = self.from_plain(int_data, scale, zero_point, self._layout)
1771+
return return_and_correct_aliasing(func, args, kwargs, sliced)
1772+
elif dim == 1:
1773+
int_data, scale, zero_point = self.get_plain()
1774+
assert step == 1, "Only step == 1 is supported in slicing right now"
1775+
data_len = int_data.shape[dim]
1776+
scale_len = scale.shape[dim]
1777+
ratio = data_len / scale_len
1778+
start_scale = int(start / ratio)
1779+
end_scale = int(end / ratio)
1780+
1781+
int_data = aten.slice.Tensor(int_data, dim, start, end, step)
1782+
# this is to handle padding
1783+
int_data = self._layout.post_process(int_data)
1784+
scale = aten.slice.Tensor(scale, dim, start_scale, end_scale, step)
1785+
zero_point = aten.slice.Tensor(
1786+
zero_point, dim, start_scale, end_scale, step
1787+
)
1788+
sliced = self.from_plain(int_data, scale, zero_point, self._layout)
1789+
return sliced
1790+
else:
1791+
raise NotImplementedError(
1792+
f"Int4CPUAQTTensorImpl dispatch: attempting to run {func}, with dim={dim}, that is not supported"
1793+
)
1794+
1795+
raise NotImplementedError(
1796+
f"Int4CPUAQTTensorImpl dispatch: attempting to run {func}, this is not supported"
1797+
)
1798+
1799+
__torch_function__ = torch._C._disabled_torch_function_impl
1800+
1801+
def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
1802+
from torchao.quantization.quant_primitives import (
1803+
ZeroPointDomain,
1804+
quantize_affine,
1805+
)
1806+
from torchao.quantization.utils import unpack_tinygemm_scales_and_zeros
1807+
1808+
scale, zero = unpack_tinygemm_scales_and_zeros(self.scale_and_zero)
1809+
1810+
cur_shape = self.shape
1811+
assert len(cur_shape) == 2
1812+
original_shape = (cur_shape[0], cur_shape[1] * 2)
1813+
eye_shape = original_shape[1]
1814+
groupsize = int(original_shape[1] / scale.shape[-2])
1815+
block_size = (1, groupsize)
1816+
device = self.device
1817+
original_dtype = torch.bfloat16
1818+
target_dtype = torch.int32
1819+
quant_min = 0
1820+
quant_max = 15
1821+
zero_point_domain = ZeroPointDomain.FLOAT
1822+
assert len(block_size) == 2 and block_size[0] == 1
1823+
dequantized = torch.ops.aten._weight_int4pack_mm_for_cpu(
1824+
torch.eye(eye_shape, device=device, dtype=original_dtype),
1825+
self.packed_weight,
1826+
groupsize,
1827+
self.scale_and_zero,
1828+
)
1829+
dequantized = dequantized.t().contiguous()
1830+
# TODO: move this to `unpack_tinygemm_scales_and_zeros`?
1831+
scale = scale.reshape(scale.shape[:-1]).contiguous()
1832+
zero = zero.reshape(zero.shape[:-1]).contiguous()
1833+
int_data = quantize_affine(
1834+
dequantized,
1835+
block_size,
1836+
scale,
1837+
zero,
1838+
target_dtype,
1839+
quant_min,
1840+
quant_max,
1841+
zero_point_domain,
1842+
)
1843+
return int_data, scale, zero
1844+
1845+
def get_layout(self) -> Layout:
1846+
return self._layout
1847+
16191848
#####################################################
16201849
# torch functional and aten operator implementation #
16211850
#####################################################

torchao/quantization/subclass.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -458,12 +458,20 @@ def _quantized_op(act_mat, w_qtensor, bias):
458458
act_mat = torch.nn.functional.pad(act_mat, (0, pad_size - act_mat.shape[-1]))
459459

460460
# matmul
461-
y = aten._weight_int4pack_mm(
462-
act_mat.contiguous(),
463-
w_qtensor.int_data,
464-
w_qtensor.groupsize,
465-
w_qtensor.scales_and_zeros,
466-
)
461+
if act_mat.device == torch.device("cpu"):
462+
y = aten._weight_int4pack_mm_for_cpu(
463+
act_mat.contiguous(),
464+
w_qtensor.int_data,
465+
w_qtensor.groupsize,
466+
w_qtensor.scales_and_zeros,
467+
)
468+
else:
469+
y = aten._weight_int4pack_mm(
470+
act_mat.contiguous(),
471+
w_qtensor.int_data,
472+
w_qtensor.groupsize,
473+
w_qtensor.scales_and_zeros,
474+
)
467475

468476
# remove out_feature padding
469477
orig_out_features = (
@@ -609,5 +617,8 @@ def to_qtensor_components(cls, input_float, groupsize=128, inner_k_tiles=8):
609617
input_int4x8, scales_and_zeros = groupwise_affine_quantize_tensor(
610618
input_float, 4, groupsize, dtype=input_float.dtype
611619
)
612-
int_data = aten._convert_weight_to_int4pack(input_int4x8, inner_k_tiles)
620+
if input_float.device == torch.device("cpu"):
621+
int_data = aten._convert_weight_to_int4pack_for_cpu(input_int4x8, inner_k_tiles)
622+
else:
623+
int_data = aten._convert_weight_to_int4pack(input_int4x8, inner_k_tiles)
613624
return int_data, scales_and_zeros, False, groupsize, inner_k_tiles

0 commit comments

Comments
 (0)