Skip to content

Commit 68ce5b8

Browse files
authored
Refactor int8 dynamic quantization with call to quantize (#294)
Summary: Previously we added `quantize` as a general API (#256) for Affine Quantized tensor subclass, and also tensor subclass based dtype conversion in general. The plan is to use this to replace existing quant APIs including int4 weight only, int8 weight only, int8 dynamic quant and 8da4w (for executorch). This PR we started replacing the implementation of int8 dynamic quant API with `quantize` API with affine quantized tensor subclass. We'll make sure the performance does not regress for vit model. Test Plan: TORCH_LOGS='output_code' python tutorials/quantize_vit/run_vit_b_quant.py reference: elapsed_time: 1.4821058654785155 milliseconds after refactor: elapsed_time: 1.4804757690429688 milliseconds generated code diff: https://gist.github.com/jerryzh168/90c71107a5aaaa5d8dd2170c573e076d Reviewers: Subscribers: Tasks: Tags:
1 parent e7837d7 commit 68ce5b8

File tree

6 files changed

+205
-52
lines changed

6 files changed

+205
-52
lines changed

test/integration/test_integration.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1033,6 +1033,7 @@ def _test_lin_weight_subclass_api_impl(
10331033

10341034

10351035
@parameterized.expand(COMMON_DEVICE_DTYPE)
1036+
@unittest.skipIf(TORCH_VERSION_AFTER_2_4, "skip because there is some bug in inductor codegen")
10361037
def test_int8_dynamic_quant_subclass_api(self, device, dtype):
10371038
self._test_lin_weight_subclass_api_impl(
10381039
change_linear_weights_to_int8_dqtensors, device, 35, test_dtype=dtype

test/quantization/test_quant_api.py

Lines changed: 60 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,26 @@ def forward(self, x):
118118
x = self.linear2(x)
119119
return x
120120

121+
122+
def _ref_change_linear_weights_to_int8_dqtensors(model, filter_fn=None, **kwargs):
123+
"""
124+
The deprecated implementation for int8 dynamic quant API, used as a reference for
125+
numerics and performance
126+
"""
127+
from torchao.quantization.quant_api import _in_features_greater_than_16
128+
from torchao.quantization.quant_api import _is_linear
129+
from torchao.quantization.quant_api import _get_subclass_inserter
130+
from torchao.quantization.subclass import Int8DynamicallyQuantizedLinearWeight
131+
132+
if filter_fn is None:
133+
filter_fn = lambda *args: _is_linear(*args) and _in_features_greater_than_16(
134+
*args
135+
)
136+
137+
_replace_with_custom_fn_if_matches_filter(
138+
model, _get_subclass_inserter(Int8DynamicallyQuantizedLinearWeight, enable_parametrization=False, **kwargs), filter_fn
139+
)
140+
121141
class TestQuantFlow(unittest.TestCase):
122142
def test_dynamic_quant_gpu_singleline(self):
123143
m = ToyLinearModel().eval()
@@ -492,8 +512,8 @@ def test_quantized_tensor_subclass_int8(self):
492512
@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "Test only enabled for 2.4+")
493513
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
494514
def test_quantized_tensor_subclass_int8_dyn_quant(self):
495-
# use 1024 so that we don't need padding
496-
m = ToyLinearModel(1024, 1024, 1024).eval().to(torch.bfloat16).to("cuda")
515+
# use multiples of 1024 so that we don't need padding
516+
m = ToyLinearModel(1024, 1024, 2048).eval().to(torch.bfloat16).to("cuda")
497517
m_copy = copy.deepcopy(m)
498518
# setting batch_size to 20 to be compatible with the kernel
499519
example_inputs = m.example_inputs(batch_size=20, dtype=torch.bfloat16, device="cuda")
@@ -525,6 +545,44 @@ def test_quantized_tensor_subclass_int8_dyn_quant(self):
525545
# make sure it compiles
526546
torch._export.aot_compile(m_unwrapped, example_inputs)
527547

548+
@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "Test only enabled for 2.4+")
549+
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
550+
@unittest.skip("This perf test is supposed to be run locally for sanity check performance when there is a change of int8 dynamic quant implementation")
551+
def test_quantized_tensor_subclass_int8_dyn_quant_perf(self):
552+
m = ToyLinearModel(1024, 1024, 1024).eval().to(torch.bfloat16).to("cuda")
553+
m_ref = copy.deepcopy(m)
554+
# setting batch_size to 20 to be compatible with the kernel
555+
example_inputs = m.example_inputs(batch_size=20, dtype=torch.bfloat16, device="cuda")
556+
557+
from torchao.quantization.quant_api import change_linear_weights_to_int8_dqtensors
558+
change_linear_weights_to_int8_dqtensors(m)
559+
560+
# reference
561+
_ref_change_linear_weights_to_int8_dqtensors(m_ref)
562+
563+
res = m(*example_inputs)
564+
ref = m_ref(*example_inputs)
565+
566+
self.assertTrue(torch.equal(res, ref))
567+
568+
# perf comparison
569+
from torchao.utils import benchmark_model
570+
# warmup
571+
WARMUP = 5
572+
RUNS = 100
573+
input_tensor = example_inputs[0]
574+
m = torch.compile(m, mode='max-autotune', fullgraph=True)
575+
576+
benchmark_model(m, WARMUP, input_tensor)
577+
elapsed_time = benchmark_model(m, RUNS, input_tensor)
578+
579+
m_ref = torch.compile(m_ref, mode='max-autotune', fullgraph=True)
580+
benchmark_model(m_ref, WARMUP, input_tensor)
581+
ref_elapsed_time = benchmark_model(m_ref, RUNS, input_tensor)
582+
583+
print(f"elapsed time: {elapsed_time}, ref elapsed time: {ref_elapsed_time}")
584+
self.assertTrue(elapsed_time < 1.05 * ref_elapsed_time)
585+
528586

529587

530588
if __name__ == "__main__":

torchao/dtypes/aqt.py

Lines changed: 81 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,11 @@ def _apply_fn_to_data(self, fn):
177177
fn(self.zero_point),
178178
)
179179

180+
def _change_shape(self, shape):
181+
return self.__class__(
182+
self.int_data.view(shape), self.scale, self.zero_point
183+
)
184+
180185
@classmethod
181186
def __torch_dispatch__(cls, func, types, args, kwargs):
182187
kwargs = {} if kwargs is None else kwargs
@@ -186,6 +191,11 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
186191
func, args, kwargs, args[0]._apply_fn_to_data(torch.detach)
187192
)
188193

