Skip to content

Refactor int8 dynamic quantization with call to quantize #294

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions test/integration/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -1033,6 +1033,7 @@ def _test_lin_weight_subclass_api_impl(


@parameterized.expand(COMMON_DEVICE_DTYPE)
@unittest.skipIf(TORCH_VERSION_AFTER_2_4, "skip because there is some bug in inductor codegen")
def test_int8_dynamic_quant_subclass_api(self, device, dtype):
self._test_lin_weight_subclass_api_impl(
change_linear_weights_to_int8_dqtensors, device, 35, test_dtype=dtype
Expand Down
62 changes: 60 additions & 2 deletions test/quantization/test_quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,26 @@ def forward(self, x):
x = self.linear2(x)
return x


def _ref_change_linear_weights_to_int8_dqtensors(model, filter_fn=None, **kwargs):
"""
The deprecated implementation for int8 dynamic quant API, used as a reference for
numerics and performance
"""
from torchao.quantization.quant_api import _in_features_greater_than_16
from torchao.quantization.quant_api import _is_linear
from torchao.quantization.quant_api import _get_subclass_inserter
from torchao.quantization.subclass import Int8DynamicallyQuantizedLinearWeight

if filter_fn is None:
filter_fn = lambda *args: _is_linear(*args) and _in_features_greater_than_16(
*args
)

_replace_with_custom_fn_if_matches_filter(
model, _get_subclass_inserter(Int8DynamicallyQuantizedLinearWeight, enable_parametrization=False, **kwargs), filter_fn
)

class TestQuantFlow(unittest.TestCase):
def test_dynamic_quant_gpu_singleline(self):
m = ToyLinearModel().eval()
Expand Down Expand Up @@ -492,8 +512,8 @@ def test_quantized_tensor_subclass_int8(self):
@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "Test only enabled for 2.4+")
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
def test_quantized_tensor_subclass_int8_dyn_quant(self):
# use 1024 so that we don't need padding
m = ToyLinearModel(1024, 1024, 1024).eval().to(torch.bfloat16).to("cuda")
# use multiples of 1024 so that we don't need padding
m = ToyLinearModel(1024, 1024, 2048).eval().to(torch.bfloat16).to("cuda")
m_copy = copy.deepcopy(m)
# setting batch_size to 20 to be compatible with the kernel
example_inputs = m.example_inputs(batch_size=20, dtype=torch.bfloat16, device="cuda")
Expand Down Expand Up @@ -525,6 +545,44 @@ def test_quantized_tensor_subclass_int8_dyn_quant(self):
# make sure it compiles
torch._export.aot_compile(m_unwrapped, example_inputs)

@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "Test only enabled for 2.4+")
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@unittest.skip("This perf test is supposed to be run locally for sanity check performance when there is a change of int8 dynamic quant implementation")
def test_quantized_tensor_subclass_int8_dyn_quant_perf(self):
m = ToyLinearModel(1024, 1024, 1024).eval().to(torch.bfloat16).to("cuda")
m_ref = copy.deepcopy(m)
# setting batch_size to 20 to be compatible with the kernel
example_inputs = m.example_inputs(batch_size=20, dtype=torch.bfloat16, device="cuda")

from torchao.quantization.quant_api import change_linear_weights_to_int8_dqtensors
change_linear_weights_to_int8_dqtensors(m)

# reference
_ref_change_linear_weights_to_int8_dqtensors(m_ref)

res = m(*example_inputs)
ref = m_ref(*example_inputs)

self.assertTrue(torch.equal(res, ref))

# perf comparison
from torchao.utils import benchmark_model
# warmup
WARMUP = 5
RUNS = 100
input_tensor = example_inputs[0]
m = torch.compile(m, mode='max-autotune', fullgraph=True)

benchmark_model(m, WARMUP, input_tensor)
elapsed_time = benchmark_model(m, RUNS, input_tensor)

m_ref = torch.compile(m_ref, mode='max-autotune', fullgraph=True)
benchmark_model(m_ref, WARMUP, input_tensor)
ref_elapsed_time = benchmark_model(m_ref, RUNS, input_tensor)

print(f"elapsed time: {elapsed_time}, ref elapsed time: {ref_elapsed_time}")
self.assertTrue(elapsed_time < 1.05 * ref_elapsed_time)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this test might end up being flaky, also how long does this test take? strange to do a benchmark for unit tests

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is pretty quick when I run it in my A100 machine, finishes in a few seconds. I could also skip this by default and just have people run this locally when making changes to these APIs

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

skipped this one by default

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah doing benchmarks in unit tests is a known anti pattern. Test environments don't need to be inconsistent and it's likely a waste of resources to make them so.




if __name__ == "__main__":
Expand Down
118 changes: 81 additions & 37 deletions torchao/dtypes/aqt.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,11 @@ def _apply_fn_to_data(self, fn):
fn(self.zero_point),
)

def _change_shape(self, shape):
return self.__class__(
self.int_data.view(shape), self.scale, self.zero_point
)

