Skip to content

Commit ce76e4b

Browse files
committed
Refactor layout implementation
Summary: TODO Test Plan: TODO Reviewers: Subscribers: Tasks: Tags:
1 parent 05038a1 commit ce76e4b

File tree

4 files changed

+109
-56
lines changed

4 files changed

+109
-56
lines changed

torchao/dtypes/__init__.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,21 @@
11
from .nf4tensor import NF4Tensor, to_nf4
22
# from ..prototype.dtypes.uint2 import UInt2Tensor, BitnetTensor
33
from .uint4 import UInt4Tensor
4-
from .affine_quantized_tensor import AffineQuantizedTensor, to_affine_quantized
4+
from .affine_quantized_tensor import (
5+
AffineQuantizedTensor,
6+
to_affine_quantized,
7+
LayoutType,
8+
PlainLayoutType,
9+
TensorCoreTiledLayoutType,
10+
)
511

612
__all__ = [
713
"NF4Tensor",
814
"to_nf4",
915
"UInt4Tensor"
1016
"AffineQuantizedTensor",
1117
"to_affine_quantized",
18+
"LayoutType",
19+
"PlainLayoutType",
20+
"TensorCoreTiledLayoutType",
1221
]

torchao/dtypes/affine_quantized_tensor.py

Lines changed: 69 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,32 @@
2020
_ATEN_OP_OR_TORCH_FN_TABLE,
2121
_register_layout_cls,
2222
_get_layout_tensor_constructor,
23+
LayoutType,
2324
)
25+
from typing import ClassVar
26+
from dataclasses import dataclass
2427

2528
aten = torch.ops.aten
2629