194+
if func is aten.view.default:
195+
assert len(args) == 2
196+
new = args[0]._change_shape(args[1])
197+
return return_and_correct_aliasing(func, args, kwargs, new)
198+
189199
raise NotImplementedError(
190200
f"PlainAQTLayout dispatch: attempting to run {func}, this is not supported"
191201
)
@@ -245,6 +255,7 @@ def __tensor_unflatten__(
245255
cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride
246256
):
247257
packed_weight, scale_and_zero = tensor_data_dict["packed_weight"], tensor_data_dict["scale_and_zero"]
258+
# TODO: fix the unflatten logic
248259
return cls(packed_weight, scale_and_zero)
249260

250261
def to(self, *args, **kwargs):
@@ -470,6 +481,11 @@ def _apply_fn_to_data(self, fn):
470481
strides=self.stride(),
471482
)
472483

484+
def _change_shape(self, shape, block_size):
485+
return self.__class__(
486+
self.layout_tensor.view(shape), block_size, shape, self.quant_min, self.quant_max, self.zero_point_domain, dtype=self.dtype, strides=self.stride()
487+
)
488+
473489
@classmethod
474490
def __torch_dispatch__(cls, func, types, args, kwargs):
475491
# Note: we only added cpu path here for 8da4w, this is for executorch, in the future
@@ -491,13 +507,7 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
491507
f"AffineQuantizedTensor dispatch: attempting to run {func}, this is not supported"
492508
)
493509