@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs):
kwargs = {} if kwargs is None else kwargs
Expand All @@ -186,6 +191,11 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
func, args, kwargs, args[0]._apply_fn_to_data(torch.detach)
)

if func is aten.view.default:
assert len(args) == 2
new = args[0]._change_shape(args[1])
return return_and_correct_aliasing(func, args, kwargs, new)

raise NotImplementedError(
f"PlainAQTLayout dispatch: attempting to run {func}, this is not supported"
)
Expand Down Expand Up @@ -245,6 +255,7 @@ def __tensor_unflatten__(
cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride
):
packed_weight, scale_and_zero = tensor_data_dict["packed_weight"], tensor_data_dict["scale_and_zero"]
# TODO: fix the unflatten logic
return cls(packed_weight, scale_and_zero)

def to(self, *args, **kwargs):
Expand Down Expand Up @@ -470,6 +481,11 @@ def _apply_fn_to_data(self, fn):
strides=self.stride(),
)

def _change_shape(self, shape, block_size):
return self.__class__(
self.layout_tensor.view(shape), block_size, shape, self.quant_min, self.quant_max, self.zero_point_domain, dtype=self.dtype, strides=self.stride()
)

@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs):
# Note: we only added cpu path here for 8da4w, this is for executorch, in the future
Expand All @@ -491,13 +507,7 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
f"AffineQuantizedTensor dispatch: attempting to run {func}, this is not supported"
)

@implements_aqt_torch_function(torch.nn.functional.linear)
def functional_linear(*args, **kwargs):
input_tensor, weight_qtensor, bias = (
args[0],
args[1],
args[2] if len(args) > 2 else None,
)
def _quantized_linear_op(input_tensor, weight_qtensor, bias):
is_cuda = weight_qtensor.is_cuda
is_cpu = weight_qtensor.device == torch.device("cpu")
if isinstance(weight_qtensor, AffineQuantizedTensor):
Expand All @@ -508,14 +518,10 @@ def functional_linear(*args, **kwargs):
# if input tensor is quantized, either dispatch to the int8 mm kernel
# or just dequantize the input tensor
input_is_int8 = _aqt_is_int8_reduced_range(input_tensor)
input_tensor_dtype_is_expected = input_tensor.dtype in [
torch.float,
torch.bfloat16
]
if (
is_cuda and
input_is_int8 and
input_tensor_dtype_is_expected and
input_tensor.dtype == weight_qtensor.dtype and
input_tensor.layout == "plain" and
weight_qtensor.layout == "plain"
):
Expand Down Expand Up @@ -576,45 +582,83 @@ def functional_linear(*args, **kwargs):
weight_qtensor.block_size[1] == weight_qtensor.shape[1] and
weight_qtensor.layout == "plain"
):
# TODO: enable mps path as well
# TODO: enable cpu and mps efficient path
# per channel int8 weight only quantizated mm
return torch.ops.aten._weight_int8pack_mm(input_tensor.contiguous(), weight_qtensor.layout_tensor.int_data, weight_qtensor.layout_tensor.scale)
else:
weight_tensor = weight_qtensor.dequantize()
return torch.nn.functional.linear(input_tensor, weight_tensor, bias)
else:
w_vals_int8_t = weight_qtensor.layout_tensor.int_data.t().contiguous()
orig_dtype = input_tensor.dtype
y = (
torch.mm(
input_tensor.reshape(-1, input_tensor.shape[-1]),
w_vals_int8_t.to(input_tensor.dtype),
)
* weight_qtensor.scale
)
y = y.reshape(*input_tensor.shape[:-1], y.shape[-1])
if bias is not None:
y += bias
return y.to(orig_dtype)

# is_cpu and is_mps only, some issue with is_contiguous() currently
# return torch.ops.aten._weight_int8pack_mm(input_tensor.contiguous(), w_vals_int8_t, weight_qtensor.layout_tensor.scale)

raise NotImplementedError("No specialized dispatch found for quantized linear op")


@implements_aqt_torch_function(torch.nn.functional.linear)
def functional_linear(*args, **kwargs):
input_tensor, weight_tensor, bias = (
args[0],
args[1],
args[2] if len(args) > 2 else None,
)
# using try/except here so that we can have a general fallback when input_tensor/weight_tensor
# is not picked up by any of the dispatch paths in `_quantized_linear_op`, this allows us to
# make the branches easier to understand in `_quantized_linear_op`
try:
return _quantized_linear_op(input_tensor, weight_tensor, bias)
except:
if isinstance(input_tensor, AffineQuantizedTensor):
input_tensor = input_tensor.dequantize()
if isinstance(weight_tensor, AffineQuantizedTensor):
weight_tensor = weight_tensor.dequantize()
return torch.nn.functional.linear(input_tensor, weight_tensor, bias)


@implements_aqt_aten_ops([aten.mm.default, aten.addmm.default])
def aten_mm(func, *args, **kwargs):
if not args[0].is_floating_point():
raise NotImplementedError(f"{func} is not implemented for non floating point input")

