Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions test/quantization/pt2e/test_quantize_pt2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -2948,10 +2948,11 @@ def has_inplace_ops(graph_module: torch.fx.GraphModule) -> bool:
@unittest.skipIf(not torch_version_at_least("2.7.0"), "Requires torch 2.7+")
class TestQuantizePT2EAffineQuantization(PT2EQuantizationTestCase):
def test_channel_group_quantization(self):
from torchao.quantization import PerGroup, PerToken
from torchao.quantization.pt2e._affine_quantization import (
AffineQuantizedMinMaxObserver,
)
from torchao.quantization.pt2e.observer import MappingType, PerGroup, PerToken
from torchao.quantization.pt2e.observer import MappingType

class BackendAQuantizer(Quantizer):
def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
Expand Down Expand Up @@ -3031,13 +3032,13 @@ def forward(self, x):
def test_dynamic_affine_act_per_channel_weights(self):
import operator

from torchao.quantization import PerToken
from torchao.quantization.pt2e._affine_quantization import (
AffineQuantizedMovingAverageMinMaxObserver,
)
from torchao.quantization.pt2e.observer import (
MappingType,
PerChannelMinMaxObserver,
PerToken,
)

class BackendAQuantizer(Quantizer):
Expand Down Expand Up @@ -3122,12 +3123,14 @@ def forward(self, x):
def test_dynamic_per_tok_act_per_group_weights(self):
import operator

from torchao.quantization import PerGroup, PerToken

# TODO: merge into torchao observer
from torchao.quantization.pt2e._affine_quantization import (
AffineQuantizedMinMaxObserver,
AffineQuantizedPlaceholderObserver,
)
from torchao.quantization.pt2e.observer import MappingType, PerGroup, PerToken
from torchao.quantization.pt2e.observer import MappingType

class BackendAQuantizer(Quantizer):
def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
Expand Down
2 changes: 2 additions & 0 deletions torchao/quantization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
MultiTensorInputRecorder,
)
from .granularity import (
Granularity,
PerAxis,
PerGroup,
PerRow,
Expand Down Expand Up @@ -197,6 +198,7 @@
"MappingType",
"ZeroPointDomain",
"TorchAODType",
"Granularity",
"PerTensor",
"PerAxis",
"PerGroup",
Expand Down
1 change: 0 additions & 1 deletion torchao/quantization/linear_activation_quantized_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,6 @@ def _same_metadata(

@implements([torch.nn.functional.linear, aten.linear.default])
def _(func, types, args, kwargs):

input_tensor = kwargs.get("input", args[0] if len(args) > 0 else None)
weight_tensor = kwargs.get("weight", args[1] if len(args) > 1 else None)
bias = kwargs.get("bias", args[2] if len(args) > 2 else None)
Expand Down
22 changes: 1 addition & 21 deletions torchao/quantization/observer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

from .granularity import (
Granularity,
PerAxis,
PerRow,
PerTensor,
)
Expand All @@ -24,6 +23,7 @@
_get_reduction_params,
choose_qparams_affine_with_min_max,
)
from .utils import get_block_size

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -63,26 +63,6 @@ def _with_args(cls_or_self, *args, **kwargs):
return r


def get_block_size(
input_shape: Tuple[int, ...], granularity: Granularity
) -> Tuple[int, ...]:
"""Get the block size based on the input shape and granularity type.

Args:
input_shape: The input tensor shape possibly more than 2 dimensions
granularity: The granularity type of the quantization
"""
if isinstance(granularity, PerTensor):
return input_shape
elif isinstance(granularity, PerAxis):
block_size = list(input_shape)
block_size[granularity.axis] = 1
return tuple(block_size)
elif isinstance(granularity, PerRow):
return (1,) * (len(input_shape) - 1) + (input_shape[-1],)
raise ValueError(f"Unsupported Granularity: {granularity}")


ABC: Any = ABCMeta("ABC", (object,), {}) # compatible with Python 2 *and* 3:


Expand Down
16 changes: 0 additions & 16 deletions torchao/quantization/pt2e/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@
from .observer import (
AffineQuantizedObserverBase,
FixedQParamsObserver,
Granularity,
HistogramObserver,
MappingType,
MinMaxObserver,
Expand All @@ -57,20 +56,13 @@
NoopObserver,
ObserverBase,
PartialWrapper,
PerAxis,
PerBlock,
PerChannelMinMaxObserver,
PerGroup,
PerRow,
PerTensor,
PerToken,
PlaceholderObserver,
RecordingObserver,
ReuseInputObserver,
TorchAODType,
UniformQuantizationObserverBase,
ZeroPointDomain,
get_block_size,
)

for _f in [
Expand Down Expand Up @@ -139,17 +131,9 @@
"compare_results",
# should be merged with torchao/quantization/observer.py in the future
"AffineQuantizedObserverBase",
"Granularity",
"MappingType",
"PerAxis",
"PerBlock",
"PerGroup",
"PerRow",
"PerTensor",
"PerToken",
"TorchAODType",
"ZeroPointDomain",
"get_block_size",
"default_fake_quant",
"default_dynamic_fake_quant",
]
Expand Down
2 changes: 1 addition & 1 deletion torchao/quantization/pt2e/_affine_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@
MappingType,
TorchAODType,
ZeroPointDomain,
get_block_size,
)
from torchao.quantization.utils import get_block_size

