Skip to content

Commit

Permalink
Refactor int4 and int8 weight only quantization to use quantize (py…
Browse files Browse the repository at this point in the history
…torch#301)

* Replace implementation for int8 dynamic quantization with call to `quantize`

Summary:
Previously we added `quantize` as a general API (pytorch#256) for
Affine Quantized tensor subclass, and also tensor subclass based dtype conversion in general.

The plan is to use this to replace existing quant APIs including int4 weight only, int8 weight only, int8 dynamic quant
and 8da4w (for executorch).

This PR we started replacing the implementation of int8 dynamic quant API with `quantize` API with affine quantized tensor
subclass. We'll make sure the performance does not regress for vit model.

Test Plan:
TORCH_LOGS='output_code' python tutorials/quantize_vit/run_vit_b_quant.py

reference: elapsed_time:  1.4821058654785155  milliseconds
after refactor: elapsed_time:  1.4804757690429688  milliseconds

generated code diff: https://gist.github.com/jerryzh168/90c71107a5aaaa5d8dd2170c573e076d

Reviewers:

Subscribers:

Tasks:

Tags:

* Refactor int8 weight only quant to use `quantize`

Summary:
Similar to pytorch#294 we replaced the implementation
of int8 weight only quant to used the newly added `quantize` function, as a part of
the unification effort for affine quantization

Test Plan:
1. unit perf test:
python test/quantization/test_quant_api.py -k test_quantized_tensor_subclass_int8_wo_quant_perf

elapsed time: 0.23909856796264647, ref elapsed time: 0.25150911331176756
elapsed time: 0.24894208908081056, ref elapsed time: 0.2570047950744629
elapsed time: 0.21607391357421876, ref elapsed time: 0.22809568405151368

2. integration test:

TORCH_LOGS='output_code' python tutorials/quantize_vit/run_vit_b_quant.py

Reference: elapsed_time:  1.355208740234375  milliseconds
After refactor: elapsed_time:  1.32778857421875  milliseconds

code diff (gist): https://gist.github.com/jerryzh168/921a722cf20d476c8fc5888482e722dc
code diff (meta-only paste): https://www.internalfb.com/phabricator/paste/view/P1387333845

Reviewers:

Subscribers:

Tasks:

Tags:

* Replace implementation for int8 dynamic quantization with call to `quantize`

Summary:
Previously we added `quantize` as a general API (pytorch#256) for
Affine Quantized tensor subclass, and also tensor subclass based dtype conversion in general.

The plan is to use this to replace existing quant APIs including int4 weight only, int8 weight only, int8 dynamic quant
and 8da4w (for executorch).

This PR we started replacing the implementation of int8 dynamic quant API with `quantize` API with affine quantized tensor
subclass. We'll make sure the performance does not regress for vit model.

Test Plan:
TORCH_LOGS='output_code' python tutorials/quantize_vit/run_vit_b_quant.py

reference: elapsed_time:  1.4821058654785155  milliseconds
after refactor: elapsed_time:  1.4804757690429688  milliseconds

generated code diff: https://gist.github.com/jerryzh168/90c71107a5aaaa5d8dd2170c573e076d

Reviewers:

Subscribers:

Tasks:

Tags:

* Refactor int4 weight only quantization with call to `quantize`

Summary:
This is similar to pytorch#294 but applied for int4 weight only quantization

Test Plan:

unit perf test:
python test/quantization/test_quant_api.py -k test_quantized_tensor_subclass_int4_wo_quant_perf
elapsed time: 0.2166275215148926, ref elapsed time: 0.2191881561279297
elapsed time: 0.2376406478881836, ref elapsed time: 0.22721023559570314
elapsed time: 0.21919679641723633, ref elapsed time: 0.2154969596862793

integration perf test:

reference: elapsed_time:  2.5900126953125  milliseconds
after refactor: elapsed_time:  2.56680078125  milliseconds
diff: no diff

TORCH_LOGS='output_code' python tutorials/quantize_vit/run_vit_b_quant.py

Before:
After:
generated code diff:

Reviewers:

Subscribers:

Tasks:

Tags:

---------

Co-authored-by: Mark Saroufim <marksaroufim@meta.com>
  • Loading branch information
jerryzh168 and msaroufim authored Jun 4, 2024
1 parent 1433cb4 commit 9d1d8df
Show file tree
Hide file tree
Showing 8 changed files with 397 additions and 169 deletions.
118 changes: 118 additions & 0 deletions benchmarks/benchmark_aq.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
"""Benchmarks for affine quantized tensor, this includes int8 dynamic quant, int8 weight only quant and int4 weight only quant APIs
"""
import torch
from torchao.quantization.subclass import (
Int8WeightOnlyQuantizedLinearWeight,
Int4WeightOnlyQuantizedLinearWeight,
)
from torchao.quantization.utils import (
TORCH_VERSION_AFTER_2_4,
)
from torchao.quantization.quant_api import (
_replace_with_custom_fn_if_matches_filter,
)
import copy

class ToyLinearModel(torch.nn.Module):
def __init__(self, m=64, n=32, k=64):
super().__init__()
self.linear1 = torch.nn.Linear(m, n, bias=False).to(torch.float)
self.linear2 = torch.nn.Linear(n, k, bias=False).to(torch.float)

def example_inputs(self, batch_size=1, dtype=torch.float, device="cpu"):
return (torch.randn(batch_size, self.linear1.in_features, dtype=dtype, device=device),)

def forward(self, x):
x = self.linear1(x)
x = self.linear2(x)
return x

def _ref_change_linear_weights_to_int8_dqtensors(model, filter_fn=None, **kwargs):
"""
The deprecated implementation for int8 dynamic quant API, used as a reference for
numerics and performance
"""
from torchao.quantization.quant_api import _in_features_greater_than_16
from torchao.quantization.quant_api import _is_linear
from torchao.quantization.quant_api import _get_subclass_inserter
from torchao.quantization.subclass import Int8DynamicallyQuantizedLinearWeight

if filter_fn is None:
filter_fn = lambda *args: _is_linear(*args) and _in_features_greater_than_16(
*args
)

_replace_with_custom_fn_if_matches_filter(
model, _get_subclass_inserter(Int8DynamicallyQuantizedLinearWeight, enable_parametrization=False, **kwargs), filter_fn
)

def _get_ref_change_linear_weights_to_woqtensors(deprecated_tenosr_subclass):
def _ref_change_linear_weights_to_woqtensors(model, filter_fn=None, **kwargs):
"""
The deprecated implementation for weight only quant API, used as a reference for
numerics and performance
"""
from torchao.quantization.quant_api import _is_linear
from torchao.quantization.quant_api import _get_subclass_inserter

filter_fn = kwargs.pop("filter_fn", _is_linear)

_replace_with_custom_fn_if_matches_filter(
model,
_get_subclass_inserter(deprecated_tenosr_subclass, enable_parametrization=True, **kwargs),
filter_fn,
)

return _ref_change_linear_weights_to_woqtensors

_ref_change_linear_weights_to_int8_woqtensors = _get_ref_change_linear_weights_to_woqtensors(Int8WeightOnlyQuantizedLinearWeight)
_ref_change_linear_weights_to_int4_woqtensors = _get_ref_change_linear_weights_to_woqtensors(Int4WeightOnlyQuantizedLinearWeight)


def _bench_quantized_tensor_subclass_perf(api, ref_api, kwargs=None):
if kwargs is None:
kwargs = {}

m = ToyLinearModel(1024, 1024, 1024).eval().to(torch.bfloat16).to("cuda")
m_ref = copy.deepcopy(m)
# setting batch_size to 20 to be compatible with the kernel
example_inputs = m.example_inputs(batch_size=20, dtype=torch.bfloat16, device="cuda")

api(m, **kwargs)

# reference
ref_api(m_ref, **kwargs)

res = m(*example_inputs)
ref = m_ref(*example_inputs)

assert torch.equal(res, ref)

# perf comparison
from torchao.utils import benchmark_model
# warmup
WARMUP = 5
RUNS = 100
input_tensor = example_inputs[0]
m = torch.compile(m, mode='max-autotune', fullgraph=True)

benchmark_model(m, WARMUP, input_tensor)
elapsed_time = benchmark_model(m, RUNS, input_tensor)

m_ref = torch.compile(m_ref, mode='max-autotune', fullgraph=True)
benchmark_model(m_ref, WARMUP, input_tensor)
ref_elapsed_time = benchmark_model(m_ref, RUNS, input_tensor)

print(f"elapsed time: {elapsed_time}, ref elapsed time: {ref_elapsed_time}")
assert elapsed_time < 1.05 * ref_elapsed_time

if __name__ == "__main__" and TORCH_VERSION_AFTER_2_4 and torch.cuda.is_available():
from torchao.quantization.quant_api import change_linear_weights_to_int8_dqtensors
_bench_quantized_tensor_subclass_perf(change_linear_weights_to_int8_dqtensors, _ref_change_linear_weights_to_int8_dqtensors)

from torchao.quantization.quant_api import change_linear_weights_to_int8_woqtensors
_bench_quantized_tensor_subclass_perf(change_linear_weights_to_int8_woqtensors, _ref_change_linear_weights_to_int8_woqtensors)

kwargs = {"groupsize": 32}
from torchao.quantization.quant_api import change_linear_weights_to_int4_woqtensors
_bench_quantized_tensor_subclass_perf(change_linear_weights_to_int4_woqtensors, _ref_change_linear_weights_to_int4_woqtensors, kwargs)
3 changes: 3 additions & 0 deletions test/integration/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -930,6 +930,7 @@ def _test_lin_weight_subclass_impl(
)

@parameterized.expand(COMMON_DEVICE_DTYPE)
@unittest.skipIf(TORCH_VERSION_AFTER_2_4, "skip because there is some bug in inductor codegen")
def test_int8_dynamic_quant_subclass(self, device, dtype):
self._test_lin_weight_subclass_impl(
Int8DynamicallyQuantizedLinearWeight.from_float, device, 35, test_dtype=dtype
Expand Down Expand Up @@ -1217,6 +1218,8 @@ def forward(self, x):
@parameterized.expand(COMMON_DEVICE_DTYPE)
@torch.no_grad()
def test_save_load_dqtensors(self, device, dtype):
if device == "cpu":
self.skipTest(f"indcutor failed for cpu right now")
self._test_handle_save_load_meta_impl(change_linear_weights_to_int8_dqtensors, device, test_dtype=dtype)

@parameterized.expand(COMMON_DEVICE_DTYPE)
Expand Down
78 changes: 30 additions & 48 deletions test/quantization/test_quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
from torchao.quantization.subclass import (
to_laq,
LinearActQuantizedTensor,
Int8WeightOnlyQuantizedLinearWeight,
Int4WeightOnlyQuantizedLinearWeight,
)
from torchao.quantization.quant_api import (
_replace_with_custom_fn_if_matches_filter,
Expand Down Expand Up @@ -138,6 +140,28 @@ def _ref_change_linear_weights_to_int8_dqtensors(model, filter_fn=None, **kwargs
model, _get_subclass_inserter(Int8DynamicallyQuantizedLinearWeight, enable_parametrization=False, **kwargs), filter_fn
)

def _get_ref_change_linear_weights_to_woqtensors(deprecated_tenosr_subclass):
def _ref_change_linear_weights_to_woqtensors(model, filter_fn=None, **kwargs):
"""
The deprecated implementation for weight only quant API, used as a reference for
numerics and performance
"""
from torchao.quantization.quant_api import _is_linear
from torchao.quantization.quant_api import _get_subclass_inserter

filter_fn = kwargs.pop("filter_fn", _is_linear)

_replace_with_custom_fn_if_matches_filter(
model,
_get_subclass_inserter(deprecated_tenosr_subclass, enable_parametrization=True, **kwargs),
filter_fn,
)

return _ref_change_linear_weights_to_woqtensors

_ref_change_linear_weights_to_int8_woqtensors = _get_ref_change_linear_weights_to_woqtensors(Int8WeightOnlyQuantizedLinearWeight)
_ref_change_linear_weights_to_int4_woqtensors = _get_ref_change_linear_weights_to_woqtensors(Int4WeightOnlyQuantizedLinearWeight)

class TestQuantFlow(unittest.TestCase):
def test_dynamic_quant_gpu_singleline(self):
m = ToyLinearModel().eval()
Expand Down Expand Up @@ -478,8 +502,7 @@ def test_quantized_tensor_subclass_int4(self):
assert isinstance(m.linear2.weight, AffineQuantizedTensor)

# reference
from torchao.quantization.quant_api import change_linear_weights_to_int4_woqtensors
change_linear_weights_to_int4_woqtensors(m_copy, groupsize=groupsize)
_ref_change_linear_weights_to_int4_woqtensors(m_copy, groupsize=groupsize)

res = m(*example_inputs)
ref = m_copy(*example_inputs)
Expand All @@ -489,7 +512,7 @@ def test_quantized_tensor_subclass_int4(self):

@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "Test only enabled for 2.4+")
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
def test_quantized_tensor_subclass_int8(self):
def test_quantized_tensor_subclass_int8_wo(self):
m = ToyLinearModel().eval().to(torch.bfloat16)
m_copy = copy.deepcopy(m)
example_inputs = tuple(map(lambda x: x.to(torch.bfloat16), m.example_inputs()))
Expand All @@ -500,13 +523,13 @@ def test_quantized_tensor_subclass_int8(self):
assert isinstance(m.linear2.weight, AffineQuantizedTensor)

# reference
from torchao.quantization.quant_api import change_linear_weights_to_int8_woqtensors
change_linear_weights_to_int8_woqtensors(m_copy)
_ref_change_linear_weights_to_int8_woqtensors(m_copy)


res = m(*example_inputs)
ref = m_copy(*example_inputs)

torch.testing.assert_close(res, ref, rtol=0.00001, atol=1e-2)
self.assertTrue(torch.equal(res, ref))


@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "Test only enabled for 2.4+")
Expand All @@ -525,8 +548,7 @@ def test_quantized_tensor_subclass_int8_dyn_quant(self):
assert isinstance(m.linear2.weight.original_weight_tensor, AffineQuantizedTensor)

# reference
from torchao.quantization.quant_api import change_linear_weights_to_int8_dqtensors
change_linear_weights_to_int8_dqtensors(m_copy)
_ref_change_linear_weights_to_int8_dqtensors(m_copy)

res = m(*example_inputs)
ref = m_copy(*example_inputs)
Expand All @@ -545,45 +567,5 @@ def test_quantized_tensor_subclass_int8_dyn_quant(self):
# make sure it compiles
torch._export.aot_compile(m_unwrapped, example_inputs)

@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "Test only enabled for 2.4+")
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@unittest.skip("This perf test is supposed to be run locally for sanity check performance when there is a change of int8 dynamic quant implementation")
def test_quantized_tensor_subclass_int8_dyn_quant_perf(self):
m = ToyLinearModel(1024, 1024, 1024).eval().to(torch.bfloat16).to("cuda")
m_ref = copy.deepcopy(m)
# setting batch_size to 20 to be compatible with the kernel
example_inputs = m.example_inputs(batch_size=20, dtype=torch.bfloat16, device="cuda")

from torchao.quantization.quant_api import change_linear_weights_to_int8_dqtensors
change_linear_weights_to_int8_dqtensors(m)

# reference
_ref_change_linear_weights_to_int8_dqtensors(m_ref)

res = m(*example_inputs)
ref = m_ref(*example_inputs)

self.assertTrue(torch.equal(res, ref))

# perf comparison
from torchao.utils import benchmark_model
# warmup
WARMUP = 5
RUNS = 100
input_tensor = example_inputs[0]
m = torch.compile(m, mode='max-autotune', fullgraph=True)

benchmark_model(m, WARMUP, input_tensor)
elapsed_time = benchmark_model(m, RUNS, input_tensor)

m_ref = torch.compile(m_ref, mode='max-autotune', fullgraph=True)
benchmark_model(m_ref, WARMUP, input_tensor)
ref_elapsed_time = benchmark_model(m_ref, RUNS, input_tensor)

print(f"elapsed time: {elapsed_time}, ref elapsed time: {ref_elapsed_time}")
self.assertTrue(elapsed_time < 1.05 * ref_elapsed_time)



if __name__ == "__main__":
unittest.main()
Loading

0 comments on commit 9d1d8df

Please sign in to comment.