Skip to content

Commit 0613c0f

Browse files
metascroyfacebook-github-bot
authored andcommitted
Add torchchat quantizer (#897)
Summary: Pull Request resolved: #897 This diff adds a quantizer for the new torchao kernels that is similar to the Int8DynActInt4WeightQuantizer quantizer in torchchat (imported from from torchao.quantization.quant_api). See the draft torchchat PR (pytorch/torchchat#1070) for how this can integrate with torchchat's quantization API. I confirmed that models quantized with this are compatible with eager, compile, AOTI, and export to ExecuTorch in torchchat. They do not run on ExecuTorch because we still have not written an ExecuTorch kernel wrapper. jerryzh168 this does not use the new subclass API, and this is something I'd like to discuss further with you. I'll set up a sync with you this week, but I wanted to have some API on the table to ground the discussion. We do not currently have the required C++ methods implemented to support the new subclass API (e.g., we cannot unpack the packed weights from python; they are instead unpacked inline in the kernel). From a torchchat user's perspective, I do not think this is important, but I'd like to discuss further. Reviewed By: digantdesai Differential Revision: D62394341
1 parent 728d629 commit 0613c0f

File tree

7 files changed

+342
-316
lines changed

7 files changed

+342
-316
lines changed

torchao/experimental/kernels/cpu/aarch64/CMakeLists.txt

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66

77
add_library(
88
kernel_aarch64
9-
${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/aarch64/reduction/find_min_and_max.cpp
10-
${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/aarch64/reduction/compute_sum.cpp
11-
${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/aarch64/quantization/quantize.cpp
12-
${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/aarch64/valpacking/interleave.cpp
9+
${TORCHAO_INCLUDE_DIRS}/torchao/experimental/kernels/cpu/aarch64/reduction/find_min_and_max.cpp
10+
${TORCHAO_INCLUDE_DIRS}/torchao/experimental/kernels/cpu/aarch64/reduction/compute_sum.cpp
11+
${TORCHAO_INCLUDE_DIRS}/torchao/experimental/kernels/cpu/aarch64/quantization/quantize.cpp
12+
${TORCHAO_INCLUDE_DIRS}/torchao/experimental/kernels/cpu/aarch64/valpacking/interleave.cpp
1313
)

torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/CMakeLists.txt

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,12 @@ set(CMAKE_BUILD_TYPE Release)
1313
add_compile_options("-Wall" "-Werror")
1414

1515
include(CMakePrintHelpers)
16-
message("TORCHAO_LIBRARIES: ${TORCHAO_LIBRARIES}")
17-
include_directories(${TORCHAO_LIBRARIES})
16+
message("TORCHAO_INCLUDE_DIRS: ${TORCHAO_INCLUDE_DIRS}")
17+
include_directories(${TORCHAO_INCLUDE_DIRS})
1818

19-
add_subdirectory(${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/aarch64 ${CMAKE_CURRENT_BINARY_DIR}/kernel_aarch64)
19+
add_subdirectory(${TORCHAO_INCLUDE_DIRS}/torchao/experimental/kernels/cpu/aarch64 ${CMAKE_CURRENT_BINARY_DIR}/kernel_aarch64)
2020

21-
include(${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/Utils.cmake)
21+
include(${TORCHAO_INCLUDE_DIRS}/torchao/experimental/kernels/cpu/Utils.cmake)
2222

2323
set(PLATFORM "ATEN" CACHE STRING "Choose platform surface: ATEN, EXECUTORCH")
2424
string(TOUPPER ${PLATFORM} PLATFORM_TO_UPPER)

torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/build_custom_op.sh

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,14 @@
66
# LICENSE file in the root directory of this source tree.
77

88
SCRIPT_DIR=$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &> /dev/null && pwd)
9-
export TORCHAO_LIBRARIES=${SCRIPT_DIR}/../../../../../../..
9+
export TORCHAO_INCLUDE_DIRS=${SCRIPT_DIR}/../../../../../../..
1010

1111
export CMAKE_PREFIX_PATH="$(python -c 'import torch.utils; print(torch.utils.cmake_prefix_path)')"
1212
echo "CMAKE_PREFIX_PATH: ${CMAKE_PREFIX_PATH}"
13-
export CMAKE_OUT=/tmp/cmake-out/torch_ao/examples/torch_custom_op
14-
cmake -DTORCHAO_LIBRARIES=${TORCHAO_LIBRARIES} \
13+
export CMAKE_OUT=/tmp/cmake-out/torchao
14+
cmake -DTORCHAO_INCLUDE_DIRS=${TORCHAO_INCLUDE_DIRS} \
1515
-DCMAKE_PREFIX_PATH=${CMAKE_PREFIX_PATH} \
1616
-DPLATFORM="ATEN" \
17-
-S ${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op \
17+
-S ${TORCHAO_INCLUDE_DIRS}/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op \
1818
-B ${CMAKE_OUT}
1919
cmake --build ${CMAKE_OUT}

torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/run_custom_op.py

Lines changed: 18 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,18 @@
55
# LICENSE file in the root directory of this source tree.
66

77
import copy
8+
import glob
9+
10+
import sys
811

912
import torch
10-
from torch_custom_op import (
11-
linear_a8sz_w_lowbit_reference_impl,
12-
replace_linear_with_quantized_linear,
13-
)
13+
14+
sys.path.insert(0, "../../../../..")
15+
from quant_api import Int8DynActIntxWeightQuantizer
16+
17+
libs = glob.glob("/tmp/cmake-out/torchao/liblowbit_op_aten.*")
18+
libs = list(filter(lambda l: (l.endswith("so") or l.endswith("dylib")), libs))
19+
torch.ops.load_library(libs[0])
1420

1521
group_size = 256
1622
m = 1
@@ -27,15 +33,15 @@
2733

2834
print("Quantizing random model")
2935
quantized_model = copy.deepcopy(model)
30-
quantized_model = quantized_model.eval()
31-
replace_linear_with_quantized_linear(
32-
quantized_model,
33-
kwargs={
34-
"group_size": group_size,
35-
"nbit": nbit,
36-
"has_weight_zeros": has_weight_zeros,
37-
},
36+
quantizer = Int8DynActIntxWeightQuantizer(
37+
device="cpu",
38+
precision=torch.float32,
39+
bitwidth=nbit,
40+
groupsize=group_size,
41+
has_weight_zeros=has_weight_zeros,
3842
)
43+
quantized_model = quantizer.quantize(quantized_model)
44+
quantized_model = quantized_model.eval()
3945

4046
print("Creating random activations")
4147
activations = torch.randn(m, k, dtype=torch.float32)
@@ -58,44 +64,3 @@
5864
print("Running AOTI")
5965
fn = torch._export.aot_load("/tmp/torch_custom_op_example_model.so", "cpu")
6066
fn(activations)
61-
62-
63-
print("\nChecking correctness on layer 0")
64-
linear = model[0]
65-
quantized_linear = quantized_model[0]
66-
67-
with torch.no_grad():
68-
result = quantized_linear(activations)
69-
expected_result = linear_a8sz_w_lowbit_reference_impl(
70-
linear.weight, activations, group_size, nbit, has_weight_zeros
71-
)
72-
non_quantized_result = linear(activations)
73-
74-
75-
# Check that entries in result match entries in expected_result
76-
num_mismatch_at_low_tol = 0
77-
num_total = result.reshape(-1).shape[0]
78-
for i in range(num_total):
79-
actual_val = result.reshape(-1)[i]
80-
expected_val = expected_result.reshape(-1)[i]
81-
if not torch.allclose(actual_val, expected_val):
82-
num_mismatch_at_low_tol += 1
83-
84-
# If results are not close at a relaxed tolerance, exit with failure
85-
if not torch.allclose(actual_val, expected_val, atol=1e-6):
86-
assert False, "Correctness check failed"
87-
88-
# Assert at most 5% of entries are not close at a low tolerance
89-
assert num_mismatch_at_low_tol / num_total <= 0.05, "Correctness check failed"
90-
print(
91-
"Correctness check passed. All results are close, and ",
92-
(num_total - num_mismatch_at_low_tol),
93-
"/",
94-
num_total,
95-
" entries are close at a low tolerance.",
96-
)
97-
print("Quantization errors:")
98-
print("\tL1 error: ", torch.mean(torch.abs(result - non_quantized_result)).item())
99-
print("\tL2 error: ", torch.mean((result - non_quantized_result) ** 2).item())
100-
print("\tquantized_result[0:5]: ", result[0][0:5])
101-
print("\tnon_quantized_result[0:5]: ", non_quantized_result[0][0:5])

torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/test_custom_op.py renamed to torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/test_int8_dyn_act_intx_weight_quantizer.py

Lines changed: 35 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,27 @@
44
# This source code is licensed under the license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
import copy
8+
9+
import glob
10+
11+
import sys
712
import unittest
813

914
import torch
10-
from torch_custom_op import (
11-
linear_a8sz_w_lowbit_reference_impl,
12-
replace_linear_with_quantized_linear,
15+
16+
sys.path.insert(0, "../../../../..")
17+
from quant_api import (
18+
_Int8DynActIntxWeightQuantizedLinearFallback,
19+
Int8DynActIntxWeightQuantizer,
1320
)
14-
import copy
1521

16-
class TestTorchCustomOp(unittest.TestCase):
22+
libs = glob.glob("/tmp/cmake-out/torchao/liblowbit_op_aten.*")
23+
libs = list(filter(lambda l: (l.endswith("so") or l.endswith("dylib")), libs))
24+
torch.ops.load_library(libs[0])
25+
26+
27+
class TestInt8DynActIntxWeightQuantizer(unittest.TestCase):
1728
def test_accuracy(self):
1829
group_size = 128
1930
m = 1
@@ -22,24 +33,27 @@ def test_accuracy(self):
2233
activations = torch.randn(m, k, dtype=torch.float32)
2334
model = torch.nn.Sequential(*[torch.nn.Linear(k, n, bias=False)])
2435

25-
for nbit in [2, 3, 4, 5]:
26-
for has_weight_zeros in [False, True]:
36+
for nbit in [1, 2, 3, 4, 5, 6, 7]:
37+
for has_weight_zeros in [True, False]:
38+
print(f"Testing nbit={nbit}, has_weight_zeros={has_weight_zeros}")
2739
quantized_model = copy.deepcopy(model)
28-
replace_linear_with_quantized_linear(
29-
quantized_model,
30-
kwargs={
31-
"group_size": group_size,
32-
"nbit": nbit,
33-
"has_weight_zeros": has_weight_zeros,
34-
},
40+
quantizer = Int8DynActIntxWeightQuantizer(
41+
device="cpu",
42+
precision=torch.float32,
43+
bitwidth=nbit,
44+
groupsize=group_size,
45+
has_weight_zeros=has_weight_zeros,
3546
)
47+
quantized_model = quantizer.quantize(quantized_model)
3648

3749
with torch.no_grad():
3850
result = quantized_model(activations)
39-
expected_result = linear_a8sz_w_lowbit_reference_impl(
40-
model[0].weight, activations, group_size, nbit, has_weight_zeros
51+
reference_impl = _Int8DynActIntxWeightQuantizedLinearFallback()
52+
reference_impl.quantize_and_pack_weights(
53+
model[0].weight, nbit, group_size, has_weight_zeros
4154
)
42-
55+
expected_result = reference_impl(activations)
56+
4357
num_mismatch_at_low_tol = 0
4458
num_total = result.reshape(-1).shape[0]
4559
for i in range(num_total):
@@ -50,7 +64,8 @@ def test_accuracy(self):
5064
num_mismatch_at_low_tol += 1
5165

5266
# Assert at most 5% of entries are not close at a low tolerance
53-
self.assertTrue(num_mismatch_at_low_tol / num_total <= 0.05)
54-
55-
if __name__ == '__main__':
67+
self.assertTrue(num_mismatch_at_low_tol / num_total <= 0.05)
68+
69+
70+
if __name__ == "__main__":
5671
unittest.main()

0 commit comments

Comments
 (0)