Skip to content

Commit 4d232b8

Browse files
committed
Add support for AQTStorage and PlainAQTStorage
Summary: Today `AffineQuantizedTensor` has hardcoded storage format of `int_data`, `scale`, `zero_point`. But this does not work if we want to support packed weight. In this PR, we added support to hide the storage details for `AffineQuantizedTensor` in a family of tensor subclasses, all should inherit from the base Storage type: `AQTStorage` (affine quantized tensor storage) This PR just added support for a plain storage tensor (`PlainAQTStorage`) that stores `int_data`, `scale` and `zero_point` tensors directly, in the next PR we'll also support storing packed weight (result of `torch.ops.aten._convert_weight_to_int4pack`) in a different type of `AQTStorage`. `AffineQuantizedTensor` will have the following: - storage_tensor: AQTStorage (can store data of different storage formats) - storage_layout: str (a string represents the type of storage_tensor we have, can be used in dispatch) Test Plan: python test/quantization/test_quant_api.py Reviewers: Subscribers: Tasks: Tags:
1 parent 90b5e17 commit 4d232b8

File tree

1 file changed

+153
-40
lines changed

1 file changed

+153
-40
lines changed

torchao/dtypes/aqt.py

Lines changed: 153 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,14 @@
1818
def _aqt_is_int8(aqt):
1919
"""Check if an AffineQuantizedTensor is int8 quantized Tensor"""
2020
return (
21-
aqt.int_data.dtype == torch.int8 and
21+
aqt.storage_tensor.dtype == torch.int8 and
2222
aqt.quant_min is None or aqt.quant_min == -128 and
2323
aqt.quant_max is None or aqt.quant_max == 127
2424
)
2525

2626
def _aqt_is_int8_reduced_range(aqt):
2727
return (
28-
aqt.int_data.dtype == torch.int8 and
28+
aqt.storage_tensor.dtype == torch.int8 and
2929
aqt.quant_min == -127 and
3030
aqt.quant_max is None or aqt.quant_max == 127
3131
)
@@ -34,7 +34,7 @@ def _aqt_is_uint4(aqt):
3434
"""Check if an AffineQuantizedTensor is uint4 quantized Tensor"""
3535
# TODO: use torch.uint4
3636
return (
37-
aqt.int_data.dtype == torch.int32 and
37+
aqt.storage_tensor.dtype == torch.int32 and
3838
aqt.quant_min is None or aqt.quant_min == 0 and
3939
aqt.quant_max is None or aqt.quant_max == 15
4040
)
@@ -69,6 +69,121 @@ def implements_aqt_aten_ops(aten_ops):
6969
def implements_aqt_torch_function(torch_function):
7070
return implements_torch_function(AffineQuantizedTensor, torch_function)
7171

