From 4c06907608d0091f52e72a6c461958671f7a5fea Mon Sep 17 00:00:00 2001 From: haowhsu-quic <111341466+haowhsu-quic@users.noreply.github.com> Date: Tue, 20 Aug 2024 00:03:20 +0800 Subject: [PATCH] Qualcomm AI Engine Direct - add cli tool for QNN artifacts (#4731) Summary: - cli tool for deploying precompiled model library / context bin onto executorch runtime - refactor & mionr fixes Resolved #4731 --- .../qualcomm/aot/python/PyQnnWrapperAdaptor.h | 2 +- backends/qualcomm/tests/test_qnn_delegate.py | 91 ++++ backends/qualcomm/utils/utils.py | 6 +- .../qualcomm/qaihub_scripts/utils/README.md | 102 ++++ .../qualcomm/qaihub_scripts/utils/export.py | 505 ++++++++++++++++++ examples/qualcomm/utils.py | 20 +- 6 files changed, 721 insertions(+), 5 deletions(-) create mode 100644 examples/qualcomm/qaihub_scripts/utils/README.md create mode 100644 examples/qualcomm/qaihub_scripts/utils/export.py diff --git a/backends/qualcomm/aot/python/PyQnnWrapperAdaptor.h b/backends/qualcomm/aot/python/PyQnnWrapperAdaptor.h index 98219d9763..1f7f5ccb08 100644 --- a/backends/qualcomm/aot/python/PyQnnWrapperAdaptor.h +++ b/backends/qualcomm/aot/python/PyQnnWrapperAdaptor.h @@ -171,7 +171,7 @@ class PyQnnTensorWrapper { return {enc_data, data.axis}; } default: - QNN_EXECUTORCH_LOG_ERROR( + QNN_EXECUTORCH_LOG_WARN( "%s QNN_QUANTIZATION_ENCODING_UNDEFINED detected", GetName().c_str()); break; diff --git a/backends/qualcomm/tests/test_qnn_delegate.py b/backends/qualcomm/tests/test_qnn_delegate.py index 1f779504bd..0cc91de4e1 100644 --- a/backends/qualcomm/tests/test_qnn_delegate.py +++ b/backends/qualcomm/tests/test_qnn_delegate.py @@ -3,6 +3,7 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import io import json import subprocess import sys @@ -1825,6 +1826,96 @@ def required_envs(self, conditions=None) -> bool: ] ) + def test_utils_export(self): + with tempfile.TemporaryDirectory() as tmp_dir: + module = ContextBinaryExample() # noqa: F405 + generate_context_binary( + module=module, + inputs=module.example_inputs(), + quantized=True, + artifact_dir=tmp_dir, + ) + ctx_path = f"{tmp_dir}/model_ctx.bin" + fpath = f"{self.executorch_root}/examples/qualcomm/qaihub_scripts/utils/export.py" + + # do compilation + compile_cmds = [ + "python", + fpath, + "compile", + "-a", + ctx_path, + "-m", + self.model, + "-l", + "False", + "-b", + self.build_folder, + "-o", + f"{tmp_dir}/output_pte", + ] + compile_process = subprocess.Popen( + compile_cmds, stdout=subprocess.DEVNULL, cwd=self.executorch_root + ) + output_pte_dir = f"{tmp_dir}/output_pte/model_ctx" + compile_process.communicate() + + # check artifacts are correctly generated + self.assertTrue( + all( + [ + Path(output_pte_dir).exists(), + Path(f"{output_pte_dir}/model_ctx.json").exists(), + Path(f"{output_pte_dir}/model_ctx.svg").exists(), + ] + ) + ) + + # prepare input files + input_list, inputs = [], module.example_inputs() + for name, tensor in inputs.items(): + tensor_path = f"{output_pte_dir}/{name}.pt" + torch.save(tensor, tensor_path) + input_list.append(tensor_path) + + # do execution + output_data_dir = f"{tmp_dir}/output_data" + execute_cmds = [ + "python", + fpath, + "execute", + "-p", + output_pte_dir, + "-i", + *input_list, + "-s", + self.device, + "-z", + "-b", + self.build_folder, + "-o", + output_data_dir, + ] + if self.host is not None: + execute_cmds.append(f"-H {self.host}") + execute_process = subprocess.Popen(execute_cmds, cwd=self.executorch_root) + execute_process.communicate() + + # read outputs + with open(f"{output_pte_dir}/model_ctx.json", "r") as f: + graph_info = json.load(f) + + device_output = [] + for output in graph_info["outputs"]: + with open(f"{output_data_dir}/{output['name']}.pt", "rb") as f: + buffer = io.BytesIO(f.read()) + device_output.append(torch.load(buffer, weights_only=False)) + + # validate outputs + golden_output = module.forward(inputs["x"], inputs["y"]) + self.atol, self.rtol = 1e-1, 1 + self._assert_outputs_equal(golden_output, device_output) + def test_llama2_7b(self): if not self.required_envs(): self.skipTest("missing required envs") diff --git a/backends/qualcomm/utils/utils.py b/backends/qualcomm/utils/utils.py index 27acaaa33b..2b3e6ba463 100644 --- a/backends/qualcomm/utils/utils.py +++ b/backends/qualcomm/utils/utils.py @@ -232,7 +232,9 @@ def capture_program( return edge_ep -def from_context_binary(ctx_path: str, op_name: str): +def from_context_binary( + ctx_path: str, op_name: str, soc_model: QcomChipset = QcomChipset.SM8650 +): def implement_op(custom_op, op_name, outputs): @torch.library.impl( custom_op, str(op_name), dispatch_key="CompositeExplicitAutograd" @@ -283,7 +285,7 @@ def build_tensor(tensors, dtype_map): # dummy compiler spec would be fine, since we're not compiling backend_options = generate_htp_compiler_spec(use_fp16=False) compiler_specs = generate_qnn_executorch_compiler_spec( - soc_model=QcomChipset.SM8650, + soc_model=soc_model, backend_options=backend_options, is_from_context_binary=True, ) diff --git a/examples/qualcomm/qaihub_scripts/utils/README.md b/examples/qualcomm/qaihub_scripts/utils/README.md new file mode 100644 index 0000000000..facc1da76e --- /dev/null +++ b/examples/qualcomm/qaihub_scripts/utils/README.md @@ -0,0 +1,102 @@ +# CLI Tool for Compile / Deploy Pre-Built QNN Artifacts + +An easy-to-use tool for generating / executing .pte program from pre-built model libraries / context binaries from Qualcomm AI Engine Direct. Tool is verified with [host environement](../../../../docs/source/build-run-qualcomm-ai-engine-direct-backend.md#host-os). + +## Description + +This tool aims for users who want to leverage ExecuTorch runtime framework with their existent artifacts generated by QNN. It's possible for them to produce .pte program in few steps.
+If users are interested in well-known applications, [Qualcomm AI HUB](https://aihub.qualcomm.com/) is a great approach which provides tons of optimized state-of-the-art models ready for deploying. All of them could be downloaded in model library or context binary format. + +* Model libraries(.so) came from `qnn-model-lib-generator` | AI HUB, or context binaries(.bin) came from `qnn-context-binary-generator` | AI HUB, could apply tool directly with: + - To produce .pte program: + ```bash + $ python export.py compile + ``` + - To perform inference with generated .pte program: + ```bash + $ python export.py execute + ``` + +### Dependencies + +* Register for Qualcomm AI HUB. +* Download the corresponding QNN SDK via shit [link](https://www.qualcomm.com/developer/software/qualcomm-ai-engine-direct-sdk) which your favorite model is compiled with. Ths link will automatically download the latest version at this moment (users should be able to specify version soon, please refer to [this](../../../../docs/source/build-run-qualcomm-ai-engine-direct-backend.md#software) for earlier releases). + +### Target Model + +* Consider using [virtual environment](https://app.aihub.qualcomm.com/docs/hub/getting_started.html) for AI HUB scripts to prevent package conflict against ExecuTorch. Please finish the [installation section](https://app.aihub.qualcomm.com/docs/hub/getting_started.html#installation) before proceeding following steps. +* Take [QuickSRNetLarge-Quantized](https://aihub.qualcomm.com/models/quicksrnetlarge_quantized?searchTerm=quantized) as an example, please [install](https://huggingface.co/qualcomm/QuickSRNetLarge-Quantized#installation) package as instructed. +* Create workspace and export pre-built model library: + ```bash + mkdir $MY_WS && cd $MY_WS + # target chipset is `SM8650` + python -m qai_hub_models.models.quicksrnetlarge_quantized.export --target-runtime qnn --chipset qualcomm-snapdragon-8gen3 + ``` +* The compiled model library will be located under `$MY_WS/build/quicksrnetlarge_quantized/quicksrnetlarge_quantized.so`. This model library maps to the artifacts generated by SDK tools mentioned in `Integration workflow` section on [Qualcomm AI Engine Direct document](https://docs.qualcomm.com/bundle/publicresource/topics/80-63442-50/overview.html). + +### Compiling Program + +* Compile .pte program + ```bash + # `pip install pydot` if package is missing + # Note that device serial & hostname might not be required if given artifacts is in context binary format + PYTHONPATH=$EXECUTORCH_ROOT/.. python $EXECUTORCH_ROOT/examples/qualcomm/qaihub_scripts/utils/export.py compile -a $MY_WS/build/quicksrnetlarge_quantized/quicksrnetlarge_quantized.so -m SM8650 -s $DEVICE_SERIAL -b $EXECUTORCH_ROOT/build-android + ``` +* Artifacts for checking IO information + - `output_pte/quicksrnetlarge_quantized/quicksrnetlarge_quantized.json` + - `output_pte/quicksrnetlarge_quantized/quicksrnetlarge_quantized.svg` + +### Executing Program + +* Prepare test image + ```bash + cd $MY_WS + wget https://user-images.githubusercontent.com/12981474/40157448-eff91f06-5953-11e8-9a37-f6b5693fa03f.png -O baboon.png + ``` + Execute following python script to generate input data: + ```python + import torch + import torchvision.transforms as transforms + from PIL import Image + img = Image.open('baboon.png').resize((128, 128)) + transform = transforms.Compose([transforms.PILToTensor()]) + # convert (C, H, W) to (N, H, W, C) + # IO tensor info. could be checked with quicksrnetlarge_quantized.json | .svg + img = transform(img).permute(1, 2, 0).unsqueeze(0) + torch.save(img, 'baboon.pt') + ``` +* Execute .pte program + ```bash + PYTHONPATH=$EXECUTORCH_ROOT/.. python $EXECUTORCH_ROOT/examples/qualcomm/qaihub_scripts/utils/export.py execute -p output_pte/quicksrnetlarge_quantized -i baboon.pt -s $DEVICE_SERIAL -b $EXECUTORCH_ROOT/build-android + ``` +* Post-process generated data + ```bash + cd output_data + ``` + Execute following python script to generate output image: + ```python + import io + import torch + import torchvision.transforms as transforms + # IO tensor info. could be checked with quicksrnetlarge_quantized.json | .svg + # generally we would have same layout for input / output tensors: e.g. either NHWC or NCHW + # this might not be true under different converter configurations + # learn more with converter tool from Qualcomm AI Engine Direct documentation + # https://docs.qualcomm.com/bundle/publicresource/topics/80-63442-50/tools.html#model-conversion + with open('output__142.pt', 'rb') as f: + buffer = io.BytesIO(f.read()) + img = torch.load(buffer, weights_only=False) + transform = transforms.Compose([transforms.ToPILImage()]) + img_pil = transform(img.squeeze(0)) + img_pil.save('baboon_upscaled.png') + ``` + You could check the upscaled result now! + +## Help + +Please check help messages for more information: +```bash +PYTHONPATH=$EXECUTORCH_ROOT/.. python $EXECUTORCH_ROOT/examples/qualcomm/aihub/utils/export.py -h +PYTHONPATH=$EXECUTORCH_ROOT/.. python $EXECUTORCH_ROOT/examples/qualcomm/aihub/utils/python export.py compile -h +PYTHONPATH=$EXECUTORCH_ROOT/.. python $EXECUTORCH_ROOT/examples/qualcomm/aihub/utils/python export.py execute -h +``` diff --git a/examples/qualcomm/qaihub_scripts/utils/export.py b/examples/qualcomm/qaihub_scripts/utils/export.py new file mode 100644 index 0000000000..9dfe679649 --- /dev/null +++ b/examples/qualcomm/qaihub_scripts/utils/export.py @@ -0,0 +1,505 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import io +import json +import logging +import os +from pathlib import Path + +import executorch.backends.qualcomm.python.PyQnnManagerAdaptor as PyQnnManagerAdaptor +import numpy as np + +import torch +from executorch.backends.qualcomm.serialization.qnn_compile_spec_schema import ( + QcomChipset, +) +from executorch.backends.qualcomm.utils.utils import ( + draw_graph, + from_context_binary, + generate_htp_compiler_spec, + generate_qnn_executorch_compiler_spec, + generate_qnn_executorch_option, +) +from executorch.examples.qualcomm.utils import make_output_dir, SimpleADB +from executorch.exir.backend.backend_api import to_backend +from executorch.exir.passes.memory_planning_pass import MemoryPlanningPass + + +def get_logger(): + logger = logging.getLogger("aihub.utils.export") + handler = logging.StreamHandler() + handler.setFormatter( + logging.Formatter( + fmt="[%(asctime)s %(prefix)s] %(levelname)-8s: %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + ) + ) + logger.addHandler(handler) + logger.setLevel(logging.INFO) + logger.propagate = False + return logging.LoggerAdapter(logger, extra={"prefix": "UTILS.EXPORT"}) + + +def get_io_info(prog_info, ctx_bin_path, compiler_spec): + def fill_tensor_info(info, qnn_tensors, category): + # fetch related IO info stored in prog_info + for i, (name, tensor) in enumerate(prog_info[category].items()): + assert qnn_tensors[i].GetName() == name, "tensor name unmatch" + encoding = qnn_tensors[i].GetEncodings() + quantization_info = { + "scale": encoding.data["scale"].tolist(), + "offset": encoding.data["offset"].tolist(), + "axis": encoding.axis, + } + info[category].append( + { + "name": name, + "shape": tuple(tensor.shape), + "dtype": str(tensor.dtype), + "encoding": quantization_info, + } + ) + + # dictionary to be serialized into json format + in_key, out_key = "inputs", "outputs" + tensor_info = {in_key: [], out_key: []} + + with open(ctx_bin_path, "rb") as f: + ctx_bin = f.read() + # leverage QNN pybind interface to retrieve tensor encodings + qnn_mgr = PyQnnManagerAdaptor.QnnManager( + generate_qnn_executorch_option(compiler_spec), ctx_bin + ) + assert qnn_mgr.Init().value == 0, "failed to load context binary" + qnn_mgr.AllocateTensor() + fill_tensor_info(tensor_info, qnn_mgr.GetGraphInputs(), in_key) + fill_tensor_info(tensor_info, qnn_mgr.GetGraphOutputs(), out_key) + qnn_mgr.Destroy() + + return tensor_info + + +def get_ones_tensor(tensor_info, logger): + logger.warning( + f"tensor '{tensor_info['name']}' use ones tensor, " + "unexpected outputs might generate" + ) + return torch.ones(tensor_info["shape"], dtype=eval(tensor_info["dtype"])) + + +def get_tensor_with_encoding(tensor, tensor_info, logger): + scale = tensor_info["encoding"]["scale"] + offset = tensor_info["encoding"]["offset"] + + # user gave wrong tensor for no encoding appears + if len(scale) == 0: + logger.error(f"tensor '{tensor_info['name']}' has no encoding") + return get_ones_tensor(tensor_info, logger) + + # quant if tensor is float with encoding + return ( + tensor.div(scale).add(offset).round().to(eval(tensor_info["dtype"])) + if tensor.dtype == torch.float + else tensor.sub(offset).mul(scale).to(torch.float32) + ) + + +def get_tensor(io_info, tensors, logger, checking_output=False): + # check if enough tensors have been given + if len(tensors) != len(io_info): + logger.error( + "given tensor numbers mismatch, " + f"expected {len(io_info)} but got {len(tensors)}" + ) + if checking_output: + logger.error( + "output tensors failed to generate, " + "please check executor_runner logs." + ) + exit(-1) + + return [get_ones_tensor(t, logger) for t in io_info] + + # list of tensors to be returned + ret_tensors, ret_list = [], [] + for i, info in enumerate(io_info): + ret_list.append(f"input_0_{i}.raw") + if list(tensors[i].shape) != info["shape"]: + logger.error( + f"tensor '{info['name']}' shape mismatch: " + f"users > {tensors[i].shape} - " + f"required > {info['shape']}" + ) + ret_tensors.append(get_ones_tensor(info, logger)) + continue + + ret_tensors.append( + tensors[i] + if tensors[i].dtype == eval(info["dtype"]) + else + # try quant / dequant for given tensor if possible + ret_tensors.append(get_tensor_with_encoding(tensors[i], info, logger)) + ) + return [ret_tensors], " ".join(ret_list) + + +def to_context_binary( + model_lib, soc_model, device, host, build_folder, output_folder, logger +): + ext = Path(model_lib).suffix + if ext == ".bin": + return model_lib + + assert ( + device is not None + ), "Please assign device serial for model library conversion." + logger.info(f"Generating context binary for {model_lib}") + # leverage SimpleADB for model library conversion + lib_name = Path(model_lib).stem + sdk_root = os.getenv("QNN_SDK_ROOT") + adb = SimpleADB( + qnn_sdk=sdk_root, + build_path=build_folder, + pte_path=model_lib, + workspace=f"/data/local/tmp/executorch/{lib_name}", + device_id=device, + soc_model=soc_model, + host_id=host, + ) + + logger.info("pushing QNN libraries & tool") + arch = adb.arch_table[soc_model] + files = [ + f"{sdk_root}/bin/aarch64-android/qnn-context-binary-generator", + f"{sdk_root}/lib/aarch64-android/libQnnHtp.so", + f"{sdk_root}/lib/aarch64-android/libQnnHtpV{arch}Stub.so", + f"{sdk_root}/lib/aarch64-android/libQnnHtpPrepare.so", + f"{sdk_root}/lib/hexagon-v{arch}/unsigned/libQnnHtpV{arch}Skel.so", + ] + adb.push(files=files) + + logger.info("starting conversion") + commands = " ".join( + [ + f"cd {adb.workspace} &&", + "export LD_LIBRARY_PATH=. &&", + "./qnn-context-binary-generator", + f"--model {Path(model_lib).name}", + "--backend libQnnHtp.so", + f"--binary_file {lib_name}", + ] + ) + adb.execute(custom_runner_cmd=commands) + + logger.info(f"collecting converted context binary - {lib_name}.bin") + adb._adb(["pull", f"{adb.workspace}/output/{lib_name}.bin", output_folder]) + + bin_path = f"{output_folder}/{lib_name}.bin" + assert os.path.exists(bin_path), ( + "Failed to convert context binary, " "please check logcat for more details." + ) + return bin_path + + +def compile(args): + logger = get_logger() + logger.info("prepare compiler spec for qualcomm backend") + + # setup compiler spec dedicated to QNN HTP backend + backend_options = generate_htp_compiler_spec(use_fp16=False) + # setup general compiler spec for QNN + compiler_specs = generate_qnn_executorch_compiler_spec( + soc_model=getattr(QcomChipset, args.model), + backend_options=backend_options, + is_from_context_binary=True, + ) + # setup memory planning + memory_planning_pass = MemoryPlanningPass( + memory_planning_algo="greedy", + alloc_graph_input=args.allocate_graph_io, + alloc_graph_output=args.allocate_graph_io, + ) + + # dictionary for avoiding name collision when creating custom ops + name_map = {} + num_bins = len(args.artifacts) + for i, ctx_bin in enumerate(args.artifacts): + index = i + 1 + binary_name = Path(ctx_bin).stem + output_dir = f"{args.output_pte_folder}/{binary_name}" + make_output_dir(output_dir) + # conversion model library into context binary if required + ctx_bin = to_context_binary( + model_lib=ctx_bin, + soc_model=args.model, + device=args.device, + host=args.host, + build_folder=args.build_folder, + output_folder=output_dir, + logger=logger, + ) + # step 0: check if name collision happens for context binaries + logger.info(f"({index}/{num_bins}) checking custom op name of {ctx_bin}") + custom_op_name = f"ctx_loader_{binary_name}" + postfix = name_map.get(custom_op_name, 0) + if postfix > 0: + postfix += 1 + custom_op_name = f"{custom_op_name}_{postfix}" + name_map[custom_op_name] = postfix + # step 1: generate ExportedProgram with custom op as binary loader + logger.info(f"({index}/{num_bins}) exporting program for {ctx_bin}") + prog_info = from_context_binary( + ctx_bin, custom_op_name, getattr(QcomChipset, args.model) + ) + # step 2: lower to QnnBackend + logger.info(f"({index}/{num_bins}) start lowering {ctx_bin} to QnnBackend") + lowered_module = to_backend( + "QnnBackend", prog_info["edge_program"], compiler_specs + ) + # step 3: write pte files and IO information + logger.info(f"({index}/{num_bins}) exporting {binary_name}.pte") + with open(f"{output_dir}/{binary_name}.pte", "wb") as f: + f.write( + lowered_module.buffer( + extract_delegate_segments=True, memory_planning=memory_planning_pass + ) + ) + logger.info( + f"({index}/{num_bins}) exporting network graph with {binary_name}.svg" + ) + draw_graph(binary_name, output_dir, prog_info["edge_program"].graph_module) + logger.info( + f"({index}/{num_bins}) exporting graph description with {binary_name}.json" + ) + with open(f"{output_dir}/{binary_name}.json", "w") as f: + graph_info = get_io_info(prog_info, ctx_bin, compiler_specs) + graph_info["soc_model"] = args.model + json.dump(graph_info, f, indent=2) + + +def execute(args): + logger = get_logger() + + # load graph description file + pte_name = Path(args.pte_directory).stem + graph_desc = f"{args.pte_directory}/{pte_name}.json" + logger.info(f"loading graph description: {graph_desc}") + with open(graph_desc, "r") as f: + graph_info = json.load(f) + + # load input files + logger.info("loading user inputs") + user_inputs = [] + for input_file in args.input_files: + with open(input_file, "rb") as f: + buffer = io.BytesIO(f.read()) + user_inputs.append(torch.load(buffer, weights_only=False)) + + # check if inputs are valid, fallback to ones tensor if any + logger.info("generating input data") + inputs, input_list = get_tensor(graph_info["inputs"], user_inputs, logger) + + logger.info("preparing ADB connection") + # leverage SimpleADB for e2e inference + adb = SimpleADB( + qnn_sdk=os.getenv("QNN_SDK_ROOT"), + build_path=args.build_folder, + pte_path=f"{args.pte_directory}/{pte_name}.pte", + workspace=f"/data/local/tmp/executorch/{pte_name}", + device_id=args.device, + soc_model=graph_info["soc_model"], + host_id=args.host, + shared_buffer=args.shared_buffer, + ) + + logger.info("pushing QNN libraries & other artifacts") + adb.push(inputs=inputs, input_list=input_list) + + logger.info("starting inference") + adb.execute() + + logger.info("collecting output data") + + def post_process(): + output_info, outputs = graph_info["outputs"], [] + output_folder = f"{args.output_data_folder}/outputs" + for i, f in enumerate(sorted(os.listdir(output_folder))): + filename = os.path.join(output_folder, f) + output = np.fromfile( + filename, dtype=eval(f"np.{output_info[i]['dtype'].split('.')[-1]}") + ) + outputs.append(torch.from_numpy(output.reshape(output_info[i]["shape"]))) + os.remove(filename) + + os.rmdir(output_folder) + outputs, _ = get_tensor(output_info, outputs, logger, checking_output=True) + # dataset length equals to 1 + for i, output in enumerate(outputs[0]): + torch.save(output, f"{args.output_data_folder}/{output_info[i]['name']}.pt") + + make_output_dir(args.output_data_folder) + adb.pull(args.output_data_folder, post_process) + logger.info( + f"execution finished, please check {args.output_data_folder} for results" + ) + + +def main(): + parser = argparse.ArgumentParser( + description=( + "Utility to lower precompiled model libraries / " + "context binaries from Qualcomm AI Engine Direct to executorch" + " .pte program. Please visit https://aihub.qualcomm.com/ to " + "download your favorite models." + ), + ) + subparsers = parser.add_subparsers( + title="subcommands", + description=( + "[compile]: Compile designated model libraries / " + "context binaries into .pte files. " + "[execute]: Perform on-device inference with given .pte." + ), + ) + + sub_compile = subparsers.add_parser( + name="compile", + help=( + "e.g. python export.py compile -a model.bin -m SM8650 " + "-b /path/to/build-android" + ), + ) + sub_compile.add_argument( + "-a", + "--artifacts", + nargs="+", + type=str, + required=True, + help=( + "Path to AI HUB or QNN tool generated artifacts, " + "batch process is supported. " + "e.g. python export.py compile -a a.bin b.so c.bin " + "-m SM8650 -s $SERIAL_NO -b /path/to/build-android" + ), + ) + sub_compile.add_argument( + "-m", + "--model", + type=str, + required=True, + help="SoC model. e.g. SM8650", + ) + sub_compile.add_argument( + "-s", + "--device", + type=str, + help="Serial no of device which could be obtained by 'adb devices'.", + ) + sub_compile.add_argument( + "-o", + "--output_pte_folder", + type=str, + default="./output_pte", + help=( + "Path to output artifacts, store in 'output_pte' if not given. " + "graph descriptions & diagram will also be exported." + ), + ) + sub_compile.add_argument( + "-b", + "--build_folder", + help="Path to cmake binary directory for android, e.g., /path/to/build-android", + type=str, + required=True, + ) + sub_compile.add_argument( + "-l", + "--allocate_graph_io", + type=bool, + default=True, + help=( + "True if IO tensors are pre-allocated by framework. " + "False for users who want to manage resources in runtime." + ), + ) + sub_compile.add_argument( + "-H", + "--host", + type=str, + help="Gateway hostname.", + ) + sub_compile.set_defaults(callback=compile) + + sub_execute = subparsers.add_parser( + name="execute", + help=( + "e.g. python export.py execute -p model_dir -i inp.raw " "-s device_serial" + ), + ) + sub_execute.add_argument( + "-p", + "--pte_directory", + type=str, + required=True, + help="Path to .pte file folder generated from 'compile' subcommand.", + ) + sub_execute.add_argument( + "-i", + "--input_files", + nargs="*", + type=str, + help=( + "Path to input files stored via torch.save. " + "If the number / spec of input files doesn't match given .pte file, " + "tensors filled with value 1 will be taken as inputs." + ), + ) + sub_execute.add_argument( + "-s", + "--device", + type=str, + required=True, + help="Serial no of device which could be obtained by 'adb devices'.", + ) + sub_execute.add_argument( + "-o", + "--output_data_folder", + type=str, + default="./output_data", + help="Path to output data, store in 'output_data' if not given.", + ) + sub_execute.add_argument( + "-b", + "--build_folder", + help="Path to cmake binary directory for android, e.g., /path/to/build-android", + type=str, + required=True, + ) + sub_execute.add_argument( + "-z", + "--shared_buffer", + help=( + "Enables usage of shared buffer between application and backend for graph I/O." + " Please use with `--allocate_graph_io False` in compile command." + ), + action="store_true", + ) + sub_execute.add_argument( + "-H", + "--host", + type=str, + help="Gateway hostname.", + ) + sub_execute.set_defaults(callback=execute) + + args = parser.parse_args() + args.callback(args) + + +if __name__ == "__main__": + main() diff --git a/examples/qualcomm/utils.py b/examples/qualcomm/utils.py index 2293b31e59..20641c6dc8 100755 --- a/examples/qualcomm/utils.py +++ b/examples/qualcomm/utils.py @@ -39,6 +39,22 @@ class SimpleADB: + """ + A wrapper class for communicating with Android device + + Attributes: + qnn_sdk (str): QNN SDK path setup in environment variable + build_path (str): Path where artifacts were built + pte_path (str): Path where executorch binary was stored + workspace (str): Folder for storing artifacts on android device + device_id (str): Serial number of android device + soc_model (str): Chipset of device + host_id (str): Hostname of machine where device connects + error_only (bool): Redirect stdio and leave error messages only + shared_buffer (bool): Apply zero-copy mechanism in runtime + runner (str): Runtime executor binary + """ + def __init__( self, qnn_sdk, @@ -62,13 +78,13 @@ def __init__( self.input_list_filename = "input_list.txt" self.etdump_path = f"{self.workspace}/etdump.etdp" self.output_folder = f"{self.workspace}/outputs" - arch_table = { + self.arch_table = { "SM8650": "75", "SM8550": "73", "SM8475": "69", "SM8450": "69", } - self.soc_model = arch_table[soc_model] + self.soc_model = self.arch_table[soc_model] self.error_only = error_only self.shared_buffer = shared_buffer self.runner = runner