30+
@dataclass(frozen=True)
31+
class PlainLayoutType(LayoutType):
32+
pass
33+
34+
@dataclass(frozen=True)
35+
class TensorCoreTiledLayoutType(LayoutType):
36+
inner_k_tiles: int = 8
37+
38+
def pad_input(self, input: torch.Tensor) -> torch.Tensor:
39+
orig_out_features, orig_in_features = input.shape
40+
in_features = find_multiple(orig_in_features, 1024)
41+
out_features = find_multiple(orig_out_features, 8)
42+
input = torch.nn.functional.pad(
43+
input,
44+
(0, in_features - orig_in_features, 0, out_features - orig_out_features),
45+
)
46+
return input
47+
48+
2749
def _aqt_is_int8(aqt):
2850
"""Check if an AffineQuantizedTensor is int8 quantized Tensor"""
2951
return (
@@ -52,9 +74,6 @@ class AQTLayout(torch.Tensor):
5274
"""
5375
Base class for the layout tensor for `AffineQuantizedTensor`
5476
"""
55-
# this should be set for each layout class during registration
56-
extended_layout: Optional[str] = None
57-
5877
def get_plain() -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
5978
pass
6079

@@ -64,6 +83,7 @@ def from_plain(
6483
int_data: torch.Tensor,
6584
scale: torch.Tensor,
6685
zero_point: torch.Tensor,
86+
layout_type: LayoutType,
6787
):
6888
pass
6989

@@ -194,30 +214,16 @@ def from_float(
194214
zero_point_dtype: Optional[torch.dtype] = None,
195215
preserve_zero: bool = True,
196216
zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT,
197-
extended_layout: str = "plain",
198-
# TODO: this is only for "tensor_core_tiled", need to figure out
199-
# the proper API for this arg
200-
inner_k_tiles: Optional[int] = None,
217+
layout_type: LayoutType = PlainLayoutType(),
201218
):
202219
original_shape = input_float.shape
203-
if extended_layout == "tensor_core_tiled":
204-
orig_out_features, orig_in_features = input_float.shape
205-
in_features = find_multiple(orig_in_features, 1024)
206-
out_features = find_multiple(orig_out_features, 8)
207-
input_float = torch.nn.functional.pad(
208-
input_float,
209-
(0, in_features - orig_in_features, 0, out_features - orig_out_features),
210-
)
220+
input_float = layout_type.pad_input(input_float)
211221

212222
scale, zero_point = choose_qparams_affine(input_float, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, scale_dtype, zero_point_dtype, preserve_zero, zero_point_domain)
213223
int_data = quantize_affine(input_float, block_size, scale, zero_point, target_dtype, quant_min, quant_max, zero_point_domain)
214224

215-
layout_cls_ctr = get_layout_tensor_constructor(extended_layout)
216-
# TODO: this is temporary, need to come up with the proper UX
217-
if extended_layout == "tensor_core_tiled":
218-
layout_tensor = layout_cls_ctr(int_data, scale, zero_point, inner_k_tiles)
219-
else:
220-
layout_tensor = layout_cls_ctr(int_data, scale, zero_point)
225+
layout_tensor_ctr = get_layout_tensor_constructor(type(layout_type))
226+
layout_tensor = layout_tensor_ctr(int_data, scale, zero_point, layout_type)
221227
return cls(
222228
layout_tensor,
223229
block_size,
@@ -229,8 +235,8 @@ def from_float(
229235
)
230236

231237
@property
232-
def extended_layout(self) -> str:
233-
return self.layout_tensor.extended_layout
238+
def layout_type(self) -> str:
239+
return self.layout_tensor.layout_type
234240

235241
@classmethod
236242
def __torch_function__(cls, func, types, args=(), kwargs=None):
@@ -308,13 +314,13 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
308314
def implements(aten_ops_or_torch_fn):
309315
return _implements(AffineQuantizedTensor, aten_ops_or_torch_fn)
310316

311-
def register_layout_cls(extended_layout: str):
312-
return _register_layout_cls(AffineQuantizedTensor, extended_layout)
317+
def register_layout_cls(layout_type_class: type(LayoutType)):
318+
return _register_layout_cls(AffineQuantizedTensor, layout_type_class)
313319

314-
def get_layout_tensor_constructor(extended_layout: str):
315-
return _get_layout_tensor_constructor(AffineQuantizedTensor, extended_layout)
320+
def get_layout_tensor_constructor(layout_type_class: type(LayoutType)):
321+
return _get_layout_tensor_constructor(AffineQuantizedTensor, layout_type_class)
316322

317-
@register_layout_cls("plain")
323+
@register_layout_cls(PlainLayoutType)
318324
class PlainAQTLayout(AQTLayout):
319325
"""
320326
Layout storage class for plain layout for affine quantized tensor, it stores int_data, scale, zero_point
@@ -330,6 +336,7 @@ def __new__(
330336
int_data: torch.Tensor,
331337
scale: torch.Tensor,
332338
zero_point: torch.Tensor,
339+
layout_type: LayoutType,
333340
):
334341
kwargs = {}
335342
kwargs["device"] = int_data.device
@@ -346,34 +353,39 @@ def __init__(
346353
int_data: torch.Tensor,
347354
scale: torch.Tensor,
348355
zero_point: torch.Tensor,
356+
layout_type: LayoutType,
349357
):
350358
self.int_data = int_data
351359
self.scale = scale
352360
self.zero_point = zero_point
361+
self.layout_type = layout_type
353362

354363
def __tensor_flatten__(self):
355-
return ["int_data", "scale", "zero_point"], []
364+
return ["int_data", "scale", "zero_point"], [self.layout_type]
356365

357366
@classmethod
358367
def __tensor_unflatten__(
359368
cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride
360369
):
361370
int_data, scale, zero_point = tensor_data_dict["int_data"], tensor_data_dict["scale"], tensor_data_dict["zero_point"]
362-
return cls(int_data, scale, zero_point)
371+
layout_type, = tensor_attributes
372+
return cls(int_data, scale, zero_point, layout_type)
363373

364374
def to(self, *args, **kwargs):
365375
kwargs = self._get_to_kwargs(*args, **kwargs)
366376
return self.__class__(
367377
self.int_data.to(kwargs["device"]),
368378
self.scale.to(kwargs["device"]),
369379
self.zero_point.to(kwargs["device"]),
380+
self.layout_type,
370381
)
371382

372383
def _apply_fn_to_data(self, fn):
373384
return self.__class__(
374385
fn(self.int_data),
375386
fn(self.scale),
376387
fn(self.zero_point),
388+
self.layout_type,
377389
)
378390

379391
@classmethod
@@ -407,10 +419,12 @@ def from_plain(
407419
int_data: torch.Tensor,
408420
scale: torch.Tensor,
409421
zero_point: torch.Tensor,
422+
layout_type: LayoutType,
410423
):
411-
return cls(int_data, scale, zero_point)
424+
assert isinstance(layout_type, PlainLayoutType)
425+
return cls(int_data, scale, zero_point, layout_type)
412426

413-
@register_layout_cls("tensor_core_tiled")
427+
@register_layout_cls(TensorCoreTiledLayoutType)
414428
class TensorCoreTiledAQTLayout(AQTLayout):
415429
"""
416430
Layout storage class for tensor_core_tiled layout for affine quantized tensor, this is for int4 only,
@@ -427,6 +441,7 @@ def __new__(
427441
packed_weight: torch.Tensor,
428442
scale_and_zero: torch.Tensor,
429443
transposed: bool,
444+
layout_type: LayoutType,
430445
):
431446
kwargs = {}
432447
kwargs["device"] = packed_weight.device
@@ -443,29 +458,38 @@ def __init__(
443458
packed_weight: torch.Tensor,
444459
scale_and_zero: torch.Tensor,
445460
transposed: bool,
461+
layout_type: LayoutType,
446462
):
447463
self.packed_weight = packed_weight
448464
self.scale_and_zero = scale_and_zero
449465
self.transposed = False
466+
self.layout_type = layout_type
450467

451468
def __tensor_flatten__(self):
452-
return ["packed_weight", "scale_and_zero"], [self.transposed]
469+
return ["packed_weight", "scale_and_zero"], [self.transposed, self.layout_type]
453470

454471
@classmethod
455472
def __tensor_unflatten__(
456473
cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride
457474
):
458475
packed_weight, scale_and_zero = tensor_data_dict["packed_weight"], tensor_data_dict["scale_and_zero"]
459-
transposed, = tensor_attributes
460-
return cls(packed_weight, scale_and_zero, transposed)
476+
transposed, layout_type, = tensor_attributes
477+
return cls(packed_weight, scale_and_zero, transposed, layout_type)
461478

462479
@classmethod
463-
def from_plain(cls, int_data, scale, zero_point, inner_k_tiles=8):
464-
packed_weight = torch.ops.aten._convert_weight_to_int4pack(int_data.to(torch.int32), inner_k_tiles)
480+
def from_plain(
481+
cls,
482+
int_data: torch.Tensor,
483+
scale: torch.Tensor,
484+
zero_point: torch.Tensor,
485+
layout_type: LayoutType
486+
):
487+
assert isinstance(layout_type, TensorCoreTiledLayoutType)
488+
packed_weight = torch.ops.aten._convert_weight_to_int4pack(int_data.to(torch.int32), layout_type.inner_k_tiles)
465489
scale = scale.reshape(int_data.shape[0], -1)
466490
zero_point = zero_point.reshape(int_data.shape[0], -1)
467491
scale_and_zero = pack_tinygemm_scales_and_zeros(scale, zero_point)
468-
return cls(packed_weight, scale_and_zero, False)
492+
return cls(packed_weight, scale_and_zero, False, layout_type)
469493

470494
def to(self, *args, **kwargs):
471495
kwargs = self._get_to_kwargs(*args, **kwargs)
@@ -475,7 +499,8 @@ def to(self, *args, **kwargs):
475499
return self.__class__(
476500
self.packed_weight.to(device),
477501
self.scale_and_zero.to(device),
478-
self.transposed
502+
self.transposed,
503+
self.layout_type,
479504
)
480505

481506
def _apply_fn_to_data(self, fn):
@@ -485,7 +510,7 @@ def _apply_fn_to_data(self, fn):
485510

486511
def __repr__(self):
487512
int_data, scale, zero_point = self.get_plain()
488-
return f"TensorCoreTiledAQTLayout(int_data={int_data}, scale={scale}, zero_point={zero_point})"
513+
return f"TensorCoreTiledAQTLayout(int_data={int_data}, scale={scale}, zero_point={zero_point}, {self.layout_type})"
489514

490515
@classmethod
491516
def __torch_dispatch__(cls, func, types, args, kwargs):
@@ -563,8 +588,8 @@ def _quantized_linear_op(input_tensor, weight_qtensor, bias):
563588
is_cuda and
564589
input_is_int8 and
565590
input_tensor.dtype == weight_qtensor.dtype and
566-
input_tensor.extended_layout == "plain" and
567-
weight_qtensor.extended_layout == "plain"
591+
isinstance(input_tensor.layout_type, PlainLayoutType) and
592+
isinstance(weight_qtensor.layout_type, PlainLayoutType)
568593
):
569594
#
570595
# 1. do the matrix form of dot(X_i, W_j)
@@ -606,7 +631,7 @@ def _quantized_linear_op(input_tensor, weight_qtensor, bias):
606631
weight_qtensor.dtype == torch.bfloat16 and
607632
len(weight_qtensor.shape) == 2 and
608633
weight_qtensor.zero_point_domain == ZeroPointDomain.FLOAT and
609-
weight_qtensor.extended_layout == "tensor_core_tiled"
634+
isinstance(weight_qtensor.layout_type, TensorCoreTiledLayoutType)
610635
):
611636
assert weight_qtensor.block_size[0] == 1, f"Requires groupwise quantization, got block_size: {block_size}"
612637
assert input_tensor.shape[-1] == weight_qtensor.shape[1], (
@@ -649,7 +674,7 @@ def _quantized_linear_op(input_tensor, weight_qtensor, bias):
649674
weight_qtensor.block_size[0] == 1 and
650675
weight_qtensor.block_size[1] == weight_qtensor.shape[1] and
651676
weight_qtensor.zero_point_domain == ZeroPointDomain.INT and
652-
weight_qtensor.extended_layout == "plain"
677+
isinstance(weight_qtensor.layout_type, PlainLayoutType)
653678
):
654679
# TODO: enable cpu and mps efficient path
655680
# per channel int8 weight only quantizated mm

