Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Qualcomm AI Engine Direct - add cli tool for QNN artifacts #4731

Merged
merged 1 commit into from
Aug 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion backends/qualcomm/aot/python/PyQnnWrapperAdaptor.h
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ class PyQnnTensorWrapper {
return {enc_data, data.axis};
}
default:
QNN_EXECUTORCH_LOG_ERROR(
QNN_EXECUTORCH_LOG_WARN(
"%s QNN_QUANTIZATION_ENCODING_UNDEFINED detected",
GetName().c_str());
break;
Expand Down
91 changes: 91 additions & 0 deletions backends/qualcomm/tests/test_qnn_delegate.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import io
import json
import subprocess
import sys
Expand Down Expand Up @@ -1825,6 +1826,96 @@ def required_envs(self, conditions=None) -> bool:
]
)

def test_utils_export(self):
with tempfile.TemporaryDirectory() as tmp_dir:
module = ContextBinaryExample() # noqa: F405
generate_context_binary(
module=module,
inputs=module.example_inputs(),
quantized=True,
artifact_dir=tmp_dir,
)
ctx_path = f"{tmp_dir}/model_ctx.bin"
fpath = f"{self.executorch_root}/examples/qualcomm/qaihub_scripts/utils/export.py"

# do compilation
compile_cmds = [
"python",
fpath,
"compile",
"-a",
ctx_path,
"-m",
self.model,
"-l",
"False",
"-b",
self.build_folder,
"-o",
f"{tmp_dir}/output_pte",
]
compile_process = subprocess.Popen(
compile_cmds, stdout=subprocess.DEVNULL, cwd=self.executorch_root
)
output_pte_dir = f"{tmp_dir}/output_pte/model_ctx"
compile_process.communicate()

# check artifacts are correctly generated
self.assertTrue(
all(
[
Path(output_pte_dir).exists(),
Path(f"{output_pte_dir}/model_ctx.json").exists(),
Path(f"{output_pte_dir}/model_ctx.svg").exists(),
]
)
)

# prepare input files
input_list, inputs = [], module.example_inputs()
for name, tensor in inputs.items():
tensor_path = f"{output_pte_dir}/{name}.pt"
torch.save(tensor, tensor_path)
input_list.append(tensor_path)

# do execution
output_data_dir = f"{tmp_dir}/output_data"
execute_cmds = [
"python",
fpath,
"execute",
"-p",
output_pte_dir,
"-i",
*input_list,
"-s",
self.device,
"-z",
"-b",
self.build_folder,
"-o",
output_data_dir,
]
if self.host is not None:
execute_cmds.append(f"-H {self.host}")
execute_process = subprocess.Popen(execute_cmds, cwd=self.executorch_root)
execute_process.communicate()

# read outputs
with open(f"{output_pte_dir}/model_ctx.json", "r") as f:
graph_info = json.load(f)

device_output = []
for output in graph_info["outputs"]:
with open(f"{output_data_dir}/{output['name']}.pt", "rb") as f:
buffer = io.BytesIO(f.read())
device_output.append(torch.load(buffer, weights_only=False))

# validate outputs
golden_output = module.forward(inputs["x"], inputs["y"])
self.atol, self.rtol = 1e-1, 1
self._assert_outputs_equal(golden_output, device_output)

