Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion test/core/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from torchao.quantization.quant_api import (
FbgemmConfig,
Float8DynamicActivationFloat8WeightConfig,
Float8DynamicActivationInt4WeightConfig,
Float8WeightOnlyConfig,
FPXWeightOnlyConfig,
GemliteUIntXWeightOnlyConfig,
Expand All @@ -49,13 +50,14 @@
weight_dtype=torch.float8_e4m3fn,
),
UIntXWeightOnlyConfig(dtype=torch.uint1),
Float8DynamicActivationInt4WeightConfig(),
Int4DynamicActivationInt4WeightConfig(),
Int4WeightOnlyConfig(
group_size=32,
),
Int4WeightOnlyConfig(
group_size=128,
packing_format="tile_packed_to_4d",
int4_packing_format="tile_packed_to_4d",
int4_choose_qparams_algorithm="hqq",
version=2,
),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

BF16_ACT_CONFIG = Int4WeightOnlyConfig(
group_size=128,
packing_format="marlin_sparse",
int4_packing_format="marlin_sparse",
version=2,
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
def get_config(group_size):
return Int4WeightOnlyConfig(
group_size=group_size,
packing_format="opaque",
int4_packing_format="opaque",
version=2,
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
def get_config(group_size):
return Int4WeightOnlyConfig(
group_size=group_size,
packing_format="plain_int32",
int4_packing_format="plain_int32",
version=2,
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,13 @@

BF16_ACT_CONFIG = Int4WeightOnlyConfig(
group_size=128,
packing_format="preshuffled",
int4_packing_format="preshuffled",
version=2,
)

# only 128 group_size is supported
FP8_ACT_CONFIG = Float8DynamicActivationInt4WeightConfig(
packing_format="preshuffled",
int4_packing_format="preshuffled",
)


Expand Down
11 changes: 9 additions & 2 deletions test/quantization/quantize_/workflows/int4/test_int4_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,24 @@
from torchao.quantization.quantize_.common import SupportsActivationPreScaling
from torchao.quantization.utils import compute_error
from torchao.testing.utils import TorchAOIntegrationTestCase
from torchao.utils import is_sm_at_least_90, torch_version_at_least
from torchao.utils import (
_is_fbgemm_genai_gpu_available,
is_sm_at_least_90,
torch_version_at_least,
)


@unittest.skipIf(not torch_version_at_least("2.8.0"), "Need pytorch 2.8+")
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@unittest.skipIf(not is_sm_at_least_90(), "Nedd sm90+")
@unittest.skipIf(
not _is_fbgemm_genai_gpu_available(), "Requires fbgemm-gpu-genai >= 1.2.0"
)
class TestInt4Tensor(TorchAOIntegrationTestCase):
def setUp(self):
self.config = Int4WeightOnlyConfig(
group_size=128,
packing_format="plain",
int4_packing_format="plain",
version=2,
)
self.GPU_DEVICES = ["cuda"] if torch.cuda.is_available() else []
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,13 @@

INT4_CONFIG = Int4WeightOnlyConfig(
group_size=128,
packing_format="tile_packed_to_4d",
int4_packing_format="tile_packed_to_4d",
version=2,
)

INT4_HQQ_CONFIG = Int4WeightOnlyConfig(
group_size=128,
packing_format="tile_packed_to_4d",
int4_packing_format="tile_packed_to_4d",
int4_choose_qparams_algorithm="hqq",
version=2,
)
Expand Down
35 changes: 18 additions & 17 deletions torchao/quantization/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@
Int4ChooseQParamsAlgorithm,
Int4MarlinSparseTensor,
Int4OpaqueTensor,
Int4PackingFormat,
Int4PlainInt32Tensor,
Int4PreshuffledTensor,
Int4Tensor,
Expand Down Expand Up @@ -1075,7 +1076,7 @@ class Int4WeightOnlyConfig(AOBaseConfig):
Note:
Current state for Int4WeightOnlyConfig is that it supports both v1 (legacy) and v2

For v2 (version = 2), only `group_size`, `packing_format`, `int4_choose_qparams_algorithm` and `set_inductor_config` are valid, all other args will be ignored
For v2 (version = 2), only `group_size`, `int4_packing_format`, `int4_choose_qparams_algorithm` and `set_inductor_config` are valid, all other args will be ignored
For v1 (version = 1), only `group_size`, `layout`, `use_hqq`, `zero_point_domain`, `preserve_zero` and `set_inductor_config` are valid, we plan to deprecate v1 in torchao 0.15 to make this config
less confusing
"""
Expand All @@ -1087,7 +1088,7 @@ class Int4WeightOnlyConfig(AOBaseConfig):
set_inductor_config: bool = True
preserve_zero: Optional[bool] = None
# only used in version >= 2
packing_format: PackingFormat = PackingFormat.PLAIN
int4_packing_format: Int4PackingFormat = Int4PackingFormat.PLAIN
int4_choose_qparams_algorithm: Int4ChooseQParamsAlgorithm = (
Int4ChooseQParamsAlgorithm.TINYGEMM
)
Expand All @@ -1113,7 +1114,7 @@ def _int4_weight_only_quantize_tensor(weight, config):
use_hqq = config.use_hqq
int4_choose_qparams_algorithm = config.int4_choose_qparams_algorithm
zero_point_domain = config.zero_point_domain
packing_format = config.packing_format
int4_packing_format = config.int4_packing_format

if weight.shape[-1] % group_size != 0:
logger.info(
Expand All @@ -1127,50 +1128,50 @@ def _int4_weight_only_quantize_tensor(weight, config):
block_size = list(block_size)

if int4_choose_qparams_algorithm == Int4ChooseQParamsAlgorithm.HQQ:
assert packing_format == PackingFormat.TILE_PACKED_TO_4D, (
f"Int4ChooseQParamsAlgorithm.HQQ is not supported by packing format {packing_format}, it's only supported by PackingFormat.TILE_PACKED_TO_4D curretnly"
assert int4_packing_format == Int4PackingFormat.TILE_PACKED_TO_4D, (
f"Int4ChooseQParamsAlgorithm.HQQ is not supported by packing format {int4_packing_format}, it's only supported by Int4PackingFormat.TILE_PACKED_TO_4D curretnly"
)

if packing_format == PackingFormat.PRESHUFFLED:
if int4_packing_format == Int4PackingFormat.PRESHUFFLED:
new_weight = Int4PreshuffledTensor.from_hp(
weight,
block_size,
activation_dtype=torch.bfloat16,
)
return new_weight
elif packing_format == PackingFormat.PLAIN:
elif int4_packing_format == Int4PackingFormat.PLAIN:
new_weight = Int4Tensor.from_hp(
weight,
block_size,
)
return new_weight
elif packing_format == PackingFormat.PLAIN_INT32:
elif int4_packing_format == Int4PackingFormat.PLAIN_INT32:
new_weight = Int4PlainInt32Tensor.from_hp(
weight,
block_size,
)
return new_weight
elif packing_format == PackingFormat.MARLIN_SPARSE:
elif int4_packing_format == Int4PackingFormat.MARLIN_SPARSE:
new_weight = Int4MarlinSparseTensor.from_hp(
weight,
block_size,
)
return new_weight
elif packing_format == PackingFormat.OPAQUE:
elif int4_packing_format == Int4PackingFormat.OPAQUE:
new_weight = Int4OpaqueTensor.from_hp(
weight,
block_size,
)
return new_weight
elif packing_format == PackingFormat.TILE_PACKED_TO_4D:
elif int4_packing_format == Int4PackingFormat.TILE_PACKED_TO_4D:
new_weight = Int4TilePackedTo4dTensor.from_hp(
weight,
block_size,
int4_choose_qparams_algorithm=int4_choose_qparams_algorithm,
)
return new_weight
else:
raise ValueError(f"Unsupported packing format: {packing_format}")
raise ValueError(f"Unsupported int4 packing format: {int4_packing_format}")

assert config.version == 1

Expand Down Expand Up @@ -1254,10 +1255,10 @@ class Float8DynamicActivationInt4WeightConfig(AOBaseConfig):
and above and no benefits of making it bigger)

Args:
`packing_format`: how the weight is packed, only preshuffled is supported
`int4_packing_format`: how the weight is packed, only preshuffled is supported
"""

packing_format: PackingFormat = "preshuffled"
int4_packing_format: Int4PackingFormat = "preshuffled"


@register_quantize_module_handler(Float8DynamicActivationInt4WeightConfig)
Expand All @@ -1268,10 +1269,10 @@ def _float8_dynamic_activation_int4_weight_transform(
"applying int8 weight only quant requires module to have weight attribute"
+ " but {module} does not have one"
)
packing_format = config.packing_format
int4_packing_format = config.int4_packing_format

assert packing_format == "preshuffled", (
f"only preshuffled packing_format supported right now, got: {packing_format}"
assert int4_packing_format == "preshuffled", (
f"only preshuffled int4_packing_format supported right now, got: {int4_packing_format}"
)
weight = module.weight
group_size = 128
Expand Down
23 changes: 1 addition & 22 deletions torchao/quantization/quantize_/common/packing_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ class PackingFormat(str, Enum):

"""
plain means the format that quantized Tensor data lays out elements in Tensor sequentially,
for example: for a Tensor of shape (4, 6):
for example: for a Tensor of shape (4, 6):
a_0_0, a_0_1, ..., a_0_5,
...
a_3_0, a_3_1, ..., a_3_5
Expand All @@ -26,32 +26,11 @@ class PackingFormat(str, Enum):
"""
PLAIN = "plain"

"""
preshuffled is referring to the preshuffled format used by fbgemm kernels
"""
PRESHUFFLED = "preshuffled"

"""
marlin_sparse is referring to the format used by marlin kernels, only supports symmetric quantization
"""
MARLIN_SPARSE = "marlin_sparse"

"""
Unpacked to int8 means the subbyte quantized data is stored as int8
"""
UNPACKED_TO_INT8 = "unpacked_to_int8"

"""
plain_int32 is referring to the format used by int4 weight-only quantization.
which is a groupwise quantization format 2*int4 is store in a byte and 4*(int4*2) is stored in a int32.
"""
PLAIN_INT32 = "plain_int32"

"""
tile_packed_to_4d is referring to the format used by tinygemm kernels for int4 quantization
"""
TILE_PACKED_TO_4D = "tile_packed_to_4d"

"""
Opaque packing format that's used for tensors that does not have a predefined packing format
(that may be decided on hardware, tensor shape, library availability etc.) and it's not
Expand Down
2 changes: 2 additions & 0 deletions torchao/quantization/quantize_/workflows/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from .int4.int4_opaque_tensor import (
Int4OpaqueTensor,
)
from .int4.int4_packing_format import Int4PackingFormat
from .int4.int4_plain_int32_tensor import (
Int4PlainInt32Tensor,
)
Expand Down Expand Up @@ -39,4 +40,5 @@
"IntxUnpackedTensor",
"IntxUnpackedToInt8Tensor",
"Int4ChooseQParamsAlgorithm",
"Int4PackingFormat",
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.

from enum import Enum


# can switch to StrEnum (https://docs.python.org/3/library/enum.html#enum.StrEnum)
# after python 3.10 is end of life (https://devguide.python.org/versions/)
class Int4PackingFormat(str, Enum):
"""Packing format for quantized data in Int4 Tensor subclasses in torchao, represents how
the values in quantized data are packed and laid out in memory.
"""

"""
plain means the format that quantized Tensor data lays out elements in Tensor sequentially,
for example: for a Tensor of shape (4, 6):
a_0_0, a_0_1, ..., a_0_5,
...
a_3_0, a_3_1, ..., a_3_5

For example for int4, we will
pack two adjacent int4 elements into one uint8/int8 value for plain packing format
"""
PLAIN = "plain"

"""
preshuffled is referring to the preshuffled format used by fbgemm kernels
"""
PRESHUFFLED = "preshuffled"

"""
marlin_sparse is referring to the format used by marlin kernels, requires symmetric quantization
"""
MARLIN_SPARSE = "marlin_sparse"

"""
plain_int32 is a format that 2 adjacent int4 values are packed in a byte and 4 such packed bytes are stored in a int32 value.
"""
PLAIN_INT32 = "plain_int32"

"""
tile_packed_to_4d is referring to the format used by tinygemm kernels for int4 quantization
for a Tensor of shape (n, k), the packed weight will have dimension:
[n / 8][k / (inner_k_tiles * 16)][32][inner_k_tiles / 2], where inner_k_tiles is 8 currently
for simplication of Int4TilePackedTo4dTensor API
"""
TILE_PACKED_TO_4D = "tile_packed_to_4d"

"""
Opaque packing format that's used for tensors that does not have a predefined packing format
(that may be decided on hardware, tensor shape, library availability etc.) and it's not
needed for the rest of the system to understand the specific format that's adopted.
"""
OPAQUE = "opaque"
Loading