72+
_STORAGE_LAYOUT_TO_AQT_STORAGE_CLS: Dict[str, Callable] = {}
73+
74+
def register_aqt_storage_cls(storage_layout: str):
75+
def decorator(storage_cls):
76+
storage_cls.storage_layout = storage_layout
77+
_STORAGE_LAYOUT_TO_AQT_STORAGE_CLS[storage_layout] = storage_cls
78+
return storage_cls
79+
return decorator
80+
81+
def get_aqt_storage_cls(storage_layout: str) -> Callable:
82+
if storage_layout not in _STORAGE_LAYOUT_TO_AQT_STORAGE_CLS:
83+
raise ValueError(f"storage layout: {storage_layout} is not supported yet")
84+
return _STORAGE_LAYOUT_TO_AQT_STORAGE_CLS.get(storage_layout)
85+
86+
class AQTStorage(torch.Tensor):
87+
# this should be set for each storage class during registration
88+
storage_layout: Optional[str] = None
89+
90+
def __init__(
91+
self,
92+
int_data: torch.Tensor,
93+
scale: torch.Tensor,
94+
zero_point: torch.Tensor,
95+
):
96+
pass
97+
98+
def get_plain() -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
99+
pass
100+
101+
@register_aqt_storage_cls("plain")
102+
class PlainAQTStorage(AQTStorage):
103+
def __new__(
104+
cls,
105+
int_data: torch.Tensor,
106+
scale: torch.Tensor,
107+
zero_point: torch.Tensor,
108+
):
109+
kwargs = {}
110+
kwargs["device"] = int_data.device
111+
kwargs["layout"] = (
112+
kwargs.get("layout") if kwargs.get("layout", False) else int_data.layout
113+
)
114+
kwargs["dtype"] = int_data.dtype
115+
kwargs["requires_grad"] = False
116+
shape = int_data.shape
117+
return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined]
118+
119+
def __init__(
120+
self,
121+
int_data: torch.Tensor,
122+
scale: torch.Tensor,
123+
zero_point: torch.Tensor,
124+
):
125+
self.int_data = int_data
126+
self.scale = scale
127+
self.zero_point = zero_point
128+
129+
def __tensor_flatten__(self):
130+
return ["int_data", "scale", "zero_point"], []
131+
132+
@classmethod
133+
def __tensor_unflatten__(
134+
cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride
135+
):
136+
int_data, scale, zero_point = tensor_data_dict["int_data"], tensor_data_dict["scale"], tensor_data_dict["zero_point"]
137+
return cls(int_data, scale, zero_point)
138+
139+
# TODO: dedup
140+
def _get_to_kwargs(self, *args, **kwargs):
141+
device, dtype, _, memory_format = torch._C._nn._parse_to(*args, **kwargs)
142+
device = self.device if device is None else device
143+
dtype = self.dtype if dtype is None else dtype
144+
memory_format = (
145+
memory_format if memory_format is not None else torch.preserve_format
146+
)
147+
kwargs = {
148+
"device": device,
149+
"dtype": dtype,
150+
"memory_format": memory_format,
151+
}
152+
return kwargs
153+
154+
def to(self, *args, **kwargs):
155+
kwargs = self._get_to_kwargs(*args, **kwargs)
156+
return self.__class__(
157+
self.int_data.to(kwargs["device"]),
158+
self.scale.to(kwargs["device"]),
159+
self.zero_point.to(kwargs["device"]),
160+
)
161+
162+
def _apply_fn_to_data(self, fn):
163+
return self.__class__(
164+
fn(self.int_data),
165+
fn(self.scale),
166+
fn(self.zero_point),
167+
)
168+
169+
@classmethod
170+
def __torch_dispatch__(cls, func, types, args, kwargs):
171+
kwargs = {} if kwargs is None else kwargs
172+
173+
if func is aten.detach.default:
174+
return return_and_correct_aliasing(
175+
func, args, kwargs, args[0]._apply_fn_to_data(torch.detach)
176+
)
177+
178+
raise NotImplementedError(
179+
f"PlainAQTStorage dispatch: attempting to run {func}, this is not supported"
180+
)
181+
182+
__torch_function__ = torch._C._disabled_torch_function_impl
183+
184+
def get_plain(self):
185+
return self.int_data, self.scale, self.zero_point
186+
72187