494-
@implements_aqt_torch_function(torch.nn.functional.linear)
495-
def functional_linear(*args, **kwargs):
496-
input_tensor, weight_qtensor, bias = (
497-
args[0],
498-
args[1],
499-
args[2] if len(args) > 2 else None,
500-
)
510+
def _quantized_linear_op(input_tensor, weight_qtensor, bias):
501511
is_cuda = weight_qtensor.is_cuda
502512
is_cpu = weight_qtensor.device == torch.device("cpu")
503513
if isinstance(weight_qtensor, AffineQuantizedTensor):
@@ -508,14 +518,10 @@ def functional_linear(*args, **kwargs):
508518
# if input tensor is quantized, either dispatch to the int8 mm kernel
509519
# or just dequantize the input tensor
510520
input_is_int8 = _aqt_is_int8_reduced_range(input_tensor)
511-
input_tensor_dtype_is_expected = input_tensor.dtype in [
512-
torch.float,
513-
torch.bfloat16
514-
]
515521
if (
516522
is_cuda and
517523
input_is_int8 and
518-
input_tensor_dtype_is_expected and
524+
input_tensor.dtype == weight_qtensor.dtype and
519525
input_tensor.layout == "plain" and
520526
weight_qtensor.layout == "plain"
521527
):
@@ -576,45 +582,83 @@ def functional_linear(*args, **kwargs):
576582
weight_qtensor.block_size[1] == weight_qtensor.shape[1] and
577583
weight_qtensor.layout == "plain"
578584
):
579-
# TODO: enable mps path as well
585+
# TODO: enable cpu and mps efficient path
580586
# per channel int8 weight only quantizated mm
581-
return torch.ops.aten._weight_int8pack_mm(input_tensor.contiguous(), weight_qtensor.layout_tensor.int_data, weight_qtensor.layout_tensor.scale)
582-
else:
583-
weight_tensor = weight_qtensor.dequantize()
584-
return torch.nn.functional.linear(input_tensor, weight_tensor, bias)
585-
else:
587+
w_vals_int8_t = weight_qtensor.layout_tensor.int_data.t().contiguous()
588+
orig_dtype = input_tensor.dtype
589+
y = (
590+
torch.mm(
591+
input_tensor.reshape(-1, input_tensor.shape[-1]),
592+
w_vals_int8_t.to(input_tensor.dtype),
593+
)
594+
* weight_qtensor.scale
595+
)
596+
y = y.reshape(*input_tensor.shape[:-1], y.shape[-1])
597+
if bias is not None:
598+
y += bias
599+
return y.to(orig_dtype)
600+
601+
# is_cpu and is_mps only, some issue with is_contiguous() currently
602+
# return torch.ops.aten._weight_int8pack_mm(input_tensor.contiguous(), w_vals_int8_t, weight_qtensor.layout_tensor.scale)
603+
604+
raise NotImplementedError("No specialized dispatch found for quantized linear op")
605+
606+
607+
@implements_aqt_torch_function(torch.nn.functional.linear)
608+
def functional_linear(*args, **kwargs):
609+
input_tensor, weight_tensor, bias = (
610+
args[0],
611+
args[1],
612+
args[2] if len(args) > 2 else None,
613+
)
614+
# using try/except here so that we can have a general fallback when input_tensor/weight_tensor
615+
# is not picked up by any of the dispatch paths in `_quantized_linear_op`, this allows us to
616+
# make the branches easier to understand in `_quantized_linear_op`
617+
try:
618+
return _quantized_linear_op(input_tensor, weight_tensor, bias)
619+
except:
586620
if isinstance(input_tensor, AffineQuantizedTensor):
587621
input_tensor = input_tensor.dequantize()
622+
if isinstance(weight_tensor, AffineQuantizedTensor):
623+
weight_tensor = weight_tensor.dequantize()
588624
return torch.nn.functional.linear(input_tensor, weight_tensor, bias)
589625

590-
591626
@implements_aqt_aten_ops([aten.mm.default, aten.addmm.default])
592627
def aten_mm(func, *args, **kwargs):
593628
if not args[0].is_floating_point():
594629
raise NotImplementedError(f"{func} is not implemented for non floating point input")
595630

