Skip to content

Update api_ref_dtypes docs #1610

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Jan 24, 2025
33 changes: 28 additions & 5 deletions docs/source/api_ref_dtypes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,42 @@ torchao.dtypes

.. currentmodule:: torchao.dtypes

Layouts and Tensor Subclasses
-----------------------------
.. autosummary::
:toctree: generated/
:nosignatures:

NF4Tensor
AffineQuantizedTensor
Layout
PlainLayout
SemiSparseLayout
TensorCoreTiledLayout
Float8Layout
FloatxTensor
FloatxTensorCoreLayout
MarlinSparseLayout
BlockSparseLayout
UintxLayout
MarlinQQQTensor
MarlinQQQLayout
Int4CPULayout
CutlassInt4PackedLayout

Quantization techniques
-----------------------
.. autosummary::
:toctree: generated/
:nosignatures:

to_nf4
to_affine_quantized_intx
to_affine_quantized_intx_static
to_affine_quantized_fpx
to_affine_quantized_floatx
to_affine_quantized_floatx_static
to_affine_quantized_fpx
NF4Tensor
AffineQuantizedTensor

to_marlinqqq_quantized_intx
to_nf4
..
_NF4Tensor - add after fixing torchao/dtypes/nf4tensor.py:docstring
of torchao.dtypes.nf4tensor.NF4Tensor.dequantize_scalers:6:Unexpected indentation.
37 changes: 19 additions & 18 deletions torchao/dtypes/affine_quantized_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,8 @@
# Tensor Subclass Definition #
##############################
class AffineQuantizedTensor(TorchAOBaseTensor):
"""
Affine quantized tensor subclass. Affine quantization means we quantize the floating point tensor with an affine transformation:
quantized_tensor = float_tensor / scale + zero_point
"""Affine quantized tensor subclass. Affine quantization means we quantize the floating point tensor with an affine transformation:
quantized_tensor = float_tensor / scale + zero_point

To see what happens during choose_qparams, quantization and dequantization for affine quantization,
please checkout https://github.com/pytorch/ao/blob/main/torchao/quantization/quant_primitives.py
Expand All @@ -56,21 +55,18 @@ class AffineQuantizedTensor(TorchAOBaseTensor):
regardless of the internal representation's type or orientation.