73188
class AffineQuantizedTensor(torch.Tensor):
74189
"""
@@ -103,9 +218,7 @@ class AffineQuantizedTensor(torch.Tensor):
103218
@staticmethod
104219
def __new__(
105220
cls,
106-
int_data: torch.Tensor,
107-
scale: torch.Tensor,
108-
zero_point: torch.Tensor,
221+
storage_tensor: AQTStorage,
109222
block_size: Tuple[int, ...],
110223
shape: torch.Size,
111224
quant_min: Optional[int] = None,
@@ -115,9 +228,9 @@ def __new__(
115228
strides=None,
116229
):
117230
kwargs = {}
118-
kwargs["device"] = int_data.device
231+
kwargs["device"] = storage_tensor.device
119232
kwargs["layout"] = (
120-
kwargs.get("layout") if kwargs.get("layout", False) else int_data.layout
233+
kwargs.get("layout") if kwargs.get("layout", False) else storage_tensor.layout
121234
)
122235
if dtype is None:
123236
dtype = scale.dtype
@@ -129,9 +242,7 @@ def __new__(
129242

130243
def __init__(
131244
self,
132-
int_data: torch.Tensor,
133-
scale: torch.Tensor,
134-
zero_point: torch.Tensor,
245+
storage_tensor: AQTStorage,
135246
block_size: Tuple[int, ...],
136247
shape: torch.Size,
137248
quant_min: Optional[int] = None,
@@ -140,9 +251,7 @@ def __init__(
140251
dtype=None,
141252
strides=None,
142253
):
143-
self.int_data = int_data
144-
self.scale = scale
145-
self.zero_point = zero_point
254+
self.storage_tensor = storage_tensor
146255
self.block_size = block_size
147256
self.quant_min = quant_min
148257
self.quant_max = quant_max
@@ -157,21 +266,20 @@ def __repr__(self):
157266
def dequantize(self, output_dtype=None):
158267
if output_dtype is None:
159268
output_dtype = self.dtype
160-
return dequantize_affine(self.int_data, self.block_size, self.scale, self.zero_point, self.int_data.dtype, self.quant_min, self.quant_max, self.zero_point_domain, output_dtype=output_dtype)
269+
int_data, scale, zero_point = self.storage_tensor.get_plain()
270+
return dequantize_affine(int_data, self.block_size, scale, zero_point, int_data.dtype, self.quant_min, self.quant_max, self.zero_point_domain, output_dtype=output_dtype)
161271

162272
def __tensor_flatten__(self):
163-
return ["int_data", "scale", "zero_point"], [self.block_size, self.shape, self.quant_min, self.quant_max, self.zero_point_domain, self.dtype]
273+
return ["storage_tensor"], [self.block_size, self.shape, self.quant_min, self.quant_max, self.zero_point_domain, self.dtype]
164274

165275
@classmethod
166276
def __tensor_unflatten__(
167277
cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride
168278
):
169-
int_data, scale, zero_point = tensor_data_dict["int_data"], tensor_data_dict["scale"], tensor_data_dict["zero_point"]
279+
storage_tensor = tensor_data_dict["storage_tensor"]
170280
block_size, shape, quant_min, quant_max, zero_point_domain, dtype = tensor_attributes
171281
return cls(
172-
int_data,
173-
scale,
174-
zero_point,
282+
storage_tensor,
175283
block_size,
176284
shape if outer_size is None else outer_size,
177285
quant_min,
@@ -195,13 +303,15 @@ def from_float(
195303
zero_point_dtype: Optional[torch.dtype] = None,
196304
preserve_zero: bool = True,
197305
zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT,
306+
storage_layout: str = "plain",
198307
):
199308
scale, zero_point = choose_qparams_affine(input_float, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, scale_dtype, zero_point_dtype, preserve_zero, zero_point_domain)
200309
int_data = quantize_affine(input_float, block_size, scale, zero_point, target_dtype, quant_min, quant_max, zero_point_domain)
310+
311+
storage_cls = get_aqt_storage_cls(storage_layout)
312+
storage_tensor = storage_cls(int_data, scale, zero_point)
201313
return cls(
202-
int_data,
203-
scale,
204-
zero_point,
314+
storage_tensor,
205315
block_size,
206316
input_float.shape,
207317
quant_min,
@@ -210,6 +320,10 @@ def from_float(
210320
dtype=input_float.dtype
211321
)
212322

323+
@property
324+
def storage_layout(self) -> str:
325+
return self.storage_tensor.storage_layout
326+
213327
@classmethod
214328
def __torch_function__(cls, func, types, args=(), kwargs=None):
215329
kwargs = {} if kwargs is None else kwargs
@@ -238,9 +352,7 @@ def _get_to_kwargs(self, *args, **kwargs):
238352
def to(self, *args, **kwargs):
239353
kwargs = self._get_to_kwargs(*args, **kwargs)
240354
return self.__class__(
241-
self.int_data.to(kwargs["device"]),
242-
self.scale.to(kwargs["device"]),
243-
self.zero_point.to(kwargs["device"]),
355+
self.storage_tensor.to(kwargs["device"]),
244356
self.block_size,
245357
self.shape,
246358
self.quant_min,
@@ -251,9 +363,7 @@ def to(self, *args, **kwargs):
251363

252364
def _apply_fn_to_data(self, fn):
253365
return self.__class__(
254-
fn(self.int_data),
255-
fn(self.scale),
256-
fn(self.zero_point),
366+
fn(self.storage_tensor),
257367
self.block_size,
258368
self.shape,
259369
self.quant_min,
@@ -308,7 +418,9 @@ def functional_linear(*args, **kwargs):
308418
if (
309419
is_cuda and
310420
input_is_int8 and
311-
input_tensor_dtype_is_expected
421+
input_tensor_dtype_is_expected and
422+
input_tensor.storage_layout == "plain" and
423+
weight_qtensor.storage_layout == "plain"
312424
):
313425
#
314426
# 1. do the matrix form of dot(X_i, W_j)
@@ -321,10 +433,10 @@ def functional_linear(*args, **kwargs):
321433
# value of a float 16, (which results in a value of inf even if multiplying
322434
# by the other scale would bring it within the expected range)
323435

324-
x_vals_int8 = input_tensor.int_data
325-
x_scales = input_tensor.scale
326-
w_vals_int8_t = weight_qtensor.int_data.contiguous().t()
327-
w_scales = weight_qtensor.scale
436+
x_vals_int8 = input_tensor.storage_tensor.int_data
437+
x_scales = input_tensor.storage_tensor.scale
438+
w_vals_int8_t = weight_qtensor.storage_tensor.int_data.contiguous().t()
439+
w_scales = weight_qtensor.storage_tensor.scale
328440
tmp = x_vals_int8.reshape(-1, x_vals_int8.shape[-1])
329441
y_dot_scaled = int_scaled_matmul(tmp, w_vals_int8_t, x_scales.reshape(-1, 1))
330442

@@ -344,22 +456,22 @@ def functional_linear(*args, **kwargs):
344456
# weight only quantization
345457
# TODO: enable cpu and mps path as well
346458
# TODO: make sure weight dimension matches the expectation of the int4mm kernel
347-
# TODO: move this to TinygemmAffineQuantizedTensor
348459
if (
349460
is_cuda and
350461
weight_is_uint4 and
351462
weight_qtensor.dtype == torch.bfloat16 and
352463
len(weight_qtensor.shape) == 2 and
353464
weight_qtensor.block_size[0] == 1 and
354-
weight_qtensor.zero_point_domain == ZeroPointDomain.FLOAT
465+
weight_qtensor.zero_point_domain == ZeroPointDomain.FLOAT and
466+
weight_qtensor.storage_layout == "plain"
355467
):
356468
# groupwise int4 quantization
357469
# TODO: currently doing packing on the fly, we'll need to figure out
358470
# the API to do packing before hand
359471
# TODO: expose the arg
360472
innerKTiles = 8
361-
packed_weight = torch.ops.aten._convert_weight_to_int4pack(weight_qtensor.int_data.to(torch.int32), innerKTiles)
362-
scales_and_zeros = pack_tinygemm_scales_and_zeros(weight_qtensor.scale, weight_qtensor.zero_point)
473+
packed_weight = torch.ops.aten._convert_weight_to_int4pack(weight_qtensor.storage_tensor.int_data.to(torch.int32), innerKTiles)
474+
scales_and_zeros = pack_tinygemm_scales_and_zeros(weight_qtensor.storage_tensor.scale, weight_qtensor.storage_tensor.zero_point)
363475
groupsize = weight_qtensor.block_size[-1]
364476
return torch.ops.aten._weight_int4pack_mm(input_tensor.contiguous(), packed_weight, groupsize, scales_and_zeros)
365477
elif (
@@ -368,11 +480,12 @@ def functional_linear(*args, **kwargs):
368480
len(weight_qtensor.shape) == 2 and
369481
len(weight_qtensor.block_size) == 2 and
370482
weight_qtensor.block_size[0] == 1 and
371-
weight_qtensor.block_size[1] == weight_qtensor.shape[1]
483+
weight_qtensor.block_size[1] == weight_qtensor.shape[1] and
484+
weight_qtensor.storage_layout == "plain"
372485
):
373486
# TODO: enable mps path as well
374487
# per channel int8 weight only quantizated mm
375-
return torch.ops.aten._weight_int8pack_mm(input_tensor.contiguous(), weight_qtensor.int_data, weight_qtensor.scale)
488+
return torch.ops.aten._weight_int8pack_mm(input_tensor.contiguous(), weight_qtensor.storage_tensor.int_data, weight_qtensor.storage_tensor.scale)
376489
else:
377490
weight_tensor = weight_qtensor.dequantize()
378491
return torch.nn.functional.linear(input_tensor, weight_tensor, bias)

0 commit comments

Comments
 (0)