631+
# using try/except here so that we can have a general fallback when input_tensor/weight_tensor
632+
# is not picked up by any of the dispatch paths in `_quantized_linear_op`, this allows us to
633+
# make the branches easier to understand in `_quantized_linear_op`
596634
if func == aten.addmm.default:
597-
assert args[1].shape[-1] == args[2].shape[0], (
598-
f"need mat1 shape: {args[1].shape} final"
599-
f"dim to match mat2 shape: {args[2].shape} first dim "
600-
)
601-
input_tensor, weight_qtensor, bias = (
635+
input_tensor, weight_tensor, bias = (
602636
args[1],
603637
args[2],
604638
args[0],
605639
)
640+
try:
641+
return _quantized_linear_op(input_tensor, weight_tensor, bias)
642+
except:
643+
if isinstance(input_tensor, AffineQuantizedTensor):
644+
input_tensor = input_tensor.dequantize()
645+
if isinstance(weight_tensor, AffineQuantizedTensor):
646+
weight_tensor = weight_tensor.dequantize()
647+
return func(bias, input_tensor, weight_tensor)
606648
else:
607-
assert args[0].shape[-1] == args[1].shape[0], (
608-
f"need mat1 shape: {args[0].shape} final dim"
609-
f"to match mat2 shape: {args[1].shape} first dim"
610-
)
611-
input_tensor, weight_qtensor, bias = (
649+
input_tensor, weight_tensor, bias = (
612650
args[0],
613651
args[1],
614-
None if len(args) == 2 else args[2],
652+
None
615653
)
616-
weight_tensor = weight_qtensor.dequantize()
617-
return func(input_tensor, weight_tensor, bias)
654+
try:
655+
return _quantized_linear_op(input_tensor, weight_tensor, bias)
656+
except:
657+
if isinstance(input_tensor, AffineQuantizedTensor):
658+
input_tensor = input_tensor.dequantize()
659+
if isinstance(weight_tensor, AffineQuantizedTensor):
660+
weight_tensor = weight_tensor.dequantize()
661+
return func(input_tensor, weight_tensor)
618662

619663
@implements_aqt_aten_ops([aten.detach.default])
620664
def detach(func, *args, **kwargs):
@@ -641,10 +685,10 @@ def _to_copy(func, *args, **kwargs):
641685

642686
@implements_aqt_aten_ops([aten.t.default])
643687
def t(func, *args, **kwargs):
644-
# TODO: need to implement this
645-
# args[0].transposed = not args[0].transposed
646-
# new = args[0]._change_shape(args[0].shape[::-1])
647-
# return return_and_correct_aliasing(func, args, kwargs, new)
648-
raise Exception("transpose not implemented yet")
688+
block_size = args[0].block_size
689+
assert len(block_size) == 2
690+
transposed_block_size = (block_size[1], block_size[0])
691+
new = args[0]._change_shape(args[0].shape[::-1], transposed_block_size)
692+
return return_and_correct_aliasing(func, args, kwargs, new)
649693

650694
to_aq = AffineQuantizedTensor.from_float

torchao/quantization/quant_api.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,11 @@
2525
from typing import Any, Callable
2626

2727
from .dynamic_quant import DynamicallyPerAxisQuantizedLinear
28-
from .utils import TORCH_VERSION_AFTER_2_3, TORCH_VERSION_AFTER_2_4
28+
from .utils import (
29+
TORCH_VERSION_AFTER_2_3,
30+
TORCH_VERSION_AFTER_2_4,
31+
unwrap_tensor_subclass,
32+
)
2933

3034
from .subclass import (
3135
Int4WeightOnlyQuantizedLinearWeight,
@@ -187,9 +191,13 @@ def change_linear_weights_to_int8_dqtensors(model, filter_fn=None, **kwargs):
187191
*args
188192
)
189193

190-
_replace_with_custom_fn_if_matches_filter(
191-
model, _get_subclass_inserter(Int8DynamicallyQuantizedLinearWeight, enable_parametrization=TORCH_VERSION_AFTER_2_4, **kwargs), filter_fn
192-
)
194+
if TORCH_VERSION_AFTER_2_4:
195+
quantize(model, get_apply_int8dyn_quant(), filter_fn)
196+
unwrap_tensor_subclass(model, filter_fn)
197+
else:
198+
_replace_with_custom_fn_if_matches_filter(
199+
model, _get_subclass_inserter(Int8DynamicallyQuantizedLinearWeight, enable_parametrization=False, **kwargs), filter_fn
200+
)
193201

194202

195203
def change_linear_weights_to_int8_woqtensors(model, filter_fn=None, **kwargs):
@@ -282,7 +290,7 @@ def quantize(model: torch.nn.Module, apply_tensor_subclass: Callable[[torch.Tens
282290
zero_point_dtype = torch.bfloat16
283291
zero_point_domain = ZeroPointDomain.FLOAT
284292
285-
apply_weight_quant = lambda x: to_aqt(x, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, zero_point_dtype=zero_point_dtype, preserve_zero=preserve_zero, zero_point_domain=zero_point_domain)
293+
apply_weight_quant = lambda x: to_aq(x, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, zero_point_dtype=zero_point_dtype, preserve_zero=preserve_zero, zero_point_domain=zero_point_domain)
286294
287295
# apply to modules under block0 submodule
288296
def filter_fn(module, fqn):

0 commit comments

Comments
 (0)