# using try/except here so that we can have a general fallback when input_tensor/weight_tensor
# is not picked up by any of the dispatch paths in `_quantized_linear_op`, this allows us to
# make the branches easier to understand in `_quantized_linear_op`
if func == aten.addmm.default:
assert args[1].shape[-1] == args[2].shape[0], (
f"need mat1 shape: {args[1].shape} final"
f"dim to match mat2 shape: {args[2].shape} first dim "
)
input_tensor, weight_qtensor, bias = (
input_tensor, weight_tensor, bias = (
args[1],
args[2],
args[0],
)
try:
return _quantized_linear_op(input_tensor, weight_tensor, bias)
except:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there' a bunch of code duplication here, also why do we need the try except block?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh we actually need to call the function in different ways here, not sure when the change is reverted, will fix

try except is used as a fallback when the specific configuration of input and weight tensor is not caught by any of the special dispatches in _quantized_linear_op

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added some comments

if isinstance(input_tensor, AffineQuantizedTensor):
input_tensor = input_tensor.dequantize()
if isinstance(weight_tensor, AffineQuantizedTensor):
weight_tensor = weight_tensor.dequantize()
return func(bias, input_tensor, weight_tensor)
else:
assert args[0].shape[-1] == args[1].shape[0], (
f"need mat1 shape: {args[0].shape} final dim"
f"to match mat2 shape: {args[1].shape} first dim"
)
input_tensor, weight_qtensor, bias = (
input_tensor, weight_tensor, bias = (
args[0],
args[1],
None if len(args) == 2 else args[2],
None
)
weight_tensor = weight_qtensor.dequantize()
return func(input_tensor, weight_tensor, bias)
try:
return _quantized_linear_op(input_tensor, weight_tensor, bias)
except:
if isinstance(input_tensor, AffineQuantizedTensor):
input_tensor = input_tensor.dequantize()
if isinstance(weight_tensor, AffineQuantizedTensor):
weight_tensor = weight_tensor.dequantize()
return func(input_tensor, weight_tensor)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so here is a difference of how we call the function, since we have aten.mm here, the order of passing around the args are different from aten.addmm


@implements_aqt_aten_ops([aten.detach.default])
def detach(func, *args, **kwargs):
Expand All @@ -641,10 +685,10 @@ def _to_copy(func, *args, **kwargs):

@implements_aqt_aten_ops([aten.t.default])
def t(func, *args, **kwargs):
# TODO: need to implement this
# args[0].transposed = not args[0].transposed
# new = args[0]._change_shape(args[0].shape[::-1])
# return return_and_correct_aliasing(func, args, kwargs, new)
raise Exception("transpose not implemented yet")
block_size = args[0].block_size
assert len(block_size) == 2
transposed_block_size = (block_size[1], block_size[0])
new = args[0]._change_shape(args[0].shape[::-1], transposed_block_size)
return return_and_correct_aliasing(func, args, kwargs, new)

to_aq = AffineQuantizedTensor.from_float
18 changes: 13 additions & 5 deletions torchao/quantization/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,11 @@
from typing import Any, Callable

from .dynamic_quant import DynamicallyPerAxisQuantizedLinear
from .utils import TORCH_VERSION_AFTER_2_3, TORCH_VERSION_AFTER_2_4
from .utils import (
TORCH_VERSION_AFTER_2_3,
TORCH_VERSION_AFTER_2_4,
unwrap_tensor_subclass,
)

from .subclass import (
Int4WeightOnlyQuantizedLinearWeight,
Expand Down Expand Up @@ -187,9 +191,13 @@ def change_linear_weights_to_int8_dqtensors(model, filter_fn=None, **kwargs):
*args
)

_replace_with_custom_fn_if_matches_filter(
model, _get_subclass_inserter(Int8DynamicallyQuantizedLinearWeight, enable_parametrization=TORCH_VERSION_AFTER_2_4, **kwargs), filter_fn
)
if TORCH_VERSION_AFTER_2_4:
quantize(model, get_apply_int8dyn_quant(), filter_fn)
unwrap_tensor_subclass(model, filter_fn)
else:
_replace_with_custom_fn_if_matches_filter(
model, _get_subclass_inserter(Int8DynamicallyQuantizedLinearWeight, enable_parametrization=False, **kwargs), filter_fn
)


def change_linear_weights_to_int8_woqtensors(model, filter_fn=None, **kwargs):
Expand Down Expand Up @@ -282,7 +290,7 @@ def quantize(model: torch.nn.Module, apply_tensor_subclass: Callable[[torch.Tens
zero_point_dtype = torch.bfloat16
zero_point_domain = ZeroPointDomain.FLOAT

apply_weight_quant = lambda x: to_aqt(x, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, zero_point_dtype=zero_point_dtype, preserve_zero=preserve_zero, zero_point_domain=zero_point_domain)
apply_weight_quant = lambda x: to_aq(x, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, zero_point_dtype=zero_point_dtype, preserve_zero=preserve_zero, zero_point_domain=zero_point_domain)

# apply to modules under block0 submodule
def filter_fn(module, fqn):
Expand Down
Loading
Loading