Skip to content

Commit a6b701e

Browse files
metascroyDannyYuyang-quic
authored andcommitted
Bump torchao + add unit tests for torchao kernels (pytorch#9396)
### Summary This PR bumps the torchao pin and adds unit tests and documentation for the lowbit torchao kernels. ### Test plan New CI test
1 parent c1bf6b8 commit a6b701e

File tree

7 files changed

+216
-24
lines changed

7 files changed

+216
-24
lines changed
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
#!/bin/bash
2+
# Copyright (c) Qualcomm Innovation Center, Inc.
3+
# All rights reserved
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
set -exu
9+
10+
source "$(dirname "${BASH_SOURCE[0]}")/utils.sh"
11+
12+
export EXECUTORCH_ROOT="$(dirname "${BASH_SOURCE[0]}")/../.."
13+
14+
if [[ -z "${PYTHON_EXECUTABLE:-}" ]]; then
15+
PYTHON_EXECUTABLE=python3
16+
fi
17+
18+
which "${PYTHON_EXECUTABLE}"
19+
20+
# Update tokenizers submodule
21+
pushd $EXECUTORCH_ROOT/extension/llm/tokenizers
22+
echo "Update tokenizers submodule"
23+
git submodule update --init
24+
popd
25+
26+
# Install ET with CMake
27+
cmake -DPYTHON_EXECUTABLE=python \
28+
-DCMAKE_INSTALL_PREFIX=cmake-out \
29+
-DEXECUTORCH_ENABLE_LOGGING=1 \
30+
-DCMAKE_BUILD_TYPE=Release \
31+
-DEXECUTORCH_BUILD_EXTENSION_DATA_LOADER=ON \
32+
-DEXECUTORCH_BUILD_EXTENSION_MODULE=ON \
33+
-DEXECUTORCH_BUILD_EXTENSION_TENSOR=ON \
34+
-DEXECUTORCH_BUILD_XNNPACK=OFF \
35+
-DEXECUTORCH_BUILD_KERNELS_QUANTIZED=ON \
36+
-DEXECUTORCH_BUILD_KERNELS_OPTIMIZED=ON \
37+
-DEXECUTORCH_BUILD_KERNELS_CUSTOM=ON \
38+
-Bcmake-out .
39+
cmake --build cmake-out -j16 --target install --config Release
40+
41+
# Install llama runner with torchao
42+
cmake -DPYTHON_EXECUTABLE=python \
43+
-DCMAKE_PREFIX_PATH=$(python -c 'from distutils.sysconfig import get_python_lib; print(get_python_lib())') \
44+
-DCMAKE_BUILD_TYPE=Release \
45+
-DEXECUTORCH_BUILD_KERNELS_CUSTOM=ON \
46+
-DEXECUTORCH_BUILD_KERNELS_OPTIMIZED=ON \
47+
-DEXECUTORCH_BUILD_XNNPACK=OFF \
48+
-DEXECUTORCH_BUILD_KERNELS_QUANTIZED=ON \
49+
-DEXECUTORCH_BUILD_TORCHAO=ON \
50+
-Bcmake-out/examples/models/llama \
51+
examples/models/llama
52+
cmake --build cmake-out/examples/models/llama -j16 --config Release
53+
54+
# Download stories llama110m artifacts
55+
download_stories_model_artifacts
56+
57+
echo "Creating tokenizer.bin"
58+
$PYTHON_EXECUTABLE -m extension.llm.tokenizer.tokenizer -t tokenizer.model -o tokenizer.bin
59+
60+
# Export model
61+
LLAMA_CHECKPOINT=stories110M.pt
62+
LLAMA_PARAMS=params.json
63+
MODEL_OUT=model.pte
64+
TOKENIZER=tokenizer.bin
65+
66+
# Set low-bit quantization parameters
67+
QLINEAR_BITWIDTH=3 # Can be 1-8
68+
QLINEAR_GROUP_SIZE=128 # Must be multiple of 16
69+
QEMBEDDING_BITWIDTH=4 # Can be 1-8
70+
QEMBEDDING_GROUP_SIZE=32 # Must be multiple of 16
71+
72+
${PYTHON_EXECUTABLE} -m examples.models.llama.export_llama \
73+
--checkpoint "${LLAMA_CHECKPOINT:?}" \
74+
--params "${LLAMA_PARAMS:?}" \
75+
-kv \
76+
--use_sdpa_with_kv_cache \
77+
--output_name=${MODEL_OUT} \
78+
-qmode "torchao:8da${QLINEAR_BITWIDTH}w" \
79+
--group_size ${QLINEAR_GROUP_SIZE} \
80+
-E "torchao:${QEMBEDDING_BITWIDTH},${QEMBEDDING_GROUP_SIZE}" \
81+
--disable_dynamic_shape \
82+
-d fp32
83+
84+
# Test run
85+
./cmake-out/examples/models/llama/llama_main --model_path=$MODEL_OUT --tokenizer_path=$TOKENIZER --prompt="Once upon a time,"

.github/workflows/trunk.yml

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@ jobs:
2323
uses: pytorch/test-infra/.github/workflows/macos_job.yml@main
2424
strategy:
2525
matrix:
26-
# Mac runners are expensive and limited, and non reliable.
27-
# Do some basic testing for macos jobs, and rely mostly on
26+
# Mac runners are expensive and limited, and non reliable.
27+
# Do some basic testing for macos jobs, and rely mostly on
2828
# test-models-linux-aarch64 job instead.
2929
model: [emformer_join, ic4, llama2, mobilebert, mv3, resnet50, vit, w2l]
3030
backend: [xnnpack-quantization-delegation]
@@ -288,6 +288,26 @@ jobs:
288288
# Test ANE llama
289289
${CONDA_RUN} sh .ci/scripts/test_ane_static_llama.sh
290290
291+
test-llama-torchao-lowbit:
292+
name: test-llama-torchao-lowbit
293+
uses: pytorch/test-infra/.github/workflows/macos_job.yml@main
294+
with:
295+
runner: macos-m1-stable
296+
python-version: '3.11'
297+
submodules: 'true'
298+
ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
299+
script: |
300+
set -eux
301+
bash .ci/scripts/setup-conda.sh
302+
eval "$(conda shell.bash hook)"
303+
304+
# Install requirements
305+
${CONDA_RUN} python install_executorch.py
306+
${CONDA_RUN} sh examples/models/llama/install_requirements.sh
307+
308+
# Run test
309+
${CONDA_RUN} sh .ci/scripts/test_llama_torchao_lowbit.sh
310+
291311
test-llama-runner-linux:
292312
# Test Both linux x86 and linux aarch64
293313
name: test-llama-runner-linux

examples/models/llama/CMakeLists.txt

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -115,16 +115,23 @@ if(EXECUTORCH_BUILD_KERNELS_CUSTOM)
115115
endif()
116116

117117
if(EXECUTORCH_BUILD_TORCHAO)
118-
set(TORCHAO_BUILD_EXECUTORCH_OPS ON)
119-
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/../../../third-party/ao/torchao/experimental ${CMAKE_CURRENT_BINARY_DIR}/../../../third-party/ao/torchao/experimental)
120-
target_link_options_shared_lib(torchao_ops_executorch)
121-
list(APPEND link_libraries torchao_ops_executorch)
118+
# Currently only enable this on Arm-based Macs
122119
if(CMAKE_SYSTEM_NAME STREQUAL "Darwin" AND CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64")
120+
set(TORCHAO_BUILD_EXECUTORCH_OPS ON)
121+
set(TORCHAO_BUILD_CPU_AARCH64 ON)
123122
add_subdirectory(
124-
${CMAKE_CURRENT_SOURCE_DIR}/../../../third-party/ao/torchao/experimental/ops/mps
125-
${CMAKE_CURRENT_BINARY_DIR}/../../../third-party/ao/torchao/experimental/ops/mps)
126-
target_link_options_shared_lib(torchao_ops_mps_executorch)
127-
list(APPEND link_libraries torchao_ops_mps_executorch)
123+
${CMAKE_CURRENT_SOURCE_DIR}/../../../third-party/ao/torchao/experimental
124+
${CMAKE_CURRENT_BINARY_DIR}/../../../third-party/ao/torchao/experimental
125+
)
126+
target_link_options_shared_lib(torchao_ops_executorch)
127+
list(APPEND link_libraries torchao_ops_executorch)
128+
if(EXECUTORCH_BUILD_MPS)
129+
add_subdirectory(
130+
${CMAKE_CURRENT_SOURCE_DIR}/../../../third-party/ao/torchao/experimental/ops/mps
131+
${CMAKE_CURRENT_BINARY_DIR}/../../../third-party/ao/torchao/experimental/ops/mps)
132+
target_link_options_shared_lib(torchao_ops_mps_executorch)
133+
list(APPEND link_libraries torchao_ops_mps_executorch)
134+
endif()
128135
endif()
129136
endif()
130137

