-
Notifications
You must be signed in to change notification settings - Fork 293
Refactor layout implementation #491
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,12 +1,21 @@ | ||
from .nf4tensor import NF4Tensor, to_nf4 | ||
# from ..prototype.dtypes.uint2 import UInt2Tensor, BitnetTensor | ||
from .uint4 import UInt4Tensor | ||
from .affine_quantized_tensor import AffineQuantizedTensor, to_affine_quantized | ||
from .affine_quantized_tensor import ( | ||
AffineQuantizedTensor, | ||
to_affine_quantized, | ||
LayoutType, | ||
PlainLayoutType, | ||
TensorCoreTiledLayoutType, | ||
) | ||
|
||
__all__ = [ | ||
"NF4Tensor", | ||
"to_nf4", | ||
"UInt4Tensor" | ||
"AffineQuantizedTensor", | ||
"to_affine_quantized", | ||
"LayoutType", | ||
"PlainLayoutType", | ||
"TensorCoreTiledLayoutType", | ||
] |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -20,10 +20,35 @@ | |
_ATEN_OP_OR_TORCH_FN_TABLE, | ||
_register_layout_cls, | ||
_get_layout_tensor_constructor, | ||
LayoutType, | ||
) | ||
from typing import ClassVar | ||
from dataclasses import dataclass | ||
|
||
aten = torch.ops.aten | ||
|
||
@dataclass(frozen=True) | ||
class PlainLayoutType(LayoutType): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. comment or error that this shouldnt be instantiated directly There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this can be instantiated I think, are you talking about LayoutType? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I see I guess I'm a bit thrown off because a data classes primary goal is to store data wheras this class stores nothing and its really just a name There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would have instead done an enum like this
enums are also a class so you can override There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. sorry I don't follow the |
||
pass | ||
|
||
@dataclass(frozen=True) | ||
class TensorCoreTiledLayoutType(LayoutType): | ||
inner_k_tiles: int = 8 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @msaroufim see here, we have extra configurable arguments, it's not just a name so I'm not sure how enum would work here |
||
|
||
def pre_process(self, input: torch.Tensor) -> torch.Tensor: | ||
orig_out_features, orig_in_features = input.shape | ||
in_features = find_multiple(orig_in_features, 1024) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. where are the in and out numbers coming from? I constants like this were a function of the dtype as well There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is comes from tinygemm kernel I think, this layout only applies to uint4 dtype |
||
out_features = find_multiple(orig_out_features, 8) | ||
input = torch.nn.functional.pad( | ||
input, | ||
(0, in_features - orig_in_features, 0, out_features - orig_out_features), | ||
) | ||
return input | ||
|
||
def extra_repr(self): | ||
return f"inner_k_tiles={self.inner_k_tiles}" | ||
|
||
|
||
def _aqt_is_int8(aqt): | ||
"""Check if an AffineQuantizedTensor is int8 quantized Tensor""" | ||
return ( | ||
|
@@ -52,10 +77,10 @@ class AQTLayout(torch.Tensor): | |
""" | ||
Base class for the layout tensor for `AffineQuantizedTensor` | ||
""" | ||
# this should be set for each layout class during registration | ||
extended_layout: Optional[str] = None | ||
def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | ||
pass | ||
|
||
def get_plain() -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | ||
def get_layout_type(self) -> LayoutType: | ||
pass | ||
|
||
@classmethod | ||
|
@@ -64,9 +89,15 @@ def from_plain( | |
int_data: torch.Tensor, | ||
scale: torch.Tensor, | ||
zero_point: torch.Tensor, | ||
layout_type: LayoutType, | ||
): | ||
pass | ||
|
||
def __repr__(self): | ||
int_data, scale, zero_point = self.get_plain() | ||
layout_type = self.get_layout_type() | ||
return f"{self.__class__.__name__}(int_data={int_data}, scale={scale}, zero_point={zero_point}, layout_type={layout_type})" | ||
|
||
def _get_to_kwargs(self, *args, **kwargs): | ||
device, dtype, _, memory_format = torch._C._nn._parse_to(*args, **kwargs) | ||
device = self.device if device is None else device | ||
|
@@ -194,30 +225,17 @@ def from_float( | |
zero_point_dtype: Optional[torch.dtype] = None, | ||
preserve_zero: bool = True, | ||
zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT, | ||
extended_layout: str = "plain", | ||
# TODO: this is only for "tensor_core_tiled", need to figure out | ||
# the proper API for this arg | ||
inner_k_tiles: Optional[int] = None, | ||
layout_type: LayoutType = PlainLayoutType(), | ||
): | ||
original_shape = input_float.shape | ||
if extended_layout == "tensor_core_tiled": | ||
orig_out_features, orig_in_features = input_float.shape | ||
in_features = find_multiple(orig_in_features, 1024) | ||
out_features = find_multiple(orig_out_features, 8) | ||
input_float = torch.nn.functional.pad( | ||
input_float, | ||
(0, in_features - orig_in_features, 0, out_features - orig_out_features), | ||
) | ||
input_float = layout_type.pre_process(input_float) | ||
|
||
scale, zero_point = choose_qparams_affine(input_float, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, scale_dtype, zero_point_dtype, preserve_zero, zero_point_domain) | ||
int_data = quantize_affine(input_float, block_size, scale, zero_point, target_dtype, quant_min, quant_max, zero_point_domain) | ||
int_data = layout_type.post_process(int_data) | ||
|
||
layout_cls_ctr = get_layout_tensor_constructor(extended_layout) | ||
# TODO: this is temporary, need to come up with the proper UX | ||
if extended_layout == "tensor_core_tiled": | ||
layout_tensor = layout_cls_ctr(int_data, scale, zero_point, inner_k_tiles) | ||
else: | ||
layout_tensor = layout_cls_ctr(int_data, scale, zero_point) | ||
layout_tensor_ctr = get_layout_tensor_constructor(type(layout_type)) | ||
layout_tensor = layout_tensor_ctr(int_data, scale, zero_point, layout_type) | ||
return cls( | ||
layout_tensor, | ||
block_size, | ||
|
@@ -229,8 +247,8 @@ def from_float( | |
) | ||
|
||
@property | ||
def extended_layout(self) -> str: | ||
return self.layout_tensor.extended_layout | ||
def layout_type(self) -> str: | ||
return self.layout_tensor.layout_type | ||
|
||
@classmethod | ||
def __torch_function__(cls, func, types, args=(), kwargs=None): | ||
|
@@ -308,13 +326,13 @@ def __torch_dispatch__(cls, func, types, args, kwargs): | |
def implements(aten_ops_or_torch_fn): | ||
return _implements(AffineQuantizedTensor, aten_ops_or_torch_fn) | ||
|
||
def register_layout_cls(extended_layout: str): | ||
return _register_layout_cls(AffineQuantizedTensor, extended_layout) | ||
def register_layout_cls(layout_type_class: type(LayoutType)): | ||
return _register_layout_cls(AffineQuantizedTensor, layout_type_class) | ||
|
||
def get_layout_tensor_constructor(extended_layout: str): | ||
return _get_layout_tensor_constructor(AffineQuantizedTensor, extended_layout) | ||
def get_layout_tensor_constructor(layout_type_class: type(LayoutType)): | ||
return _get_layout_tensor_constructor(AffineQuantizedTensor, layout_type_class) | ||
|
||
@register_layout_cls("plain") | ||
@register_layout_cls(PlainLayoutType) | ||
class PlainAQTLayout(AQTLayout): | ||
""" | ||
Layout storage class for plain layout for affine quantized tensor, it stores int_data, scale, zero_point | ||
|
@@ -330,6 +348,7 @@ def __new__( | |
int_data: torch.Tensor, | ||
scale: torch.Tensor, | ||
zero_point: torch.Tensor, | ||
layout_type: LayoutType, | ||
): | ||
kwargs = {} | ||
kwargs["device"] = int_data.device | ||
|
@@ -346,34 +365,39 @@ def __init__( | |
int_data: torch.Tensor, | ||
scale: torch.Tensor, | ||
zero_point: torch.Tensor, | ||
layout_type: LayoutType, | ||
): | ||
self.int_data = int_data | ||
self.scale = scale | ||
self.zero_point = zero_point | ||
self.layout_type = layout_type | ||
|
||
def __tensor_flatten__(self): | ||
return ["int_data", "scale", "zero_point"], [] | ||
return ["int_data", "scale", "zero_point"], [self.layout_type] | ||
|
||
@classmethod | ||
def __tensor_unflatten__( | ||
cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride | ||
): | ||
int_data, scale, zero_point = tensor_data_dict["int_data"], tensor_data_dict["scale"], tensor_data_dict["zero_point"] | ||
return cls(int_data, scale, zero_point) | ||
layout_type, = tensor_attributes | ||
return cls(int_data, scale, zero_point, layout_type) | ||
|
||
def to(self, *args, **kwargs): | ||
kwargs = self._get_to_kwargs(*args, **kwargs) | ||
return self.__class__( | ||
self.int_data.to(kwargs["device"]), | ||
self.scale.to(kwargs["device"]), | ||
self.zero_point.to(kwargs["device"]), | ||
self.layout_type, | ||
) | ||
|
||
def _apply_fn_to_data(self, fn): | ||
return self.__class__( | ||
fn(self.int_data), | ||
fn(self.scale), | ||
fn(self.zero_point), | ||
self.layout_type, | ||
) | ||
|
||
@classmethod | ||
|
@@ -398,19 +422,24 @@ def __torch_dispatch__(cls, func, types, args, kwargs): | |
|
||
__torch_function__ = torch._C._disabled_torch_function_impl | ||
|
||
def get_plain(self): | ||
def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | ||
return self.int_data, self.scale, self.zero_point | ||
|
||
def get_layout_type(self) -> LayoutType: | ||
return self.layout_type | ||
|
||
@classmethod | ||
def from_plain( | ||
cls, | ||
int_data: torch.Tensor, | ||
scale: torch.Tensor, | ||
zero_point: torch.Tensor, | ||
layout_type: LayoutType, | ||
): | ||
return cls(int_data, scale, zero_point) | ||
assert isinstance(layout_type, PlainLayoutType) | ||
return cls(int_data, scale, zero_point, layout_type) | ||
|
||
@register_layout_cls("tensor_core_tiled") | ||
@register_layout_cls(TensorCoreTiledLayoutType) | ||
class TensorCoreTiledAQTLayout(AQTLayout): | ||
""" | ||
Layout storage class for tensor_core_tiled layout for affine quantized tensor, this is for int4 only, | ||
|
@@ -427,6 +456,7 @@ def __new__( | |
packed_weight: torch.Tensor, | ||
scale_and_zero: torch.Tensor, | ||
transposed: bool, | ||
layout_type: LayoutType, | ||
): | ||
kwargs = {} | ||
kwargs["device"] = packed_weight.device | ||
|
@@ -443,31 +473,40 @@ def __init__( | |
packed_weight: torch.Tensor, | ||
scale_and_zero: torch.Tensor, | ||
transposed: bool, | ||
layout_type: LayoutType, | ||
): | ||
self.packed_weight = packed_weight | ||
self.scale_and_zero = scale_and_zero | ||
self.transposed = False | ||
self.layout_type = layout_type | ||
|
||
def __tensor_flatten__(self): | ||
return ["packed_weight", "scale_and_zero"], [self.transposed] | ||
return ["packed_weight", "scale_and_zero"], [self.transposed, self.layout_type] | ||
|
||
@classmethod | ||
def __tensor_unflatten__( | ||
cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride | ||
): | ||
packed_weight, scale_and_zero = tensor_data_dict["packed_weight"], tensor_data_dict["scale_and_zero"] | ||
transposed, = tensor_attributes | ||
return cls(packed_weight, scale_and_zero, transposed) | ||
transposed, layout_type, = tensor_attributes | ||
return cls(packed_weight, scale_and_zero, transposed, layout_type) | ||
|
||
@classmethod | ||
def from_plain(cls, int_data, scale, zero_point, inner_k_tiles=8): | ||
def from_plain( | ||
cls, | ||
int_data: torch.Tensor, | ||
scale: torch.Tensor, | ||
zero_point: torch.Tensor, | ||
layout_type: LayoutType | ||
): | ||
assert isinstance(layout_type, TensorCoreTiledLayoutType) | ||
# assert int_data.dtype == torch.uint8, "torch.ops.aten._convert_weight_to_int4pack expects `uint8` dtype" | ||
# packed_weight = torch.ops.aten._convert_weight_to_int4pack(int_data, inner_k_tiles) | ||
packed_weight = torch.ops.aten._convert_weight_to_int4pack(int_data.to(torch.int32), inner_k_tiles) | ||
packed_weight = torch.ops.aten._convert_weight_to_int4pack(int_data.to(torch.int32), layout_type.inner_k_tiles) | ||
scale = scale.reshape(int_data.shape[0], -1) | ||
zero_point = zero_point.reshape(int_data.shape[0], -1) | ||
scale_and_zero = pack_tinygemm_scales_and_zeros(scale, zero_point) | ||
return cls(packed_weight, scale_and_zero, False) | ||
return cls(packed_weight, scale_and_zero, False, layout_type) | ||
|
||
def to(self, *args, **kwargs): | ||
kwargs = self._get_to_kwargs(*args, **kwargs) | ||
|
@@ -477,18 +516,15 @@ def to(self, *args, **kwargs): | |
return self.__class__( | ||
self.packed_weight.to(device), | ||
self.scale_and_zero.to(device), | ||
self.transposed | ||
self.transposed, | ||
self.layout_type, | ||
) | ||
|
||
def _apply_fn_to_data(self, fn): | ||
self.packed_weight = fn(self.packed_weight) | ||
self.scale_and_zero = fn(self.scale_and_zero) | ||
return self | ||
|
||
def __repr__(self): | ||
int_data, scale, zero_point = self.get_plain() | ||
return f"TensorCoreTiledAQTLayout(int_data={int_data}, scale={scale}, zero_point={zero_point})" | ||
|
||
@classmethod | ||
def __torch_dispatch__(cls, func, types, args, kwargs): | ||
kwargs = {} if kwargs is None else kwargs | ||
|
@@ -511,7 +547,7 @@ def __torch_dispatch__(cls, func, types, args, kwargs): | |
|
||
__torch_function__ = torch._C._disabled_torch_function_impl | ||
|
||
def get_plain(self): | ||
def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | ||
from torchao.quantization.quant_primitives import ( | ||
ZeroPointDomain, | ||
quantize_affine, | ||
|
@@ -542,6 +578,9 @@ def get_plain(self): | |
int_data = quantize_affine(dequantized, block_size, scale, zero, target_dtype, quant_min, quant_max, zero_point_domain) | ||
return int_data, scale, zero | ||
|
||
def get_layout_type(self) -> LayoutType: | ||
return self.layout_type | ||
|
||
def _quantized_linear_op(input_tensor, weight_qtensor, bias): | ||
""" | ||
Quantized version of F.linear operator | ||
|
@@ -565,8 +604,8 @@ def _quantized_linear_op(input_tensor, weight_qtensor, bias): | |
is_cuda and | ||
input_is_int8 and | ||
input_tensor.dtype == weight_qtensor.dtype and | ||
input_tensor.extended_layout == "plain" and | ||
weight_qtensor.extended_layout == "plain" | ||
isinstance(input_tensor.layout_type, PlainLayoutType) and | ||
isinstance(weight_qtensor.layout_type, PlainLayoutType) | ||
): | ||
# | ||
# 1. do the matrix form of dot(X_i, W_j) | ||
|
@@ -608,7 +647,7 @@ def _quantized_linear_op(input_tensor, weight_qtensor, bias): | |
weight_qtensor.dtype == torch.bfloat16 and | ||
len(weight_qtensor.shape) == 2 and | ||
weight_qtensor.zero_point_domain == ZeroPointDomain.FLOAT and | ||
weight_qtensor.extended_layout == "tensor_core_tiled" | ||
isinstance(weight_qtensor.layout_type, TensorCoreTiledLayoutType) | ||
): | ||
assert weight_qtensor.block_size[0] == 1, f"Requires groupwise quantization, got block_size: {block_size}" | ||
assert input_tensor.shape[-1] == weight_qtensor.shape[1], ( | ||
|
@@ -651,7 +690,7 @@ def _quantized_linear_op(input_tensor, weight_qtensor, bias): | |
weight_qtensor.block_size[0] == 1 and | ||
weight_qtensor.block_size[1] == weight_qtensor.shape[1] and | ||
weight_qtensor.zero_point_domain == ZeroPointDomain.INT and | ||
weight_qtensor.extended_layout == "plain" | ||
isinstance(weight_qtensor.layout_type, PlainLayoutType) | ||
): | ||
# TODO: enable cpu and mps efficient path | ||
# per channel int8 weight only quantizated mm | ||
|
Uh oh!
There was an error while loading. Please reload this page.