ABC: Any = ABCMeta("ABC", (object,), {}) # compatible with Python 2 *and* 3:

Expand Down
143 changes: 1 addition & 142 deletions torchao/quantization/pt2e/observer.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from torch.fx import Node

import torchao
from torchao.quantization import Granularity
from torchao.quantization.pt2e.utils import (
calculate_qmin_qmax,
check_min_max_valid,
Expand Down Expand Up @@ -67,17 +68,9 @@
"ReuseInputObserver",
"UniformQuantizationObserverBase",
"AffineQuantizedObserverBase",
"Granularity",
"MappingType",
"PerAxis",
"PerBlock",
"PerGroup",
"PerRow",
"PerTensor",
"PerToken",
"TorchAODType",
"ZeroPointDomain",
"get_block_size",
]


Expand Down Expand Up @@ -1622,7 +1615,6 @@ def calculate_qparams(self):
We plan to merge the following with torchao repo after we move pt2e flow to torchao
copied from https://github.com/pytorch/ao/blob/main/torchao/quantization/observer.py
"""
from dataclasses import dataclass
from enum import Enum, auto


Expand Down Expand Up @@ -1679,139 +1671,6 @@ class TorchAODType(Enum):
INT7 = auto()


@dataclass(frozen=True)
class Granularity:
"""
Base class for representing the granularity of quantization.

This class serves as a parent for specific granularity types used in
quantization operations, such as per-tensor or per-axis quantization.
"""


@dataclass(frozen=True)
class PerBlock(Granularity):
"""
Represents per-block granularity in quantization. See
:func:`~torchao.quantization.quant_primitives.quantize_affine` for docs for
`block_size`

Attributes:
block_size (Tuple[int, ...]): The size of each quantization group
"""

block_size: tuple[int, ...]


@dataclass(frozen=True)
class PerTensor(Granularity):
"""
Represents per-tensor granularity in quantization.

This granularity type calculates the quantization parameters
based off the entire tensor.

"""


@dataclass(frozen=True)
class PerAxis(Granularity):
"""
Represents per-axis granularity in quantization.

This granularity type calculates different quantization parameters
along a specified axis of the tensor.

For example if the input tensor is shape [8, 16] and axis=0, then
the quantization parameters are calculated for each row of the tensor.
Giving a total of 8 quantization parameters.

Attributes:
axis (int): The axis along which reduction is performed.
"""

axis: int


@dataclass(frozen=True)
class PerGroup(Granularity):
"""
Represents per-channel group granularity in quantization.

This granularity type calculates different quantization parameters
for each group of <group_size> elements.

For example if the input tensor is shape [8, 16], and the group size is 4, then
the input tensor is reshaped to [64, 4]
quantization parameters are calculated for each group of 4 elements,
giving a total of 64 quantization parameters.