examples/models/llama/README.md

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -380,6 +380,79 @@ Please refer to [this tutorial](https://pytorch.org/executorch/main/llm/llama-de
380380
### Android
381381
Please refer to [this tutorial](https://pytorch.org/executorch/main/llm/llama-demo-android.html) to for full instructions on building the Android LLAMA Demo App.
382382
383+
## Running with low-bit kernels
384+
385+
We now give instructions for quantizating and running your model with low-bit kernels. These are still experimental, and require you do development on an Arm-based Mac. Also note that low-bit quantization often requires QAT (quantization-aware training) to give good quality results. Currently dynamic shapes must be disabled when exporting a model with these kernels.
386+
387+
First export your model for lowbit quantization (step 2 above):
388+
389+
```
390+
# Set these paths to point to the downloaded files
391+
LLAMA_CHECKPOINT=path/to/checkpoint.pth
392+
LLAMA_PARAMS=path/to/params.json
393+
394+
# Set low-bit quantization parameters
395+
QLINEAR_BITWIDTH=3 # Can be 1-8
396+
QLINEAR_GROUP_SIZE=128 # Must be multiple of 16
397+
QEMBEDDING_BITWIDTH=4 # Can be 1-8
398+
QEMBEDDING_GROUP_SIZE=32 # Must be multiple of 16
399+
400+
python -m examples.models.llama.export_llama \
401+
--model "llama3_2" \
402+
--checkpoint "${LLAMA_CHECKPOINT:?}" \
403+
--params "${LLAMA_PARAMS:?}" \
404+
-kv \
405+
--use_sdpa_with_kv_cache \
406+
--metadata '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}' \
407+
--output_name="llama3_2.pte" \
408+
-qmode "torchao:8da${QLINEAR_BITWIDTH}w" \
409+
--group_size ${QLINEAR_GROUP_SIZE} \
410+
-E "torchao:${QEMBEDDING_BITWIDTH},${QEMBEDDING_GROUP_SIZE}" \
411+
--disable_dynamic_shape \
412+
-d fp32
413+
```
414+
415+
Once the model is exported, we need to build ExecuTorch and the runner with the low-bit kernels.
416+
417+
The first step is to install ExecuTorch (the same as step 3.1 above):
418+
419+
```
420+
cmake -DPYTHON_EXECUTABLE=python \
421+
-DCMAKE_INSTALL_PREFIX=cmake-out \
422+
-DEXECUTORCH_ENABLE_LOGGING=1 \
423+
-DCMAKE_BUILD_TYPE=Release \
424+
-DEXECUTORCH_BUILD_EXTENSION_DATA_LOADER=ON \
425+
-DEXECUTORCH_BUILD_EXTENSION_MODULE=ON \
426+
-DEXECUTORCH_BUILD_EXTENSION_TENSOR=ON \
427+
-DEXECUTORCH_BUILD_XNNPACK=ON \
428+
-DEXECUTORCH_BUILD_KERNELS_QUANTIZED=ON \
429+
-DEXECUTORCH_BUILD_KERNELS_OPTIMIZED=ON \
430+
-DEXECUTORCH_BUILD_KERNELS_CUSTOM=ON \
431+
-Bcmake-out .
432+
cmake --build cmake-out -j16 --target install --config Release
433+
```
434+
435+
Next install the llama runner with torchao kernels enabled (similar to step 3.2 above):
436+
437+
```
438+
cmake -DPYTHON_EXECUTABLE=python \
439+
-DCMAKE_PREFIX_PATH=$(python -c 'from distutils.sysconfig import get_python_lib; print(get_python_lib())') \
440+
-DCMAKE_BUILD_TYPE=Release \
441+
-DEXECUTORCH_BUILD_KERNELS_CUSTOM=ON \
442+
-DEXECUTORCH_BUILD_KERNELS_OPTIMIZED=ON \
443+
-DEXECUTORCH_BUILD_XNNPACK=OFF \
444+
-DEXECUTORCH_BUILD_KERNELS_QUANTIZED=ON \
445+
-DEXECUTORCH_BUILD_TORCHAO=ON \
446+
-Bcmake-out/examples/models/llama \
447+
examples/models/llama
448+
cmake --build cmake-out/examples/models/llama -j16 --config Release
449+
```
450+
451+
Finally run your model (similar to step 3.3 above):
452+
453+
```
454+
cmake-out/examples/models/llama/llama_main --model_path=<model pte file> --tokenizer_path=<tokenizer.model> --prompt=<prompt>
455+
```
383456
384457
## Utility tools for Llama enablement
385458

examples/models/llama/source_transformation/quantize.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -98,18 +98,24 @@ def quantize( # noqa C901
9898
matches = re.findall(pattern, qmode)
9999
assert len(matches) == 1, f"Expected 1 match for pattern but got {len(matches)}"
100100
bitwidth = int(matches[0][0])
101-
_load_torchao_aten_lib(libname="libtorchao_ops_aten")
102-
from torchao.experimental.quant_api import Int8DynActIntxWeightLinearQuantizer
103101

104-
with torch.no_grad():
105-
model = Int8DynActIntxWeightLinearQuantizer(
106-
device="cpu",
107-
precision=torch.float32,
108-
groupsize=group_size,
109-
bitwidth=bitwidth,
110-
has_weight_zeros=False,
111-
).quantize(model)
102+
from torchao.experimental.quant_api import Int8DynamicActivationIntxWeightConfig
103+
from torchao.quantization.granularity import PerGroup, PerRow
104+
from torchao.quantization.quant_api import quantize_
105+
from torchao.utils import unwrap_tensor_subclass
112106

107+
with torch.no_grad():
108+
quantize_(
109+
model,
110+
Int8DynamicActivationIntxWeightConfig(
111+
weight_dtype=getattr(torch, f"int{bitwidth}"),
112+
granularity=(
113+
PerRow() if group_size in [0, -1] else PerGroup(group_size)
114+
),
115+
has_weight_zeros=False,
116+
),
117+
)
118+
model = unwrap_tensor_subclass(model)
113119
if verbose:
114120
print("quantized model:", model)
115121
return model
@@ -752,7 +758,6 @@ def get_quant_embedding_transform(args):
752758
bitwidth, group_size = args.embedding_quantize.split(":")[1].split(",")
753759
group_size = int(group_size)
754760
bitwidth = int(bitwidth)
755-
_load_torchao_aten_lib(libname="libtorchao_ops_aten")
756761
from torchao.experimental.quant_api import IntxWeightEmbeddingQuantizer
757762

758763
def _torchao_embedding_quantizer(model):

install_requirements.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
# LICENSE file in the root directory of this source tree.
77

88
import argparse
9+
import os
910
import platform
1011
import re
1112
import subprocess
@@ -117,6 +118,8 @@ def install_requirements(use_pytorch_nightly):
117118

118119
# Install packages directly from local copy instead of pypi.
119120
# This is usually not recommended.
121+
new_env = os.environ.copy()
122+
new_env["USE_CPP"] = "1" # install torchao kernels
120123
subprocess.run(
121124
[
122125
sys.executable,
@@ -127,6 +130,7 @@ def install_requirements(use_pytorch_nightly):
127130
"--no-build-isolation",
128131
*LOCAL_REQUIREMENTS,
129132
],
133+
env=new_env,
130134
check=True,
131135
)
132136

@@ -143,8 +147,6 @@ def main(args):
143147

144148

145149
if __name__ == "__main__":
146-
import os
147-
148150
# Before doing anything, cd to the directory containing this script.
149151
os.chdir(os.path.dirname(os.path.abspath(__file__)))
150152
if not python_is_compatible():

third-party/ao

Submodule ao updated 493 files

0 commit comments

Comments
 (0)