Skip to content

Commit

Permalink
Qualcomm AI Engine Direct - Enable HTP emulator test in x86 host (#4503)
Browse files Browse the repository at this point in the history
Summary:
- Enable x64 runner
- Enable HTP emulator test on unit test
- Fix unexpected error message
- Fix multi-contexts UT's mismatching datatype issue
- Port x64 dequantize flow instead of using arm_neon intrinsics
- Fix EtDump flow on runner and unittest

Pull Request resolved: #4503

Reviewed By: digantdesai

Differential Revision: D60598800

Pulled By: cccclai

fbshipit-source-id: bfb9df7948c3f64b2bd0e140836dfbd2d4655c0b
  • Loading branch information
chuntl authored and facebook-github-bot committed Aug 2, 2024
1 parent d59419c commit 1090bcd
Show file tree
Hide file tree
Showing 11 changed files with 160 additions and 61 deletions.
18 changes: 12 additions & 6 deletions backends/qualcomm/aot/ir/qcir_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ Qnn_DataType_t ToDataType(qcir::DataType type) {
}

flatbuffers::Offset<qcir::QuantizeParam> ToQuantizeParam(
const Qnn_QuantizeParams_t& param,
const Qnn_Tensor_t& tensor,
flatbuffers::FlatBufferBuilder* builder) {
static const std::unordered_map<Qnn_Definition_t, qcir::QuantizeDef> def_map{
{QNN_DEFINITION_IMPL_GENERATED, qcir::QuantizeDef::IMPL_GENERATED},
Expand All @@ -124,6 +124,7 @@ flatbuffers::Offset<qcir::QuantizeParam> ToQuantizeParam(

int32_t axis = 0;
uint32_t bitwidth = 0;
auto param = QNN_VER_PTR(tensor)->quantizeParams;
auto quant_type = type_map.at(param.quantizationEncoding);
std::vector<qcir::ScaleOffset> data;
std::vector<float> scales;
Expand Down Expand Up @@ -160,7 +161,9 @@ flatbuffers::Offset<qcir::QuantizeParam> ToQuantizeParam(
}
} break;
default:
QNN_EXECUTORCH_LOG_ERROR("QNN_QUANTIZATION_ENCODING_UNDEFINED detected");
QNN_EXECUTORCH_LOG_WARN(
"QNN_QUANTIZATION_ENCODING_UNDEFINED detected: %s",
QNN_VER_PTR(tensor)->name);
break;
}
return CreateQuantizeParamDirect(
Expand All @@ -174,7 +177,7 @@ flatbuffers::Offset<qcir::QuantizeParam> ToQuantizeParam(
&data);
}

Qnn_QuantizeParams_t ToQuantizeParam(const qparam_type& param) {
Qnn_QuantizeParams_t ToQuantizeParam(const tensor_type& tensor) {
static const std::unordered_map<qcir::QuantizeDef, Qnn_Definition_t> def_map{
{qcir::QuantizeDef::IMPL_GENERATED, QNN_DEFINITION_IMPL_GENERATED},
{qcir::QuantizeDef::DEFINED, QNN_DEFINITION_DEFINED},
Expand All @@ -196,6 +199,7 @@ Qnn_QuantizeParams_t ToQuantizeParam(const qparam_type& param) {
};

Qnn_QuantizeParams_t p = QNN_QUANTIZE_PARAMS_INIT;
auto param = tensor->qparam();
p.encodingDefinition = def_map.at(param->def());
p.quantizationEncoding = type_map.at(param->type());
switch (p.quantizationEncoding) {
Expand Down Expand Up @@ -225,7 +229,9 @@ Qnn_QuantizeParams_t ToQuantizeParam(const qparam_type& param) {
const_cast<int32_t*>(param->offsets()->data());
} break;
default:
QNN_EXECUTORCH_LOG_ERROR("qcir::QuantizeType::UNDEFINED detected");
QNN_EXECUTORCH_LOG_WARN(
"qcir::QuantizeType::UNDEFINED detected: %s",
tensor->name()->c_str());
break;
}
return p;
Expand All @@ -248,7 +254,7 @@ flatbuffers::Offset<qcir::Tensor> ToTensor(
&shape,
ToTensorType(QNN_VER_PTR(tensor)->type),
ToDataType(QNN_VER_PTR(tensor)->dataType),
ToQuantizeParam(QNN_VER_PTR(tensor)->quantizeParams, builder),
ToQuantizeParam(tensor, builder),
&buffer);
}

Expand All @@ -261,7 +267,7 @@ Qnn_Tensor_t ToTensor(const tensor_type& tensor) {
QNN_VER_PTR(t)->name = tensor->name()->c_str();
QNN_VER_PTR(t)->type = ToTensorType(tensor->type());
QNN_VER_PTR(t)->dataType = ToDataType(tensor->dtype());
QNN_VER_PTR(t)->quantizeParams = ToQuantizeParam(tensor->qparam());
QNN_VER_PTR(t)->quantizeParams = ToQuantizeParam(tensor);
QNN_VER_PTR(t)->rank = tensor->shape()->size();
QNN_VER_PTR(t)->dimensions = const_cast<uint32_t*>(tensor->shape()->data());
QNN_VER_PTR(t)->clientBuf.dataSize = tensor->data()->size();
Expand Down
4 changes: 2 additions & 2 deletions backends/qualcomm/aot/ir/qcir_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@ qcir::DataType ToDataType(Qnn_DataType_t type);
Qnn_DataType_t ToDataType(qcir::DataType type);

flatbuffers::Offset<qcir::QuantizeParam> ToQuantizeParam(
const Qnn_QuantizeParams_t& param,
const Qnn_Tensor_t& tensor,
flatbuffers::FlatBufferBuilder* builder);
Qnn_QuantizeParams_t ToQuantizeParam(const qparam_type& type);
Qnn_QuantizeParams_t ToQuantizeParam(const tensor_type& tensor);

flatbuffers::Offset<qcir::Tensor> ToTensor(
const Qnn_Tensor_t& tensor,
Expand Down
4 changes: 2 additions & 2 deletions backends/qualcomm/runtime/QnnExecuTorchBackend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,12 @@ Result<DelegateHandle*> QnnExecuTorchBackend::init(
ArrayRef<CompileSpec> compile_specs) const {
// covert SizedBuffer to qnn ExecuTorch option
QnnExecuTorchContextBinary qnn_context_blob;
const qnn_delegate::QnnExecuTorchOptions* qnn_executorch_options;
const qnn_delegate::QnnExecuTorchOptions* qnn_executorch_options = nullptr;

qnn_context_blob.buffer = const_cast<void*>(processed->data());
qnn_context_blob.nbytes = processed->size();

// covert CompileSpec to qnn ExecuTorch option
// convert CompileSpec to qnn ExecuTorch option
for (auto& compile_spec : compile_specs) {
if (std::strcmp(compile_spec.key, QNN_COMPILE_SPEC) == 0)
qnn_executorch_options =
Expand Down
7 changes: 7 additions & 0 deletions backends/qualcomm/runtime/SharedBuffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,12 @@ SharedBuffer& SharedBuffer::GetSharedBufferManager() {
std::lock_guard<std::mutex> lk(init_mutex_);
static SharedBuffer shared_buffer_manager;
if (!shared_buffer_manager.GetInitialize()) {
#if defined(__aarch64__)
Error status = shared_buffer_manager.Load();
#else
// For x86_64 platform
Error status = Error::Ok;
#endif
if (status == Error::Ok) {
shared_buffer_manager.SetInitialize(true);
}
Expand All @@ -96,9 +101,11 @@ SharedBuffer& SharedBuffer::GetSharedBufferManager() {
}

SharedBuffer::~SharedBuffer() {
#if defined(__aarch64__)
if (initialize_) {
SharedBuffer::GetSharedBufferManager().UnLoad();
}
#endif
};

void* SharedBuffer::AllocMem(size_t bytes, size_t alignment) {
Expand Down
24 changes: 19 additions & 5 deletions backends/qualcomm/scripts/build.sh
Original file line number Diff line number Diff line change
Expand Up @@ -107,19 +107,33 @@ if [ "$BUILD_X86_64" = true ]; then
rm -rf $BUILD_ROOT && mkdir $BUILD_ROOT
fi
cd $BUILD_ROOT
# TODO: Use CMAKE_BUILD_TYPE=RelWithDebInfo, and handle flatcc issues
cmake \
-DCMAKE_BUILD_TYPE=RelWithDebInfo \
-DCMAKE_BUILD_TYPE=Debug \
-DCMAKE_INSTALL_PREFIX=$BUILD_ROOT \
-DQNN_SDK_ROOT=${QNN_SDK_ROOT} \
-DEXECUTORCH_BUILD_QNN=ON \
-DEXECUTORCH_BUILD_SDK=ON \
-DEXECUTORCH_BUILD_EXTENSION_MODULE=ON \
-DEXECUTORCH_ENABLE_EVENT_TRACER=ON \
-DPYTHON_EXECUTABLE=$PYTHON_EXECUTABLE \
-DBUCK2=$BUCK2 \
-S $PRJ_ROOT \
-B $BUILD_ROOT \

cmake \
--build $BUILD_ROOT \
-t "PyQnnManagerAdaptor" "PyQnnWrapperAdaptor" -j16
cmake --build $BUILD_ROOT -j16 --target install

rm -f $PRJ_ROOT/backends/qualcomm/python/*
cp -fv $BUILD_ROOT/backends/qualcomm/Py* "$PRJ_ROOT/backends/qualcomm/python"

EXAMPLE_ROOT=examples/qualcomm
CMAKE_PREFIX_PATH="${BUILD_ROOT}/lib/cmake/ExecuTorch;${BUILD_ROOT}/third-party/gflags;"

cmake $PRJ_ROOT/$EXAMPLE_ROOT \
-DCMAKE_BUILD_TYPE=Debug \
-DCMAKE_PREFIX_PATH=$CMAKE_PREFIX_PATH \
-DCMAKE_FIND_ROOT_PATH_MODE_PACKAGE=BOTH \
-DPYTHON_EXECUTABLE=$PYTHON_EXECUTABLE \
-B$EXAMPLE_ROOT

cmake --build $EXAMPLE_ROOT -j16
fi
15 changes: 11 additions & 4 deletions backends/qualcomm/tests/test_qnn_delegate.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ def test_qnn_backend_element_wise_ceil(self):

def test_qnn_backend_element_wise_div(self):
eps = 1e-03
torch.manual_seed(8)
test_comb = [
{
QCOM_MODULE: [Div()], # noqa: F405
Expand Down Expand Up @@ -721,6 +722,7 @@ def test_qnn_backend_element_wise_ceil(self):

def test_qnn_backend_element_wise_div(self):
eps = 1e-03
torch.manual_seed(8)
test_comb = [
{
QCOM_MODULE: [Div()], # noqa: F405
Expand Down Expand Up @@ -1323,7 +1325,6 @@ def test_qnn_backend_multi_contexts_composite(self):
exec_prog = edge_prog.to_executorch()
self.verify_output(module.get_reference_module(), sample_input, exec_prog)

@unittest.expectedFailure
def test_qnn_backend_profile_op(self):
TestQNN.enable_profile = True
backend_options = generate_htp_compiler_spec(use_fp16=True)
Expand All @@ -1338,7 +1339,7 @@ def test_qnn_backend_profile_op(self):
module,
sample_input,
expected_partitions=1,
expected_profile_events=25,
expected_profile_events=24,
)

def test_qnn_backend_shared_buffer(self):
Expand Down Expand Up @@ -1488,7 +1489,6 @@ def test_qnn_backend_multi_contexts_composite(self):
exec_prog = edge_prog.to_executorch()
self.verify_output(module.get_reference_module(), sample_input, exec_prog)

@unittest.expectedFailure
def test_qnn_backend_profile_op(self):
TestQNN.enable_profile = True
backend_options = generate_htp_compiler_spec(use_fp16=False)
Expand All @@ -1504,7 +1504,7 @@ def test_qnn_backend_profile_op(self):
module,
sample_input,
expected_partitions=1,
expected_profile_events=26,
expected_profile_events=25,
)

def test_qnn_backend_shared_buffer(self):
Expand Down Expand Up @@ -2288,6 +2288,12 @@ def setup_environment():
help="Path to open source software model repository",
type=str,
)
parser.add_argument(
"-x",
"--enable_x86_64",
help="Enable unittest to be executed on x86_64 platform",
action="store_true",
)

args, ns_args = parser.parse_known_args(namespace=unittest)
TestQNN.host = args.host
Expand All @@ -2304,6 +2310,7 @@ def setup_environment():
TestQNN.error_only = args.error_only
TestQNN.oss_repo = args.oss_repo
TestQNN.shared_buffer = args.shared_buffer
TestQNN.enable_x86_64 = args.enable_x86_64
return sys.argv[:1] + ns_args


Expand Down
86 changes: 64 additions & 22 deletions backends/qualcomm/tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,11 @@
QcomChipset,
)
from executorch.backends.qualcomm.utils.utils import capture_program
from executorch.examples.qualcomm.scripts.utils import SimpleADB
from executorch.examples.qualcomm.scripts.utils import (
generate_inputs,
make_output_dir,
SimpleADB,
)

from executorch.exir.backend.backend_api import to_backend
from executorch.exir.backend.compile_spec_schema import CompileSpec
Expand Down Expand Up @@ -133,6 +137,7 @@ class TestQNN(unittest.TestCase):
use_16a16w: str = "16a16w"
use_16a4w: str = "16a4w"
shared_buffer: bool = False
enable_x86_64: bool = False

def _assert_outputs_equal(self, model_output, ref_output):
self.assertTrue(len(ref_output) == len(model_output))
Expand Down Expand Up @@ -201,40 +206,75 @@ def verify_output(
tmp_dir,
)

device_output_dir = f"{tmp_dir}/outputs"
device_outputs = []
output_dir = f"{tmp_dir}/outputs"
outputs = []
etdump_path = f"{tmp_dir}/etdump.etdp"

def post_process():
for i, f in enumerate(sorted(os.listdir(device_output_dir))):
filename = os.path.join(device_output_dir, f)
for i, f in enumerate(sorted(os.listdir(output_dir))):
filename = os.path.join(output_dir, f)
output = np.fromfile(filename, dtype=ref_outputs[i].numpy().dtype)
output = torch.from_numpy(output).reshape(ref_outputs[i].shape)
device_outputs.append(output)
outputs.append(output)

def validate_profile():
inspector = Inspector(etdump_path=etdump_path, etrecord=etrecord_path)
self.assertTrue(
len(inspector.to_dataframe().index) == expected_profile_events
)

adb = SimpleADB(
qnn_sdk=os.getenv("QNN_SDK_ROOT"),
build_path=self.build_folder,
pte_path=pte_fname,
workspace="/data/local/tmp/qnn_executorch_test",
device_id=self.device,
host_id=self.host,
soc_model=self.model,
error_only=self.error_only,
)
adb.push(inputs=[sample_inputs], input_list=input_list)
adb.execute()
adb.pull(output_path=tmp_dir, callback=post_process)
self._assert_outputs_equal(device_outputs, ref_outputs)
if self.enable_x86_64:
generate_inputs(tmp_dir, "input_list.txt", [sample_inputs], input_list)
make_output_dir(output_dir)

target = "x86_64-linux-clang"
qnn_sdk = os.environ.get("QNN_SDK_ROOT", None)
assert qnn_sdk, "QNN_SDK_ROOT was not found in environment variable"

build_path = "build_x86_64"
cmds = [
# export LD_LIBRARY_PATH to QNN_SDK_ROOT
f"export LD_LIBRARY_PATH={qnn_sdk}/lib/{target}/:{self.executorch_root}/{build_path}/lib && "
# qnn_executor_runner
f"{self.executorch_root}/{build_path}/examples/qualcomm/qnn_executor_runner",
f"--model_path {pte_fname}",
f"--input_list_path {tmp_dir}/input_list.txt",
f"--output_folder_path {output_dir}",
]

subprocess.run(
" ".join(cmds),
shell=True,
executable="/bin/bash",
capture_output=True,
cwd=tmp_dir,
)

# Verify the outputs
post_process()
self._assert_outputs_equal(outputs, ref_outputs)

# Verify the etdump
if expected_profile_events != -1:
validate_profile()
else:
adb = SimpleADB(
qnn_sdk=os.getenv("QNN_SDK_ROOT"),
build_path=self.build_folder,
pte_path=pte_fname,
workspace="/data/local/tmp/qnn_executorch_test",
device_id=self.device,
host_id=self.host,
soc_model=self.model,
error_only=self.error_only,
)
adb.push(inputs=[sample_inputs], input_list=input_list)
adb.execute()
adb.pull(output_path=tmp_dir, callback=post_process)
self._assert_outputs_equal(outputs, ref_outputs)

if expected_profile_events != -1:
adb.pull_etdump(etdump_path, callback=validate_profile)
if expected_profile_events != -1:
adb.pull_etdump(etdump_path, callback=validate_profile)

def lower_module_and_test_output(
self,
Expand Down Expand Up @@ -362,6 +402,8 @@ def _insert_clone(
(node,),
)
inserted_node.meta["val"] = node.meta["val"]
if "quant_attrs" in node.meta:
inserted_node.meta["quant_attrs"] = node.meta["quant_attrs"]
for user in users:
user.replace_input_with(node, inserted_node)

Expand Down
3 changes: 0 additions & 3 deletions examples/qualcomm/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,6 @@
set(CMAKE_CXX_STANDARD 17)
# qnn_executor_runner: Like executor_runner but with QNN

if(NOT ${ANDROID})
message(FATAL_ERROR "Not building Android, quitting...")
endif()
cmake_minimum_required(VERSION 3.19)
project(qualcomm_runner_example)

Expand Down
5 changes: 2 additions & 3 deletions examples/qualcomm/executor_runner/qnn_executor_runner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@

#include <gflags/gflags.h>

#include <chrono>
#include <fstream>
#include <memory>

Expand Down Expand Up @@ -202,10 +203,8 @@ int main(int argc, char** argv) {
// be used by a single thread at at time, but it can be reused.
//
torch::executor::ETDumpGen etdump_gen = torch::executor::ETDumpGen();
// TODO: So far we have issues with etdump_gen during load_method. Enable it
// after the issues are fixed.
Result<Method> method =
program->load_method(method_name, &memory_manager, nullptr);
program->load_method(method_name, &memory_manager, &etdump_gen);
ET_CHECK_MSG(
method.ok(),
"Loading of method %s failed with status 0x%" PRIx32,
Expand Down
Loading

0 comments on commit 1090bcd

Please sign in to comment.