Attributes:
group_size (int): The size of each quantization group

"""

group_size: int


class PerRow(Granularity):
"""
Represents row-wise granularity in quantization.

This is a special case of per-axis quantization and is unique to Float8 matmuls
where the input is quantized with a block_size of (1, ..., input.shape[-1]). And the weight
is quantized with a block_size of (1, weight.shape[1]).
"""


class PerToken(Granularity):
"""
Represents per-token granularity in quantization.

This granularity type calculates a different set of quantization parameters
for each token, which is represented as the last dimension of the tensor.

For example, if the input tensor has shape [2, 3, 4], then there are 6 tokens
with 4 elements each, and we will calculate 6 sets of quantization parameters,
one for each token.

If the input tensor has only two dimensions, e.g. [8, 16], then this is
equivalent to `PerAxis(axis=0)`, which yields 8 sets of quantization parameters.
"""


def get_block_size(
input_shape: tuple[int, ...], granularity: Granularity
) -> tuple[int, ...]:
"""Get the block size based on the input shape and granularity type.

Args:
input_shape: The input tensor shape possibly more than 2 dimensions
granularity: The granularity type of the quantization
"""
assert isinstance(granularity, Granularity), (
"Please provide an instance of Granularity, not subclass of it"
)
if isinstance(granularity, PerTensor):
return input_shape
elif isinstance(granularity, PerAxis):
block_size = list(input_shape)
block_size[granularity.axis] = 1
return tuple(block_size)
elif isinstance(granularity, PerRow):
return (1,) * (len(input_shape) - 1) + (input_shape[-1],)
elif isinstance(granularity, PerGroup):
assert len(input_shape) == 2, (
f"Expecting input shape dim to be 2 for per group quantization, gotinput shape: {input_shape}"
)
return (1, granularity.group_size)
elif isinstance(granularity, PerToken):
block_size = [1] * len(input_shape)
block_size[-1] = input_shape[-1]
return tuple(block_size)
raise ValueError(f"Unsupported Granularity: {granularity}")


class AffineQuantizedObserverBase(ABC, torch.nn.Module):
"""Observer module for affine quantization (https://github.com/pytorch/ao/tree/main/torchao/quantization#affine-quantization)

Expand Down
2 changes: 1 addition & 1 deletion torchao/quantization/qat/fake_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
PerRow,
PerToken,
)
from torchao.quantization.observer import get_block_size
from torchao.quantization.quant_primitives import (
_DTYPE_TO_BIT_WIDTH,
_DTYPE_TO_QVALUE_BOUNDS,
Expand All @@ -28,6 +27,7 @@
)
from torchao.quantization.utils import (
_get_per_token_block_size,
get_block_size,
get_group_qparams_symmetric,
get_groupwise_affine_qparams,
)
Expand Down
3 changes: 2 additions & 1 deletion torchao/quantization/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@
from torchao.quantization.linear_activation_weight_observed_tensor import (
LinearActivationWeightObservedTensor,
)
from torchao.quantization.observer import AffineQuantizedObserverBase, get_block_size
from torchao.quantization.observer import AffineQuantizedObserverBase
from torchao.quantization.quantize_.common import (
KernelPreference,
)
Expand All @@ -87,6 +87,7 @@
_QUANTIZE_CONFIG_HANDLER,
register_quantize_module_handler,
)
from torchao.quantization.utils import get_block_size
from torchao.quantization.weight_tensor_linear_activation_quantization import (
to_weight_tensor_with_linear_activation_quantization_metadata,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
preprocess_scale,
)
from torchao.quantization.granularity import PerRow, PerTensor
from torchao.quantization.observer import get_block_size
from torchao.quantization.quant_primitives import (
_choose_scale_float8,
_dequantize_affine_float8,
Expand All @@ -34,6 +33,7 @@
QuantizeTensorKwargs,
_choose_quant_func_and_quantize_tensor,
)
from torchao.quantization.utils import get_block_size
from torchao.utils import (
TorchAOBaseTensor,
_is_fbgemm_genai_gpu_available,
Expand Down
Loading
Loading