fields:
tensor_impl (AQTTensorImpl): tensor that serves as a general tensor impl storage for the quantized data,
e.g. storing plain tensors (int_data, scale, zero_point) or packed formats depending on device
and operator/kernel
block_size (Tuple[int, ...]): granularity of quantization, this means the size of the tensor elements that's sharing the same qparam
e.g. when size is the same as the input tensor dimension, we are using per tensor quantization
shape (torch.Size): the shape for the original high precision Tensor
quant_min (Optional[int]): minimum quantized value for the Tensor, if not specified, it will be derived from dtype of `int_data`
quant_max (Optional[int]): maximum quantized value for the Tensor, if not specified, it will be derived from dtype of `int_data`
zero_point_domain (ZeroPointDomain): the domain that zero_point is in, should be either integer or float
if zero_point is in integer domain, zero point is added to the quantized integer value during
quantization
if zero_point is in floating point domain, zero point is subtracted from the floating point (unquantized)
value during quantization
default is ZeroPointDomain.INT
dtype: dtype for original high precision tensor, e.g. torch.float32
- tensor_impl (AQTTensorImpl): tensor that serves as a general tensor impl storage for the quantized data,
e.g. storing plain tensors (int_data, scale, zero_point) or packed formats depending on device and operator/kernel
- block_size (Tuple[int, ...]): granularity of quantization, this means the size of the tensor elements that's sharing the same qparam
e.g. when size is the same as the input tensor dimension, we are using per tensor quantization
- shape (torch.Size): the shape for the original high precision Tensor
- quant_min (Optional[int]): minimum quantized value for the Tensor, if not specified, it will be derived from dtype of `int_data`
- quant_max (Optional[int]): maximum quantized value for the Tensor, if not specified, it will be derived from dtype of `int_data`
- zero_point_domain (ZeroPointDomain): the domain that zero_point is in, should be either integer or float
if zero_point is in integer domain, zero point is added to the quantized integer value during quantization
if zero_point is in floating point domain, zero point is subtracted from the floating point (unquantized) value during quantization
default is ZeroPointDomain.INT
- dtype: dtype for original high precision tensor, e.g. torch.float32
"""

@staticmethod
Expand Down Expand Up @@ -207,6 +203,7 @@ def from_hp_to_intx(
_layout: Layout = PlainLayout(),
use_hqq: bool = False,
):
"""Convert a high precision tensor to an integer affine quantized tensor."""
original_shape = input_float.shape
input_float = _layout.pre_process(input_float)

Expand Down Expand Up @@ -302,6 +299,7 @@ def from_hp_to_intx_static(
zero_point_domain: Optional[ZeroPointDomain] = ZeroPointDomain.INT,
_layout: Layout = PlainLayout(),
):
"""Create an integer AffineQuantizedTensor from a high precision tensor using static parameters."""
if target_dtype not in FP8_TYPES:
assert (
zero_point_domain is not None
Expand Down Expand Up @@ -348,6 +346,7 @@ def from_hp_to_floatx(
_layout: Layout,
scale_dtype: Optional[torch.dtype] = None,
):
"""Convert a high precision tensor to a float8 quantized tensor."""
if target_dtype in FP8_TYPES:
return cls.from_hp_to_intx(
input_float=input_float,
Expand Down Expand Up @@ -378,6 +377,7 @@ def from_hp_to_floatx_static(
target_dtype: torch.dtype,
_layout: Layout,
):
"""Create a float8 AffineQuantizedTensor from a high precision tensor using static parameters."""
if target_dtype in FP8_TYPES:
return cls.from_hp_to_intx_static(
input_float=input_float,
Expand All @@ -401,6 +401,7 @@ def from_hp_to_fpx(
input_float: torch.Tensor,
_layout: Layout,
):
"""Create a floatx AffineQuantizedTensor from a high precision tensor. Floatx is represented as ebits and mbits, and supports the representation of float1-float7."""
from torchao.dtypes.floatx import FloatxTensorCoreLayout

assert isinstance(
Expand Down
6 changes: 6 additions & 0 deletions torchao/dtypes/floatx/float8_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,12 @@

@dataclass(frozen=True)
class Float8Layout(Layout):
"""Represents the layout configuration for Float8 affine quantized tensors.

Attributes:
mm_config (Optional[Float8MMConfig]): Configuration for matrix multiplication operations involving Float8 tensors. If None, default settings are used.
"""

mm_config: Optional[Float8MMConfig] = None


Expand Down
4 changes: 3 additions & 1 deletion torchao/dtypes/floatx/floatx_tensor_core_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,7 +450,9 @@ def from_scaled_tc_floatx(tensor: Tensor, ebits: int, mbits: int, scale=None) ->
# quantization api integrations
@dataclass(frozen=True)
class FloatxTensorCoreLayout(Layout):
"""Layout type for FloatxTensorCoreAQTTensorImpl"""
"""FloatxTensorCoreLayout is a data class that defines the layout for a tensor with a specific number of exponent bits (ebits) and mantissa bits (mbits).
This layout is used in the context of quantization and packing of tensors optimized for TensorCore operations.
"""

ebits: int
mbits: int
Expand Down
4 changes: 2 additions & 2 deletions torchao/dtypes/nf4tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -662,10 +662,9 @@ def dequantize_scalers(
) -> torch.Tensor:
"""Used to unpack the double quantized scalers

Args;
Args:
input_tensor: Input tensor to convert to QLoRA format this is the quantized scalers in int8 format
quantization_factor: Tensor of per_scaler_block quantization factors stored in inpt_weight.dtype
size: (n_scaler_blocks)
scaler_block_size: Scaler block size to use for double quantization.

