Skip to content

Commit 9f366a9

Browse files
committed
Merge branch 'main' into cpu_int_scaled_mm_2
2 parents 0930f71 + cbd90e3 commit 9f366a9

File tree

4 files changed

+57
-6
lines changed

4 files changed

+57
-6
lines changed

.github/workflows/regression_test.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ jobs:
4040
gpu-arch-version: "12.1"
4141
- name: CUDA Nightly
4242
runs-on: linux.g5.12xlarge.nvidia.gpu
43-
torch-spec: '--pre torch==2.6.0.dev20241022 --index-url https://download.pytorch.org/whl/nightly/cu121'
43+
torch-spec: '--pre torch --index-url https://download.pytorch.org/whl/nightly/cu121'
4444
gpu-arch-type: "cuda"
4545
gpu-arch-version: "12.1"
4646

test/integration/test_integration.py

+16-3
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@
7474
AutoQuantizableLinearWeight,
7575
AQFloat8WeightOnlyQuantizedLinearWeight,
7676
AQFloat8PerRowScalingDynamicallyQuantizedLinearWeight,
77+
AQFloat8PerTensorScalingDynamicallyQuantizedLinearWeight,
7778
)
7879
from torch.ao.quantization.quantize_fx import convert_to_reference_fx, prepare_fx
7980
import os
@@ -770,11 +771,23 @@ def test_aq_float8_weight_only_quant_subclass(self, device, dtype):
770771
@parameterized.expand(COMMON_DEVICE_DTYPE)
771772
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "autoquant+aqt needs newer pytorch")
772773
@unittest.skipIf(not is_H100, "Need H100 to run")
773-
def test_aq_float8_dynamic_quant_subclass(self, device, dtype):
774+
def test_aq_float8_dynamic_quant_rowwise_scaling_subclass(self, device, dtype):
774775
if dtype != torch.bfloat16:
775-
self.skipTest("Fails for {dtype}")
776+
with self.assertRaisesRegex(AssertionError, "PerRow quantization only works for bfloat16 precision"):
777+
self._test_lin_weight_subclass_impl(
778+
AQFloat8PerRowScalingDynamicallyQuantizedLinearWeight.from_float, device, 25, test_dtype=dtype
779+
)
780+
else:
781+
self._test_lin_weight_subclass_impl(
782+
AQFloat8PerRowScalingDynamicallyQuantizedLinearWeight.from_float, device, 25, test_dtype=dtype
783+
)
784+
785+
@parameterized.expand(COMMON_DEVICE_DTYPE)
786+
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "autoquant+aqt needs newer pytorch")
787+
@unittest.skipIf(not is_H100, "Need H100 to run")
788+
def test_aq_float8_dynamic_quant_tensorwise_scaling_subclass(self, device, dtype):
776789
self._test_lin_weight_subclass_impl(
777-
AQFloat8PerRowScalingDynamicallyQuantizedLinearWeight.from_float, device, 25, test_dtype=dtype
790+
AQFloat8PerTensorScalingDynamicallyQuantizedLinearWeight.from_float, device, 25, test_dtype=dtype
778791
)
779792

780793
@parameterized.expand(COMMON_DEVICE_DTYPE)

torchao/quantization/autoquant.py

+39-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from typing import Callable
12
import torch
23
import torchao
34
from torchao.quantization.quant_primitives import (
@@ -500,7 +501,7 @@ class AQFloat8PerRowScalingDynamicallyQuantizedLinearWeight(AQMixin, LinearActiv
500501
"""
501502
AutoQuantizable version of Float8DynamicallyQuantizedLinearWeight using per row scaling
502503
"""
503-
activation_granularity: str = PerRow()
504+
activation_granularity = PerRow()
504505
@classmethod
505506
def from_float(cls, weight):
506507

@@ -537,6 +538,42 @@ def get_per_token_block_size(x):
537538
weight = super(AQFloat8PerRowScalingDynamicallyQuantizedLinearWeight, cls).from_float(weight, input_quant_func)
538539
return weight
539540

541+
class AQFloat8PerTensorScalingDynamicallyQuantizedLinearWeight(AQMixin, LinearActivationQuantizedTensor):
542+
"""
543+
AutoQuantizable version of Float8DynamicallyQuantizedLinearWeight using per tensor scaling
544+
"""
545+
activation_granularity = PerTensor()
546+
@classmethod
547+
def from_float(cls, weight):
548+
549+
# avoid circular dep
550+
from torchao.dtypes import to_affine_quantized_floatx
551+
from torchao.quantization.quant_api import _input_activation_quant_func_fp8
552+
# weight settings
553+
def get_weight_block_size(x):
554+
assert x.ndim == 2, "Only works for 2D tensors"
555+
return x.shape
556+
target_dtype = torch.float8_e4m3fn
557+
558+
input_target_dtype = torch.float8_e4m3fn
559+
_layout = Float8Layout(mm_config=Float8MMConfig(use_fast_accum=True))
560+
input_quant_func = lambda x: _input_activation_quant_func_fp8(
561+
x=x,
562+
activation_granularity=cls.activation_granularity,
563+
activation_dtype=input_target_dtype,
564+
)
565+
block_size = get_weight_block_size(weight)
566+
weight = to_affine_quantized_floatx(
567+
input_float=weight,
568+
block_size=block_size,
569+
target_dtype=target_dtype,
570+
_layout=_layout,
571+
scale_dtype=torch.float32,
572+
)
573+
from torchao.float8.inference import _is_rowwise_scaled
574+
weight = super(AQFloat8PerTensorScalingDynamicallyQuantizedLinearWeight, cls).from_float(weight, input_quant_func)
575+
return weight
576+
540577

541578
# here we don't include int4 quantization in since int8 tends to be a better apples to apples comparison
542579
DEFAULT_AUTOQUANT_CLASS_LIST = [
@@ -557,6 +594,7 @@ def get_per_token_block_size(x):
557594
OTHER_AUTOQUANT_CLASS_LIST = [
558595
AQFloat8WeightOnlyQuantizedLinearWeight,
559596
AQFloat8PerRowScalingDynamicallyQuantizedLinearWeight,
597+
AQFloat8PerTensorScalingDynamicallyQuantizedLinearWeight,
560598
]
561599

562600

torchao/utils.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -389,7 +389,7 @@ class MyTensor(torch.Tensor):
389389
return cls._ATEN_OP_OR_TORCH_FN_TABLE[func](func, types, args, kwargs)
390390

391391
arg_types = tuple(type(arg) for arg in args)
392-
kwarg_types = {k: type(arg) for k, arg in kwargs}
392+
kwarg_types = {k: type(arg) for k, arg in kwargs.items()}
393393
raise NotImplementedError(f"{cls.__name__} dispatch: attempting to run unimplemented operator/function: {func=}, {types=}, {arg_types=}, {kwarg_types=}")
394394

395395
def _register_layout(tensor_class: Callable, layout_class: Callable):

0 commit comments

Comments
 (0)