Skip to content

Commit 80b4a72

Browse files
authored
Support Llama3 qaihub (#4789)
1 parent 012d61f commit 80b4a72

File tree

17 files changed

+730
-186
lines changed

17 files changed

+730
-186
lines changed

backends/qualcomm/runtime/QnnManager.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -332,7 +332,8 @@ Error QnnManager::AllocateTensor() {
332332
const std::string& tensor_name = tensor_wrapper->GetName();
333333
// this is required by identifying shared buffer mechanism
334334
// info might be missed if context binary came from qnn_converter
335-
if (tensor_name.find("output_") == std::string::npos) {
335+
if (options_->is_from_context_binary() &&
336+
tensor_name.find("output_") == std::string::npos) {
336337
tensor_wrapper->SetName("output_" + tensor_name);
337338
}
338339
if (IsTensorDump()) {

backends/qualcomm/tests/test_qnn_delegate.py

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1923,7 +1923,7 @@ def test_llama2_7b(self):
19231923
prompt = "Explain the rules of baseball"
19241924
cmds = [
19251925
"python",
1926-
f"{self.executorch_root}/examples/qualcomm/qaihub_scripts/llama2/qaihub_llama2_7b.py",
1926+
f"{self.executorch_root}/examples/qualcomm/qaihub_scripts/llama/llama2/qaihub_llama2_7b.py",
19271927
"--artifact",
19281928
self.artifact_dir,
19291929
"--build_folder",
@@ -1957,6 +1957,47 @@ def test_llama2_7b(self):
19571957
model_out = msg["result"]
19581958
self.assertTrue(model_out.startswith(prompt))
19591959

1960+
def test_llama3_8b(self):
1961+
if not self.required_envs():
1962+
self.skipTest("missing required envs")
1963+
1964+
prompt = "Explain the rules of baseball"
1965+
cmds = [
1966+
"python",
1967+
f"{self.executorch_root}/examples/qualcomm/qaihub_scripts/llama/llama3/qaihub_llama3_8b.py",
1968+
"--artifact",
1969+
self.artifact_dir,
1970+
"--build_folder",
1971+
self.build_folder,
1972+
"--device",
1973+
self.device,
1974+
"--model",
1975+
self.model,
1976+
"--tokenizer_model",
1977+
f"{self.artifact_dir}/tokenizer.model",
1978+
"--context_binaries",
1979+
f"{self.artifact_dir}",
1980+
"--ip",
1981+
self.ip,
1982+
"--port",
1983+
str(self.port),
1984+
"--prompt",
1985+
f"{prompt}",
1986+
]
1987+
if self.host:
1988+
cmds.extend(["--host", self.host])
1989+
1990+
p = subprocess.Popen(cmds, stdout=subprocess.DEVNULL)
1991+
with Listener((self.ip, self.port)) as listener:
1992+
conn = listener.accept()
1993+
p.communicate()
1994+
msg = json.loads(conn.recv())
1995+
if "Error" in msg:
1996+
self.fail(msg["Error"])
1997+
else:
1998+
model_out = msg["result"]
1999+
self.assertTrue(model_out.startswith(prompt))
2000+
19602001

19612002
class TestExampleScript(TestQNN):
19622003
def required_envs(self, conditions=None) -> bool:

backends/qualcomm/utils/utils.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,7 @@ def replace_linear(module: torch.nn.Module):
121121

122122
def canonicalize_program(
123123
exported_program: ExportedProgram | List[LoweredBackendModule],
124+
custom_buffer_size=None,
124125
):
125126
# check if user specifies to use multi_contexts
126127
# this is a generic approach in case there exists multiple backends
@@ -140,7 +141,12 @@ def process_exported_program(prog):
140141
return max_sf_buf_size, module_map
141142

142143
def process_lowered_module(module):
143-
return len(module.processed_bytes), {
144+
spill_fill_size = (
145+
len(module.processed_bytes)
146+
if custom_buffer_size is None
147+
else custom_buffer_size
148+
)
149+
return spill_fill_size, {
144150
module: convert_to_option(module.compile_specs[0].value)
145151
}
146152

examples/qualcomm/CMakeLists.txt

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -69,15 +69,15 @@ target_include_directories(
6969

7070
# build qnn_executor_runner
7171
add_subdirectory(
72-
${CMAKE_CURRENT_SOURCE_DIR}/executor_runner
72+
${CMAKE_CURRENT_SOURCE_DIR}/executor_runner
7373
)
7474

7575
# build qnn_llama_runner
7676
add_subdirectory(
77-
${CMAKE_CURRENT_SOURCE_DIR}/oss_scripts/llama2
77+
${CMAKE_CURRENT_SOURCE_DIR}/oss_scripts/llama2
7878
)
7979

80-
# build qaihub_llama2_7b_runner
80+
# build qaihub_llama2_7b_runner and qaihub_llama3_8b_runner
8181
add_subdirectory(
82-
${CMAKE_CURRENT_SOURCE_DIR}/qaihub_scripts/llama2
82+
${CMAKE_CURRENT_SOURCE_DIR}/qaihub_scripts/llama
8383
)

examples/qualcomm/oss_scripts/llama2/qnn_llama_runner.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,9 @@
99
/**
1010
* @file
1111
*
12-
* This tool can run ExecuTorch model files with Qualcomm AI Engine Direct
13-
* and the portable kernels.
12+
* This tool can run ExecuTorch model files with Qualcomm AI Engine Direct.
1413
*
15-
* User could specify arguments like desired input data, iterations, etc.
14+
* User could specify arguments like desired prompt, temperature, etc.
1615
*/
1716

1817
#include <executorch/backends/qualcomm/runtime/QnnExecuTorch.h>
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
# Copyright (c) Qualcomm Innovation Center, Inc.
2+
# All rights reserved
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# preprocess qaihub runner src files for llama2,3
8+
set(_qaihub_llama_runner__srcs ${_llama_runner__srcs})
9+
list(TRANSFORM _qaihub_llama_runner__srcs PREPEND "${EXECUTORCH_SOURCE_DIR}/")
10+
list(FILTER _qaihub_llama_runner__srcs EXCLUDE REGEX ".*(/runner/).*")
11+
list(PREPEND _qaihub_llama_runner__srcs
12+
${CMAKE_CURRENT_LIST_DIR}/runner/runner.cpp
13+
${CMAKE_CURRENT_LIST_DIR}/runner/runner.h
14+
${CMAKE_CURRENT_LIST_DIR}/runner/io_memory.cpp
15+
${CMAKE_CURRENT_LIST_DIR}/runner/io_memory.h
16+
)
17+
18+
19+
# preprocess qaihub llama2 7b runner src files
20+
set(_qaihub_llama2_7b_runner__srcs ${_qaihub_llama_runner__srcs})
21+
22+
list(PREPEND _qaihub_llama2_7b_runner__srcs
23+
${CMAKE_CURRENT_LIST_DIR}/llama2/qaihub_llama2_7b_runner.cpp
24+
)
25+
26+
# build qaihub llama2 7b runner
27+
add_executable(qaihub_llama2_7b_runner ${_qaihub_llama2_7b_runner__srcs})
28+
target_include_directories(qaihub_llama2_7b_runner
29+
PUBLIC ${_common_include_directories}
30+
)
31+
target_link_libraries(qaihub_llama2_7b_runner
32+
qnn_executorch_backend
33+
executorch_no_prim_ops
34+
extension_data_loader
35+
extension_module
36+
gflags
37+
)
38+
target_compile_options(qaihub_llama2_7b_runner
39+
PUBLIC ${_common_compile_options}
40+
)
41+
42+
43+
# preprocess qaihub llama3 8b runner src files
44+
set(_qaihub_llama3_8b_runner__srcs ${_qaihub_llama_runner__srcs})
45+
46+
list(PREPEND _qaihub_llama3_8b_runner__srcs
47+
${CMAKE_CURRENT_LIST_DIR}/llama3/qaihub_llama3_8b_runner.cpp
48+
)
49+
50+
# Adding a compile option to differentiate llama2 with llama3 logic
51+
list(APPEND _common_compile_options -DQAIHUB_LLAMA3_RUNNER)
52+
53+
# find RE2 for tokenizer
54+
set(ABSL_ENABLE_INSTALL ON)
55+
set(ABSL_PROPAGATE_CXX_STD ON)
56+
set(_pic_flag ${CMAKE_POSITION_INDEPENDENT_CODE})
57+
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
58+
add_subdirectory(
59+
${CMAKE_CURRENT_SOURCE_DIR}/../../../../extension/llm/third-party/abseil-cpp
60+
${CMAKE_CURRENT_BINARY_DIR}/abseil-cpp
61+
)
62+
add_subdirectory(
63+
${CMAKE_CURRENT_SOURCE_DIR}/../../../../extension/llm/third-party/re2
64+
${CMAKE_CURRENT_BINARY_DIR}/re2
65+
)
66+
set(CMAKE_POSITION_INDEPENDENT_CODE ${_pic_flag})
67+
68+
69+
list(APPEND _qaihub_llama3_8b_runner__srcs
70+
${CMAKE_CURRENT_SOURCE_DIR}/../../../../extension/llm/tokenizer/tiktoken.cpp
71+
)
72+
list(APPEND _qaihub_llama3_8b_runner__srcs
73+
${CMAKE_CURRENT_SOURCE_DIR}/../../../models/llama2/tokenizer/llama_tiktoken.cpp
74+
)
75+
set(_preprocessor_flag -DET_USE_TIKTOKEN)
76+
77+
78+
# build qaihub llama3 8b runner
79+
add_executable(qaihub_llama3_8b_runner ${_qaihub_llama3_8b_runner__srcs})
80+
target_include_directories(qaihub_llama3_8b_runner
81+
PUBLIC ${_common_include_directories}
82+
)
83+
84+
target_link_libraries(qaihub_llama3_8b_runner
85+
qnn_executorch_backend
86+
executorch_no_prim_ops
87+
extension_data_loader
88+
extension_module
89+
gflags
90+
re2::re2
91+
)
92+
target_compile_options(qaihub_llama3_8b_runner
93+
PUBLIC ${_common_compile_options}
94+
)
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
# Summary
2+
3+
## Overview
4+
This file provides you the instructions to run LLAMA2 and LLAMA3 with different parameters via Qualcomm HTP backend. Following settings support for Llama-2-7b-chat-hf and Llama-3-8b-chat-hf
5+
6+
Please check corresponding section for more information.
7+
8+
## Llama-2-7b-chat-hf
9+
This example demonstrates how to run Llama-2-7b-chat-hf on mobile via Qualcomm HTP backend. Model was precompiled into context binaries by [Qualcomm AI HUB](https://aihub.qualcomm.com/).
10+
Note that the pre-compiled context binaries could not be futher fine-tuned for other downstream tasks.
11+
12+
### Instructions
13+
#### Step 1: Setup
14+
1. Follow the [tutorial](https://pytorch.org/executorch/main/getting-started-setup) to set up ExecuTorch.
15+
2. Follow the [tutorial](https://pytorch.org/executorch/stable/build-run-qualcomm-ai-engine-direct-backend.html) to build Qualcomm AI Engine Direct Backend.
16+
17+
#### Step2: Prepare Model
18+
1. Create account for https://aihub.qualcomm.com/
19+
2. Follow instructions in https://huggingface.co/qualcomm/Llama-v2-7B-Chat to export context binaries (will take some time to finish)
20+
21+
```bash
22+
# tokenizer.model: https://huggingface.co/meta-llama/Llama-2-7b-chat-hf/blob/main/tokenizer.model
23+
# tokenizer.bin:
24+
python -m examples.models.llama2.tokenizer.tokenizer -t tokenizer.model -o tokenizer.bin
25+
```
26+
27+
#### Step3: Run default examples
28+
```bash
29+
# AIHUB_CONTEXT_BINARIES: ${PATH_TO_AIHUB_WORKSPACE}/build/llama_v2_7b_chat_quantized
30+
python examples/qualcomm/qaihub_scripts/llama/llama2/qaihub_llama2_7b.py -a ${ARTIFACTS} -b cmake-out-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --context_binaries ${AIHUB_CONTEXT_BINARIES} --tokenizer_bin tokenizer.bin --prompt "What is Python?"
31+
```
32+
33+
## Llama-3-8b-chat-hf
34+
This example demonstrates how to run Llama-3-8b-chat-hf on mobile via Qualcomm HTP backend. Model was precompiled into context binaries by [Qualcomm AI HUB](https://aihub.qualcomm.com/).
35+
Note that the pre-compiled context binaries could not be futher fine-tuned for other downstream tasks. This example script has been tested on a 16GB RAM device and verified to work.
36+
37+
### Instructions
38+
#### Step 1: Setup
39+
1. Follow the [tutorial](https://pytorch.org/executorch/main/getting-started-setup) to set up ExecuTorch.
40+
2. Follow the [tutorial](https://pytorch.org/executorch/stable/build-run-qualcomm-ai-engine-direct-backend.html) to build Qualcomm AI Engine Direct Backend.
41+
42+
#### Step2: Prepare Model
43+
1. Create account for https://aihub.qualcomm.com/
44+
2. Follow instructions in https://huggingface.co/qualcomm/Llama-v3-8B-Chat to export context binaries (will take some time to finish)
45+
3. For Llama 3 tokenizer, please refer to https://github.com/meta-llama/llama-models/blob/main/README.md for further instructions on how to download tokenizer.model.
46+
47+
48+
#### Step3: Run default examples
49+
```bash
50+
# AIHUB_CONTEXT_BINARIES: ${PATH_TO_AIHUB_WORKSPACE}/build/llama_v3_8b_chat_quantized
51+
python examples/qualcomm/qaihub_scripts/llama/llama3/qaihub_llama3_8b.py -a ${ARTIFACTS} -b cmake-out-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --context_binaries ${AIHUB_CONTEXT_BINARIES} --tokenizer_model tokenizer.model --prompt "What is baseball?"
52+
```

examples/qualcomm/qaihub_scripts/llama2/qaihub_llama2_7b.py renamed to examples/qualcomm/qaihub_scripts/llama/llama2/qaihub_llama2_7b.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,11 @@ def main(args):
5555
is_from_context_binary=True,
5656
)
5757

58+
pte_name = (
59+
"qaihub_llama2_7b_prompt"
60+
if args.use_prompt_processor
61+
else "qaihub_llama2_7b_token"
62+
)
5863
if args.pre_gen_pte is None:
5964
# create custom operators as context loader
6065
bundle_programs = [
@@ -69,7 +74,7 @@ def main(args):
6974
# setup spill-fill buffer for relieving runtime memory usage
7075
canonicalize_program(lowered_modules)
7176
# export pte files
72-
pte_name, pte_files = "qaihub_llama7b", []
77+
pte_files = []
7378
for i in range(len(target_names)):
7479
print(f"pte {i} generating...")
7580
memory_planning_pass = MemoryPlanningPass(
@@ -90,7 +95,6 @@ def main(args):
9095
lowered_modules.pop(0)
9196
gc.collect()
9297
else:
93-
pte_name = "qaihub_llama7b"
9498
pte_files = [f"{args.pre_gen_pte}/{pte_name}_{i}.pte" for i in range(4)]
9599

96100
if args.compile_only:
@@ -109,12 +113,6 @@ def get_logit_encoding(path_to_last_shard: str):
109113
qnn_mgr.Destroy()
110114
return encoding.data["scale"].item(), encoding.data["offset"].item()
111115

112-
# setup required paths accordingly
113-
# qnn_sdk : QNN SDK path setup in environment variable
114-
# artifact_path : path where artifacts were built
115-
# pte_path : path where executorch binary was stored
116-
# device_id : serial number of android device
117-
# workspace : folder for storing artifacts on android device
118116
adb = SimpleADB(
119117
qnn_sdk=os.getenv("QNN_SDK_ROOT"),
120118
build_path=args.build_folder,
@@ -123,7 +121,7 @@ def get_logit_encoding(path_to_last_shard: str):
123121
device_id=args.device,
124122
host_id=args.host,
125123
soc_model=args.model,
126-
runner="examples/qualcomm/qaihub_scripts/llama2/qaihub_llama2_7b_runner",
124+
runner="examples/qualcomm/qaihub_scripts/llama/qaihub_llama2_7b_runner",
127125
)
128126
output_file = "result.txt"
129127
pos_embs_file = ["freq_cos", "freq_sin"]

examples/qualcomm/qaihub_scripts/llama2/qaihub_llama2_7b_runner.cpp renamed to examples/qualcomm/qaihub_scripts/llama/llama2/qaihub_llama2_7b_runner.cpp

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,13 @@
99
/**
1010
* @file
1111
*
12-
* This tool can run ExecuTorch model files with Qualcomm AI Engine Direct
13-
* and the portable kernels.
12+
* This tool can run Llama2 7b with Qualcomm AI Engine Direct.
1413
*
15-
* User could specify arguments like desired input data, iterations, etc.
16-
* Currently we assume that the outputs are all fp32 tensors.
14+
* User could specify arguments like desired prompt, eval_mode, etc.
1715
*/
1816

1917
#include <executorch/backends/qualcomm/runtime/QnnExecuTorch.h>
20-
#include <executorch/examples/qualcomm/qaihub_scripts/llama2/runner/runner.h>
18+
#include <executorch/examples/qualcomm/qaihub_scripts/llama/runner/runner.h>
2119
#include <executorch/extension/runner_util/managed_tensor.h>
2220
#include <executorch/runtime/platform/log.h>
2321

@@ -68,6 +66,7 @@ int main(int argc, char** argv) {
6866
Runner runner(
6967
models_path,
7068
pos_embs_path,
69+
{8, 8, 8, 8},
7170
FLAGS_tokenizer_path.c_str(),
7271
FLAGS_eval_mode,
7372
FLAGS_temperature,

0 commit comments

Comments
 (0)