Skip to content

Commit b150224

Browse files
metascroyfacebook-github-bot
authored andcommitted
Add torchchat quantizer (#897)
Summary: Pull Request resolved: #897 This diff adds a quantizer for the new torchao kernels that is similar to the Int8DynActInt4WeightQuantizer quantizer in torchchat (imported from from torchao.quantization.quant_api). See the draft torchchat PR (pytorch/torchchat#1070) for how this can integrate with torchchat's quantization API. I confirmed that models quantized with this are compatible with eager, compile, AOTI, and export to ExecuTorch in torchchat. They do not run on ExecuTorch because we still have not written an ExecuTorch kernel wrapper. jerryzh168 this does not use the new subclass API, and this is something I'd like to discuss further with you. I'll set up a sync with you this week, but I wanted to have some API on the table to ground the discussion. We do not currently have the required C++ methods implemented to support the new subclass API (e.g., we cannot unpack the packed weights from python; they are instead unpacked inline in the kernel). From a torchchat user's perspective, I do not think this is important, but I'd like to discuss further. Differential Revision: D62394341
1 parent 2dea315 commit b150224

File tree

8 files changed

+421
-351
lines changed

8 files changed

+421
-351
lines changed

torchao/experimental/kernels/cpu/aarch64/CMakeLists.txt

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66

77
add_library(
88
kernel_aarch64
9-
${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/aarch64/reduction/find_min_and_max.cpp
10-
${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/aarch64/reduction/compute_sum.cpp
11-
${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/aarch64/quantization/quantize.cpp
12-
${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/aarch64/valpacking/interleave.cpp
9+
${TORCHAO_INCLUDE_DIRS}/torchao/experimental/kernels/cpu/aarch64/reduction/find_min_and_max.cpp
10+
${TORCHAO_INCLUDE_DIRS}/torchao/experimental/kernels/cpu/aarch64/reduction/compute_sum.cpp
11+
${TORCHAO_INCLUDE_DIRS}/torchao/experimental/kernels/cpu/aarch64/quantization/quantize.cpp
12+
${TORCHAO_INCLUDE_DIRS}/torchao/experimental/kernels/cpu/aarch64/valpacking/interleave.cpp
1313
)

torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/CMakeLists.txt

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,12 @@ set(CMAKE_BUILD_TYPE Release)
1313
add_compile_options("-Wall" "-Werror")
1414

1515
include(CMakePrintHelpers)
16-
message("TORCHAO_LIBRARIES: ${TORCHAO_LIBRARIES}")
17-
include_directories(${TORCHAO_LIBRARIES})
16+
message("TORCHAO_INCLUDE_DIRS: ${TORCHAO_INCLUDE_DIRS}")
17+
include_directories(${TORCHAO_INCLUDE_DIRS})
1818

19-
add_subdirectory(${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/aarch64 ${CMAKE_CURRENT_BINARY_DIR}/kernel_aarch64)
19+
add_subdirectory(${TORCHAO_INCLUDE_DIRS}/torchao/experimental/kernels/cpu/aarch64 ${CMAKE_CURRENT_BINARY_DIR}/kernel_aarch64)
2020

21-
include(${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/Utils.cmake)
21+
include(${TORCHAO_INCLUDE_DIRS}/torchao/experimental/kernels/cpu/Utils.cmake)
2222

2323
set(PLATFORM "ATEN" CACHE STRING "Choose platform surface: ATEN, EXECUTORCH")
2424
string(TOUPPER ${PLATFORM} PLATFORM_TO_UPPER)

torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/build_custom_op.sh

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,14 @@
66
# LICENSE file in the root directory of this source tree.
77

88
SCRIPT_DIR=$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &> /dev/null && pwd)
9-
export TORCHAO_LIBRARIES=${SCRIPT_DIR}/../../../../../../..
9+
export TORCHAO_INCLUDE_DIRS=${SCRIPT_DIR}/../../../../../../..
1010

1111
export CMAKE_PREFIX_PATH="$(python -c 'import torch.utils; print(torch.utils.cmake_prefix_path)')"
1212
echo "CMAKE_PREFIX_PATH: ${CMAKE_PREFIX_PATH}"
13-
export CMAKE_OUT=/tmp/cmake-out/torch_ao/examples/torch_custom_op
14-
cmake -DTORCHAO_LIBRARIES=${TORCHAO_LIBRARIES} \
13+
export CMAKE_OUT=/tmp/cmake-out/torchao
14+
cmake -DTORCHAO_INCLUDE_DIRS=${TORCHAO_INCLUDE_DIRS} \
1515
-DCMAKE_PREFIX_PATH=${CMAKE_PREFIX_PATH} \
1616
-DPLATFORM="ATEN" \
17-
-S ${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op \
17+
-S ${TORCHAO_INCLUDE_DIRS}/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op \
1818
-B ${CMAKE_OUT}
1919
cmake --build ${CMAKE_OUT}

torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/run_custom_op.py

Lines changed: 20 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,21 @@
55
# LICENSE file in the root directory of this source tree.
66

77
import copy
8+
import glob
9+
import os
10+
11+
import sys
812

913
import torch
10-
from torch_custom_op import (
11-
linear_a8sz_w_lowbit_reference_impl,
12-
replace_linear_with_quantized_linear,
14+
15+
sys.path.insert(
16+
0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../../../../.."))
1317
)
18+
from quant_api import Int8DynActIntxWeightQuantizer
19+
20+
libs = glob.glob("/tmp/cmake-out/torchao/liblowbit_op_aten.*")
21+
libs = list(filter(lambda l: (l.endswith("so") or l.endswith("dylib")), libs))
22+
torch.ops.load_library(libs[0])
1423

1524
group_size = 256
1625
m = 1
@@ -27,15 +36,15 @@
2736

2837
print("Quantizing random model")
2938
quantized_model = copy.deepcopy(model)
30-
quantized_model = quantized_model.eval()
31-
replace_linear_with_quantized_linear(
32-
quantized_model,
33-
kwargs={
34-
"group_size": group_size,
35-
"nbit": nbit,
36-
"has_weight_zeros": has_weight_zeros,
37-
},
39+
quantizer = Int8DynActIntxWeightQuantizer(
40+
device="cpu",
41+
precision=torch.float32,
42+
bitwidth=nbit,
43+
groupsize=group_size,
44+
has_weight_zeros=has_weight_zeros,
3845
)
46+
quantized_model = quantizer.quantize(quantized_model)
47+
quantized_model = quantized_model.eval()
3948

4049
print("Creating random activations")
4150
activations = torch.randn(m, k, dtype=torch.float32)
@@ -58,44 +67,3 @@
5867
print("Running AOTI")
5968
fn = torch._export.aot_load("/tmp/torch_custom_op_example_model.so", "cpu")
6069
fn(activations)
61-
62-
63-
print("\nChecking correctness on layer 0")
64-
linear = model[0]
65-
quantized_linear = quantized_model[0]
66-
67-
with torch.no_grad():
68-
result = quantized_linear(activations)
69-
expected_result = linear_a8sz_w_lowbit_reference_impl(
70-
linear.weight, activations, group_size, nbit, has_weight_zeros
71-
)
72-
non_quantized_result = linear(activations)
73-
74-
75-
# Check that entries in result match entries in expected_result
76-
num_mismatch_at_low_tol = 0
77-
num_total = result.reshape(-1).shape[0]
78-
for i in range(num_total):
79-
actual_val = result.reshape(-1)[i]
80-
expected_val = expected_result.reshape(-1)[i]
81-
if not torch.allclose(actual_val, expected_val):
82-
num_mismatch_at_low_tol += 1
83-
84-
# If results are not close at a relaxed tolerance, exit with failure
85-
if not torch.allclose(actual_val, expected_val, atol=1e-6):
86-
assert False, "Correctness check failed"
87-
88-
# Assert at most 5% of entries are not close at a low tolerance
89-
assert num_mismatch_at_low_tol / num_total <= 0.05, "Correctness check failed"
90-
print(
91-
"Correctness check passed. All results are close, and ",
92-
(num_total - num_mismatch_at_low_tol),
93-
"/",
94-
num_total,
95-
" entries are close at a low tolerance.",
96-
)
97-
print("Quantization errors:")
98-
print("\tL1 error: ", torch.mean(torch.abs(result - non_quantized_result)).item())
99-
print("\tL2 error: ", torch.mean((result - non_quantized_result) ** 2).item())
100-
print("\tquantized_result[0:5]: ", result[0][0:5])
101-
print("\tnon_quantized_result[0:5]: ", non_quantized_result[0][0:5])

torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/test_custom_op.py

Lines changed: 0 additions & 56 deletions
This file was deleted.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import copy
8+
9+
import glob
10+
import os
11+
12+
import sys
13+
import unittest
14+
15+
import torch
16+
17+
sys.path.insert(
18+
0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../../../../.."))
19+
)
20+
from quant_api import (
21+
_Int8DynActIntxWeightQuantizedLinearFallback,
22+
Int8DynActIntxWeightQuantizer,
23+
)
24+
25+
libs = glob.glob("/tmp/cmake-out/torchao/liblowbit_op_aten.*")
26+
libs = list(filter(lambda l: (l.endswith("so") or l.endswith("dylib")), libs))
27+
if len(libs) == 0:
28+
print(
29+
"Could not find library lowbit_op_aten; please run `sh build_custom_op.sh` to build the library. A slow fallback kernel will be used instaed."
30+
)
31+
else:
32+
torch.ops.load_library(libs[0])
33+
34+
35+
class TestInt8DynActIntxWeightQuantizer(unittest.TestCase):
36+
def test_accuracy(self):
37+
group_size = 128
38+
m = 1
39+
n = 1071
40+
k = 4096
41+
activations = torch.randn(m, k, dtype=torch.float32)
42+
model = torch.nn.Sequential(*[torch.nn.Linear(k, n, bias=False)])
43+
44+
for nbit in [1, 2, 3, 4, 5, 6, 7]:
45+
for has_weight_zeros in [True, False]:
46+
print(f"Testing nbit={nbit}, has_weight_zeros={has_weight_zeros}")
47+
quantized_model = copy.deepcopy(model)
48+
quantizer = Int8DynActIntxWeightQuantizer(
49+
device="cpu",
50+
precision=torch.float32,
51+
bitwidth=nbit,
52+
groupsize=group_size,
53+
has_weight_zeros=has_weight_zeros,
54+
)
55+
quantized_model = quantizer.quantize(quantized_model)
56+
57+
with torch.no_grad():
58+
result = quantized_model(activations)
59+
reference_impl = _Int8DynActIntxWeightQuantizedLinearFallback()
60+
reference_impl.quantize_and_pack_weights(
61+
model[0].weight, nbit, group_size, has_weight_zeros
62+
)
63+
expected_result = reference_impl(activations)
64+
65+
num_mismatch_at_low_tol = 0
66+
num_total = result.reshape(-1).shape[0]
67+
for i in range(num_total):
68+
actual_val = result.reshape(-1)[i]
69+
expected_val = expected_result.reshape(-1)[i]
70+
self.assertTrue(torch.allclose(actual_val, expected_val, atol=1e-6))
71+
if not torch.allclose(actual_val, expected_val):
72+
num_mismatch_at_low_tol += 1
73+
74+
# Assert at most 5% of entries are not close at a low tolerance
75+
self.assertTrue(num_mismatch_at_low_tol / num_total <= 0.05)
76+
77+
78+
if __name__ == "__main__":
79+
unittest.main()

0 commit comments

Comments
 (0)