"""
Expand Down Expand Up @@ -953,6 +952,7 @@ def linear_nf4(input: torch.Tensor, weight: NF4Tensor) -> torch.Tensor:


def to_nf4(tensor, block_size: int = 64, scaler_block_size: int = 256):
"""Convert a given tensor to normalized float 4-bit tensor."""
return NF4Tensor.from_tensor(tensor, block_size, scaler_block_size)


Expand Down
6 changes: 6 additions & 0 deletions torchao/dtypes/uintx/block_sparse_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,12 @@

@dataclass(frozen=True)
class BlockSparseLayout(Layout):
"""BlockSparseLayout is a data class that represents the layout of a block sparse matrix.

Attributes:
blocksize (int): The size of the blocks in the sparse matrix. Default is 64.
"""

blocksize: int = 64


Expand Down
2 changes: 2 additions & 0 deletions torchao/dtypes/uintx/cutlass_int4_packed_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ def _aqt_is_int4(aqt):

@dataclass(frozen=True)
class CutlassInt4PackedLayout(Layout):
"""Layout class for int4 packed layout for affine quantized tensor, for cutlass kernel."""

pass


Expand Down
7 changes: 4 additions & 3 deletions torchao/dtypes/uintx/int4_cpu_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,16 @@

@dataclass(frozen=True)
class Int4CPULayout(Layout):
"""Only for PyTorch version at least 2.6"""
"""Layout class for int4 CPU layout for affine quantized tensor, used by tinygemm kernels `_weight_int4pack_mm_for_cpu`.
Only for PyTorch version at least 2.6
"""

pass


@register_layout(Int4CPULayout)
class Int4CPUAQTTensorImpl(AQTTensorImpl):
"""
TensorImpl for int4 CPU layout for affine quantized tensor, this is for int4 only,
"""TensorImpl for int4 CPU layout for affine quantized tensor, this is for int4 only,
used by tinygemm kernels `_weight_int4pack_mm_for_cpu`
It stores the original tensor of dimension [n][k] (int32 dtype) as packed weight of 2-d tensor of
dimension: [n][k / 2] (uint8 dtype)
Expand Down
6 changes: 4 additions & 2 deletions torchao/dtypes/uintx/marlin_qqq_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,7 @@


class MarlinQQQTensor(AffineQuantizedTensor):
"""
MarlinQQQ quantized tensor subclass which inherits AffineQuantizedTensor class.
"""MarlinQQQ quantized tensor subclass which inherits AffineQuantizedTensor class.

To see what happens during choose_qparams_and_quantize_affine_qqq, quantization and dequantization for marlin qqq quantization,
please checkout https://github.com/pytorch/ao/blob/main/torchao/quantization/quant_primitives.py
Expand Down Expand Up @@ -58,6 +57,7 @@ def from_hp_to_intx(
zero_point_domain: Optional[ZeroPointDomain] = ZeroPointDomain.INT,
_layout: Optional[Layout] = None,
):
"""Converts a floating point tensor to a Marlin QQQ quantized tensor."""
original_shape = input_float.shape
input_float = _layout.pre_process(input_float)
nbits = int(math.log2(quant_max - quant_min + 1))
Expand All @@ -81,6 +81,8 @@ def from_hp_to_intx(

@dataclass(frozen=True)
class MarlinQQQLayout(Layout):
"""MarlinQQQLayout is a layout class for Marlin QQQ quantization."""

pass


Expand Down
11 changes: 11 additions & 0 deletions torchao/dtypes/uintx/marlin_sparse_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,17 @@ def _linear_fp_act_int4_weight_sparse_marlin_impl(input_tensor, weight_tensor, b

@dataclass(frozen=True)
class MarlinSparseLayout(Layout):
"""MarlinSparseLayout is a layout class for handling sparse tensor formats
specifically designed for the Marlin sparse kernel. This layout is used
to optimize the storage and computation of affine quantized tensors with
2:4 sparsity patterns.