def test_llama2_7b(self):
if not self.required_envs():
self.skipTest("missing required envs")
Expand Down
6 changes: 4 additions & 2 deletions backends/qualcomm/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,9 @@ def capture_program(
return edge_ep


def from_context_binary(ctx_path: str, op_name: str):
def from_context_binary(
ctx_path: str, op_name: str, soc_model: QcomChipset = QcomChipset.SM8650
):
def implement_op(custom_op, op_name, outputs):
@torch.library.impl(
custom_op, str(op_name), dispatch_key="CompositeExplicitAutograd"
Expand Down Expand Up @@ -283,7 +285,7 @@ def build_tensor(tensors, dtype_map):
# dummy compiler spec would be fine, since we're not compiling
backend_options = generate_htp_compiler_spec(use_fp16=False)
compiler_specs = generate_qnn_executorch_compiler_spec(
soc_model=QcomChipset.SM8650,
soc_model=soc_model,
backend_options=backend_options,
is_from_context_binary=True,
)
Expand Down
102 changes: 102 additions & 0 deletions examples/qualcomm/qaihub_scripts/utils/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
# CLI Tool for Compile / Deploy Pre-Built QNN Artifacts

An easy-to-use tool for generating / executing .pte program from pre-built model libraries / context binaries from Qualcomm AI Engine Direct. Tool is verified with [host environement](../../../../docs/source/build-run-qualcomm-ai-engine-direct-backend.md#host-os).
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it generic for all models from ai hub?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, artifacts from AIHUB related to QNN are delivered with .so format. Only large generative AI models are shipped with context binaries.
Both of them could be transformed into .pte program with this tool.


## Description

This tool aims for users who want to leverage ExecuTorch runtime framework with their existent artifacts generated by QNN. It's possible for them to produce .pte program in few steps.<br/>
If users are interested in well-known applications, [Qualcomm AI HUB](https://aihub.qualcomm.com/) is a great approach which provides tons of optimized state-of-the-art models ready for deploying. All of them could be downloaded in model library or context binary format.

* Model libraries(.so) came from `qnn-model-lib-generator` | AI HUB, or context binaries(.bin) came from `qnn-context-binary-generator` | AI HUB, could apply tool directly with:
- To produce .pte program:
```bash
$ python export.py compile
```
- To perform inference with generated .pte program:
```bash
$ python export.py execute
```

### Dependencies

* Register for Qualcomm AI HUB.
* Download the corresponding QNN SDK via shit [link](https://www.qualcomm.com/developer/software/qualcomm-ai-engine-direct-sdk) which your favorite model is compiled with. Ths link will automatically download the latest version at this moment (users should be able to specify version soon, please refer to [this](../../../../docs/source/build-run-qualcomm-ai-engine-direct-backend.md#software) for earlier releases).

### Target Model

* Consider using [virtual environment](https://app.aihub.qualcomm.com/docs/hub/getting_started.html) for AI HUB scripts to prevent package conflict against ExecuTorch. Please finish the [installation section](https://app.aihub.qualcomm.com/docs/hub/getting_started.html#installation) before proceeding following steps.
* Take [QuickSRNetLarge-Quantized](https://aihub.qualcomm.com/models/quicksrnetlarge_quantized?searchTerm=quantized) as an example, please [install](https://huggingface.co/qualcomm/QuickSRNetLarge-Quantized#installation) package as instructed.
* Create workspace and export pre-built model library:
```bash
mkdir $MY_WS && cd $MY_WS
# target chipset is `SM8650`
python -m qai_hub_models.models.quicksrnetlarge_quantized.export --target-runtime qnn --chipset qualcomm-snapdragon-8gen3
```
* The compiled model library will be located under `$MY_WS/build/quicksrnetlarge_quantized/quicksrnetlarge_quantized.so`. This model library maps to the artifacts generated by SDK tools mentioned in `Integration workflow` section on [Qualcomm AI Engine Direct document](https://docs.qualcomm.com/bundle/publicresource/topics/80-63442-50/overview.html).

### Compiling Program

* Compile .pte program
```bash
# `pip install pydot` if package is missing
# Note that device serial & hostname might not be required if given artifacts is in context binary format
PYTHONPATH=$EXECUTORCH_ROOT/.. python $EXECUTORCH_ROOT/examples/qualcomm/qaihub_scripts/utils/export.py compile -a $MY_WS/build/quicksrnetlarge_quantized/quicksrnetlarge_quantized.so -m SM8650 -s $DEVICE_SERIAL -b $EXECUTORCH_ROOT/build-android
```
* Artifacts for checking IO information
- `output_pte/quicksrnetlarge_quantized/quicksrnetlarge_quantized.json`
- `output_pte/quicksrnetlarge_quantized/quicksrnetlarge_quantized.svg`

### Executing Program

* Prepare test image
```bash
cd $MY_WS
wget https://user-images.githubusercontent.com/12981474/40157448-eff91f06-5953-11e8-9a37-f6b5693fa03f.png -O baboon.png
```
Execute following python script to generate input data:
```python
import torch
import torchvision.transforms as transforms
from PIL import Image
img = Image.open('baboon.png').resize((128, 128))
transform = transforms.Compose([transforms.PILToTensor()])
# convert (C, H, W) to (N, H, W, C)
# IO tensor info. could be checked with quicksrnetlarge_quantized.json | .svg
img = transform(img).permute(1, 2, 0).unsqueeze(0)
torch.save(img, 'baboon.pt')
```
* Execute .pte program
```bash
PYTHONPATH=$EXECUTORCH_ROOT/.. python $EXECUTORCH_ROOT/examples/qualcomm/qaihub_scripts/utils/export.py execute -p output_pte/quicksrnetlarge_quantized -i baboon.pt -s $DEVICE_SERIAL -b $EXECUTORCH_ROOT/build-android
```
* Post-process generated data
```bash
cd output_data
```
Execute following python script to generate output image:
```python
import io
import torch
import torchvision.transforms as transforms
# IO tensor info. could be checked with quicksrnetlarge_quantized.json | .svg
# generally we would have same layout for input / output tensors: e.g. either NHWC or NCHW
# this might not be true under different converter configurations
# learn more with converter tool from Qualcomm AI Engine Direct documentation
# https://docs.qualcomm.com/bundle/publicresource/topics/80-63442-50/tools.html#model-conversion
with open('output__142.pt', 'rb') as f:
buffer = io.BytesIO(f.read())
img = torch.load(buffer, weights_only=False)
transform = transforms.Compose([transforms.ToPILImage()])
img_pil = transform(img.squeeze(0))
img_pil.save('baboon_upscaled.png')
```
You could check the upscaled result now!

## Help

Please check help messages for more information:
```bash
PYTHONPATH=$EXECUTORCH_ROOT/.. python $EXECUTORCH_ROOT/examples/qualcomm/aihub/utils/export.py -h
PYTHONPATH=$EXECUTORCH_ROOT/.. python $EXECUTORCH_ROOT/examples/qualcomm/aihub/utils/python export.py compile -h
PYTHONPATH=$EXECUTORCH_ROOT/.. python $EXECUTORCH_ROOT/examples/qualcomm/aihub/utils/python export.py execute -h
```
Loading
Loading