torchao/dtypes/utils.py

Lines changed: 27 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1+
import torch
12
from typing import Dict, Callable
23
from collections import defaultdict
34
import functools
5+
from dataclasses import dataclass
46

57
"""
68
torch_function and torch_dispatch operator dispatch registrations
@@ -28,38 +30,53 @@ def wrapper(*args, **kwargs):
2830
return func
2931
return decorator
3032

33+
"""
34+
Base class for different LayoutType, should not be instantiated directly
35+
"""
36+
@dataclass(frozen=True)
37+
class LayoutType:
38+
def pad_input(self, input: torch.Tensor) -> torch.Tensor:
39+
return input
40+
3141
"""
3242
layout tensor constructor registration for different tensor subclassesa
3343
3444
first key is a tensor subclass type like AffineQuantizedTensor
3545
second key is an extended layout string, like tensor_core_tiled
3646
value is a constructor for the LayoutTensor class, e.g. TensorCoreTiledAQTLayout.from_plain
3747
"""
38-
_LAYOUT_CONSTRUCTOR_TABLE: Dict[Callable, Dict[str, Callable]] = defaultdict(dict)
48+
_LAYOUT_CONSTRUCTOR_TABLE: Dict[Callable, Dict[type(LayoutType), Callable]] = defaultdict(dict)
3949

40-
def _register_layout_cls(cls: Callable, extended_layout: str):
50+
def _register_layout_cls(cls: Callable, layout_type_class: type(LayoutType)):
4151
"""Helper function for layout registrations, this is used to implement
4252
register_layout_cls decorator for each tensor subclass, see aqt.py for example usage
4353
4454
Args:
4555
cls: Tensor subclass type
46-
extended_layout: string name for the layout type
56+
layout_type_class: the class type of subclass of `LayoutType`, e.g. `PlainLayoutType`
4757
4858
Returns:
4959
a decorator that registers the layout tensor constructor in the table
5060
"""
5161
def decorator(layout_cls):
52-
layout_cls.extended_layout = extended_layout
53-
_LAYOUT_CONSTRUCTOR_TABLE[cls][extended_layout] = layout_cls.from_plain
62+
_LAYOUT_CONSTRUCTOR_TABLE[cls][layout_type_class] = layout_cls.from_plain
5463
return layout_cls
5564
return decorator
5665

57-
def _get_layout_tensor_constructor(cls: Callable, extended_layout: str) -> Callable:
58-
"""Get Layout class constructor (LayoutClass.from_plain) for `cls` based on `extended_layout`
66+
def _get_layout_tensor_constructor(cls: Callable, layout_type_class: type(LayoutType)) -> Callable:
67+
"""Get Layout class constructor (LayoutClass.from_plain) for `cls` based on `layout_type_class`
68+
`layout_type_class` means the class type of subclass of `LayoutType`, e.g. `PlainLayoutType`
69+
70+
Args:
71+
cls: Tensor subclass type
72+
layout_type_class: the class type of subclass of `LayoutType`, e.g. `PlainLayoutType`
73+
74+
Returns:
75+
layout tensor subclass constructor for the layout_type_class
5976
"""
6077
if cls not in _LAYOUT_CONSTRUCTOR_TABLE:
6178
raise ValueError(f"no registered layout class constructor for: {cls}")
62-
if extended_layout not in _LAYOUT_CONSTRUCTOR_TABLE[cls]:
63-
raise ValueError(f"extended_layout: {extended_layout} is not supported yet for {cls}")
79+
if layout_type_class not in _LAYOUT_CONSTRUCTOR_TABLE[cls]:
80+
raise ValueError(f"layout_name: {layout_type_class} is not supported yet for {cls}")
6481