The layout ensures that the tensor data is pre-processed and stored in a
format that is compatible with the Marlin sparse kernel operations. It
provides methods for preprocessing input tensors and managing the layout
of quantized tensors.
"""

def pre_process(self, input: torch.Tensor) -> torch.Tensor:
"""Preprocess the input tensor to be in the correct format for the Marlin sparse kernel.
- 1º: the input tensor is transposed since the linear layer keeps the weights in a transposed format
Expand Down
7 changes: 7 additions & 0 deletions torchao/dtypes/uintx/semi_sparse_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,13 @@ def _linear_int8_act_int8_weight_semi_structured_sparse_impl(

@dataclass(frozen=True)
class SemiSparseLayout(Layout):
"""SemiSparseLayout is a layout class for handling semi-structured sparse
matrices in affine quantized tensors. This layout is specifically designed
to work with the 2:4 sparsity pattern, where two out of every four elements
are pruned to zero. This class provides methods for preprocessing input
tensors to conform to this sparsity pattern.
"""

def pre_process(self, input: torch.Tensor) -> torch.Tensor:
# prune to 2:4 if not already
temp = input.detach()
Expand Down
10 changes: 5 additions & 5 deletions torchao/dtypes/uintx/tensor_core_tiled_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,10 @@ def _linear_bf16_act_uint4_weight_impl(input_tensor, weight_tensor, bias):

@dataclass(frozen=True)
class TensorCoreTiledLayout(Layout):
"""
inner_k_tiles is an internal argument for packing function of tensor core tiled layout
that can affect the performance of the matmul kernel
"""TensorCoreTiledLayout is a layout class for handling tensor core tiled layouts in affine quantized tensors. It provides methods for pre-processing and post-processing tensors to fit the required layout for efficient computation on tensor cores.

Attributes:
inner_k_tiles (int): An internal argument for the packing function of tensor core tiled layout that can affect the performance of the matmul kernel. Defaults to 8.
"""

inner_k_tiles: int = 8
Expand Down Expand Up @@ -149,8 +150,7 @@ def extra_repr(self):

@register_layout(TensorCoreTiledLayout)
class TensorCoreTiledAQTTensorImpl(AQTTensorImpl):
"""
TensorImpl for tensor_core_tiled layout for affine quantized tensor, this is for int4 only,
"""TensorImpl for tensor_core_tiled layout for affine quantized tensor, this is for int4 only,
used by tinygemm kernels `_weight_int4pack_mm`

It stores the original tensor of dimension [n][k] (int32 dtype) as packed weight of 4-d tensor of
Expand Down
11 changes: 11 additions & 0 deletions torchao/dtypes/uintx/uintx_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,17 @@ def _(func, types, args, kwargs):

@dataclass(frozen=True)
class UintxLayout(Layout):
"""A layout class for Uintx tensors, which are tensors with elements packed into
smaller bit-widths than the standard 8-bit byte. This layout is used to define
how the data is stored and processed in UintxTensor objects.

Attributes:
dtype (torch.dtype): The data type of the tensor elements, which determines
the bit-width used for packing.
pack_dim (int): The dimension along which the data is packed. Default is -1,
which indicates the last dimension.
"""

dtype: torch.dtype
pack_dim: int = -1

Expand Down
19 changes: 14 additions & 5 deletions torchao/dtypes/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,15 @@

@dataclass(frozen=True)
class Layout:
"""The Layout class serves as a base class for defining different data layouts for tensors.
It provides methods for pre-processing and post-processing tensors, as well as static
pre-processing with additional parameters like scale, zero_point, and block_size.

The Layout class is designed to be extended by other layout classes that define specific
data representations and behaviors for tensors. It is used in conjunction with TensorImpl
classes to represent custom data layouts and how tensors interact with different operators.
"""

def pre_process(self, input: torch.Tensor) -> torch.Tensor:
return input

Expand All @@ -49,13 +58,13 @@ def extra_repr(self) -> str:
return ""


"""
Plain Layout, the most basic Layout, also has no extra metadata, will typically be the default
"""


@dataclass(frozen=True)
class PlainLayout(Layout):
"""PlainLayout is the most basic layout class, inheriting from the Layout base class.
It does not add any additional metadata or processing steps to the tensor.
Typically, this layout is used as the default when no specific layout is required.
"""

pass


Expand Down
Loading