65-
return _LAYOUT_CONSTRUCTOR_TABLE[cls][extended_layout]
82+
return _LAYOUT_CONSTRUCTOR_TABLE[cls][layout_type_class]

torchao/quantization/quant_api.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -380,6 +380,7 @@ def int4_weight_only(group_size=128, inner_k_tiles=8):
380380
def apply_int4_weight_only_quant(weight):
381381
# avoid circular dep
382382
from torchao.dtypes import to_affine_quantized
383+
from torchao.dtypes import TensorCoreTiledLayoutType
383384

384385
mapping_type = MappingType.ASYMMETRIC
385386
block_size = (1, group_size)
@@ -390,7 +391,8 @@ def apply_int4_weight_only_quant(weight):
390391
preserve_zero = False
391392
zero_point_dtype = torch.bfloat16
392393
zero_point_domain = ZeroPointDomain.FLOAT
393-
return to_affine_quantized(weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, zero_point_dtype=zero_point_dtype, preserve_zero=preserve_zero, zero_point_domain=zero_point_domain, extended_layout="tensor_core_tiled", inner_k_tiles=inner_k_tiles)
394+
layout_type = TensorCoreTiledLayoutType(inner_k_tiles=inner_k_tiles)
395+
return to_affine_quantized(weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, zero_point_dtype=zero_point_dtype, preserve_zero=preserve_zero, zero_point_domain=zero_point_domain, layout_type=layout_type)
394396

395397
return apply_int4_weight_only_quant
396398

0 commit comments

Comments
 (0)