Skip to content

Commit

Permalink
Cadence - Move primary code to backends folder (#3353)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #3353

See design and discussion in https://docs.google.com/document/d/1HPDTbN07WXB9PCdezmvSs_0Yy89D57F1pHHAF9UAgGg/edit#heading=h.828btb3wp67h.

Previous folder structure:
```
executorch
└── examples
    ├── aot
    ├── kernels
    ├── ops
    ├── tests
    ├── third-party/hifi4-nnlib
    └── utils
```

New folder structure:
```
executorch
├── backends
│   └── cadence
│       ├── aot
│       ├── ops_registration
│       ├── tests
│       ├── utils
│       ├── hifi
│       │   ├── kernels
│       │   ├── operators
│       │   └── third-party
│       │       └── nnlib
│       └── [other cadence DSP families]
│           ├── kernels
│           ├── operators
│           └── third-party
│               └── [any required lib]
└── examples
    └── cadence
        ├── models
        └── operators
```

Reviewed By: tarun292, cccclai

Differential Revision: D56577399

fbshipit-source-id: a19d7d689b286c0da2ef533a17e5e66ee1eb8a26
  • Loading branch information
mcremon-meta authored and facebook-github-bot committed May 7, 2024
1 parent 6c0c675 commit c001f59
Show file tree
Hide file tree
Showing 35 changed files with 122 additions and 60 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,8 @@ set_property(
"${CMAKE_CURRENT_LIST_DIR}/../../cmake-out/extension/runner_util/libextension_runner_util.a"
)

add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ops)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/kernels)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/hifi/operators)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/hifi/kernels)

# Generate the model header file
add_custom_command(
Expand Down
30 changes: 30 additions & 0 deletions backends/cadence/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# Cadence DSP Backends

## Supported DSPs (in progress)
- HiFi Audio
- ...

## Tutorial

Please follow the [tutorial](https://pytorch.org/executorch/main/build-run-xtensa.html) for more information on how to run models on Cadence/Xtensa DSPs.

## Directory Structure

```
executorch
├── backends
│ └── cadence
│ ├── aot
│ ├── ops_registration
│ ├── tests
│ ├── utils
│ └── hifi
│ ├── kernels
│ ├── operators
│ └── third-party
│ └── nnlib
└── examples
└── cadence
├── models
└── operators
```
Empty file.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -8,28 +8,49 @@

import logging

from .meta_registrations import * # noqa
from executorch.backends.cadence.aot.ops_registrations import * # noqa

from torch._export import capture_pre_autograd_graph
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e

from ...portable.utils import save_pte_program
import os
from typing import Any, Tuple

from .compiler import export_to_edge
from .quantizer import (
from executorch.backends.cadence.aot.compiler import export_to_edge
from executorch.backends.cadence.aot.quantizer import (
CadenceBaseQuantizer,
QuantFusion,
ReplacePT2DequantWithCadenceDequant,
ReplacePT2QuantWithCadenceQuant,
)
from executorch.exir import ExecutorchProgramManager
from torch import nn
from torch._export import capture_pre_autograd_graph
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e

from .utils import print_ops_info


FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s"
logging.basicConfig(level=logging.INFO, format=FORMAT)


def export_model(model, example_inputs):
def _save_pte_program(
prog: ExecutorchProgramManager, model_name: str, output_dir: str = ""
) -> None:
if model_name.endswith(".pte"):
filename = model_name
else:
filename = os.path.join(output_dir, f"{model_name}.pte")

try:
with open(filename, "wb") as file:
prog.write_to_file(file)
logging.info(f"Saved exported program to {filename}")
except Exception as e:
logging.error(f"Error while saving to {filename}: {e}")


def export_model(
model: nn.Module, example_inputs: Tuple[Any], file_name: str = "CadenceDemoModel"
):
# Quantizer
quantizer = CadenceBaseQuantizer()

Expand Down Expand Up @@ -70,5 +91,5 @@ def export_model(model, example_inputs):
cadence_prog_manager.exported_program().graph_module,
)

# Save the program as CadenceDemoModel.pte
save_pte_program(exec_prog, "CadenceDemoModel")
# Save the program as (default name is CadenceDemoModel.pte)
_save_pte_program(exec_prog, file_name)
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include "pin_mux.h"

#include <memory>
// patternlint-disable executorch-cpp-nostdinc
#include <vector>

#include <executorch/extension/data_loader/buffer_data_loader.h>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
add_library(
cadence_kernels
kernels.cpp
${EXECUTORCH_ROOT}/examples/cadence/third-party/nnlib-hifi4/matmul_asym8uxasym8u_asym8u.cpp
${EXECUTORCH_ROOT}/backends/cadence/hifi/third-party/nnlib/matmul_asym8uxasym8u_asym8u.cpp
)

target_include_directories(
Expand Down
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -32,31 +32,26 @@ set(_aten_ops__srcs
"${EXECUTORCH_ROOT}/kernels/portable/cpu/op_permute_copy.cpp"
"${EXECUTORCH_ROOT}/kernels/portable/cpu/util/copy_ops_util.cpp"
"${EXECUTORCH_ROOT}/kernels/portable/cpu/util/broadcast_util.cpp"
"${EXECUTORCH_ROOT}/kernels/portable/cpu/util/repeat_util.cpp"
)
"${EXECUTORCH_ROOT}/kernels/portable/cpu/util/repeat_util.cpp")
add_library(aten_ops_cadence ${_aten_ops__srcs})
target_link_libraries(aten_ops_cadence PUBLIC executorch)
target_link_libraries(aten_ops_cadence PRIVATE cadence_kernels)

# Let files say "include <executorch/path/to/header.h>".
set(_common_include_directories ${EXECUTORCH_ROOT}/..)

target_include_directories(
aten_ops_cadence PUBLIC ${ROOT_DIR}/.. ${CMAKE_BINARY_DIR}
${_common_include_directories}
)
target_include_directories(aten_ops_cadence PUBLIC ${ROOT_DIR}/..
${CMAKE_BINARY_DIR}
${_common_include_directories})

# Custom ops that are needed to run the test model.
add_library(
custom_ops
"quantized_linear_out.cpp" "quantized_conv_out.cpp" "quantized_relu_out.cpp"
"quantized_layer_norm.cpp" "quantize_per_tensor.cpp"
"dequantize_per_tensor.cpp"
)
target_include_directories(
custom_ops PUBLIC ${ROOT_DIR}/.. ${CMAKE_BINARY_DIR}
${_common_include_directories}
)
custom_ops "quantized_linear_out.cpp" "quantized_conv_out.cpp"
"quantized_relu_out.cpp" "quantized_layer_norm.cpp"
"quantize_per_tensor.cpp" "dequantize_per_tensor.cpp")
target_include_directories(custom_ops PUBLIC ${ROOT_DIR}/..
${CMAKE_BINARY_DIR}
${_common_include_directories})

target_link_libraries(custom_ops PUBLIC executorch)
target_link_libraries(custom_ops PRIVATE cadence_kernels)
Expand All @@ -65,14 +60,15 @@ target_link_libraries(custom_ops PRIVATE cadence_kernels)
# Executorch (for runtime). Here select all ops in functions.yaml
gen_selected_ops(
LIB_NAME "cadence_ops_lib" OPS_SCHEMA_YAML
"${CMAKE_CURRENT_LIST_DIR}/functions.yaml"
"${CMAKE_CURRENT_LIST_DIR}/../../aot/functions.yaml" "" ""
)
generate_bindings_for_kernels(
LIB_NAME "cadence_ops_lib" FUNCTIONS_YAML
${CMAKE_CURRENT_SOURCE_DIR}/functions.yaml
LIB_NAME "cadence_ops_lib" OPS_SCHEMA_YAML
FUNCTIONS_YAML ${CMAKE_CURRENT_SOURCE_DIR}/../../aot/functions.yaml
)
message("Generated files ${gen_command_sources}")

gen_operators_lib(
LIB_NAME "cadence_ops_lib" KERNEL_LIBS custom_ops DEPS aten_ops_cadence
)
LIB_NAME "cadence_ops_lib"
KERNEL_LIBS custom_ops
DEPS aten_ops_cadence)
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
52 changes: 33 additions & 19 deletions docs/source/build-run-xtensa.md
Original file line number Diff line number Diff line change
Expand Up @@ -64,18 +64,32 @@ Step 2. Make sure you have completed the ExecuTorch setup tutorials linked to at
The working tree is:

```
examples/cadence/
├── aot
├── kernels
├── ops
├── tests
├── third-party
└── utils
executorch
├── backends
│ └── cadence
│ ├── aot
│ ├── ops_registration
│ ├── tests
│ ├── utils
│ ├── hifi
│ │ ├── kernels
│ │ ├── operators
│ │ └── third-party
│ │ └── hifi4-nnlib
│ └── [other cadence DSP families]
│ ├── kernels
│ ├── operators
│ └── third-party
│ └── [any required lib]
└── examples
└── cadence
├── models
└── operators
```

***AoT (Ahead-of-Time) Components***:

The AoT folder contains all of the python scripts and functions needed to export the model to an ExecuTorch `.pte` file. In our case, [export_example.py](https://github.com/pytorch/executorch/blob/main/examples/cadence/aot/export_example.py) is an API that takes a model (nn.Module) and representative inputs and runs it through the quantizer (from [quantizer.py](https://github.com/pytorch/executorch/blob/main/examples/cadence/aot/quantizer.py)). Then a few compiler passes, also defined in [quantizer.py](https://github.com/pytorch/executorch/blob/main/examples/cadence/aot/quantizer.py), will replace operators with custom ones that are supported and optimized on the chip. Any operator needed to compute things should be defined in [meta_registrations.py](https://github.com/pytorch/executorch/blob/main/examples/cadence/aot/meta_registrations.py) and have corresponding implemetations in the other folders.
The AoT folder contains all of the python scripts and functions needed to export the model to an ExecuTorch `.pte` file. In our case, [export_example.py](https://github.com/pytorch/executorch/blob/main/backends/cadence/aot/export_example.py) is an API that takes a model (nn.Module) and representative inputs and runs it through the quantizer (from [quantizer.py](https://github.com/pytorch/executorch/blob/main/backends/cadence/aot/quantizer.py)). Then a few compiler passes, also defined in [quantizer.py](https://github.com/pytorch/executorch/blob/main/backends/cadence/aot/quantizer.py), will replace operators with custom ones that are supported and optimized on the chip. Any operator needed to compute things should be defined in [ops_registrations.py](https://github.com/pytorch/executorch/blob/main/backends/cadence/aot/ops_registrations.py) and have corresponding implemetations in the other folders.

***Operators***:

Expand All @@ -101,27 +115,27 @@ python3 -m examples.portable.scripts.export --model_name="add"
***Quantized Operators***:

The other, more complex model are custom operators, including:
- a quantized [linear](https://pytorch.org/docs/stable/generated/torch.nn.Linear.html) operation. The model is defined [here](https://github.com/pytorch/executorch/blob/main/examples/cadence/tests/quantized_linear_example.py#L28). Linear is the backbone of most Automatic Speech Recognition (ASR) models.
- a quantized [conv1d](https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html) operation. The model is defined [here](https://github.com/pytorch/executorch/blob/main/examples/cadence/tests/quantized_conv1d_example.py#L36). Convolutions are important in wake word and many denoising models.
- a quantized [linear](https://pytorch.org/docs/stable/generated/torch.nn.Linear.html) operation. The model is defined [here](https://github.com/pytorch/executorch/blob/main/examples/cadence/operators/quantized_linear_op.py#L28). Linear is the backbone of most Automatic Speech Recognition (ASR) models.
- a quantized [conv1d](https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html) operation. The model is defined [here](https://github.com/pytorch/executorch/blob/main/examples/cadence/operators/quantized_conv1d_op.py#L36). Convolutions are important in wake word and many denoising models.

In both cases the generated file is called `XtensaDemoModel.pte`.
In both cases the generated file is called `CadenceDemoModel.pte`.

```bash
cd executorch
python3 -m examples.cadence.tests.quantized_<linear,conv1d>_example
python3 -m examples.cadence.operators.quantized_<linear,conv1d>_op
```

***Small Model: RNNT predictor***:

The torchaudio [RNNT-emformer](https://pytorch.org/audio/stable/tutorials/online_asr_tutorial.html) model is an Automatic Speech Recognition (ASR) model, comprised of three different submodels: an encoder, a predictor and a joiner.
The predictor is a sequence of basic ops (embedding, ReLU, linear, layer norm) and can be exported using:
The [predictor](https://github.com/pytorch/executorch/blob/main/examples/cadence/models/rnnt_predictor.py) is a sequence of basic ops (embedding, ReLU, linear, layer norm) and can be exported using:

```bash
cd executorch
python3 -m examples.cadence.tests.rnnt_predictor_quantized_example
python3 -m examples.cadence.models.rnnt_predictor
```

The generated file is called `XtensaDemoModel.pte`.
The generated file is called `CadenceDemoModel.pte`.

### Runtime

Expand Down Expand Up @@ -150,7 +164,7 @@ In order to run the CMake build, you need the path to the following:
cd executorch
rm -rf cmake-out
# prebuild and install executorch library
cmake -DCMAKE_TOOLCHAIN_FILE=<path_to_executorch>/examples/cadence/cadence.cmake \
cmake -DCMAKE_TOOLCHAIN_FILE=<path_to_executorch>/backends/cadence/cadence.cmake \
-DCMAKE_INSTALL_PREFIX=cmake-out \
-DCMAKE_BUILD_TYPE=Debug \
-DPYTHON_EXECUTABLE=python3 \
Expand All @@ -163,10 +177,10 @@ cmake -DCMAKE_TOOLCHAIN_FILE=<path_to_executorch>/examples/cadence/cadence.cmake
-DFLATC_EXECUTABLE="$(which flatc)" \
-Bcmake-out .

cmake --build cmake-out -j8 --target install --config Debug
cmake --build cmake-out -j<num_cores> --target install --config Debug
# build cadence runner
cmake -DCMAKE_BUILD_TYPE=Debug \
-DCMAKE_TOOLCHAIN_FILE=<path_to_executorch>/examples/cadence/cadence.cmake \
-DCMAKE_TOOLCHAIN_FILE=<path_to_executorch>/examples/backends/cadence.cmake \
-DCMAKE_PREFIX_PATH=<path_to_executorch>/cmake-out \
-DMODEL_PATH=<path_to_program_file_generated_in_previous_step> \
-DNXP_SDK_ROOT_DIR=<path_to_nxp_sdk_root> -DEXECUTORCH_BUILD_FLATC=0 \
Expand Down Expand Up @@ -212,6 +226,6 @@ First 20 elements of output 0

In this tutorial, you have learned how to export a quantized operation, build the ExecuTorch runtime and run this model on the Xtensa HiFi4 DSP chip.

The (quantized linear) model in this tutorial is a typical operation appearing in ASR models, and can be extended to a complete ASR model by creating the model as a new test and adding the needed operators/kernels to [operators](https://github.com/pytorch/executorch/blob/main/examples/cadence/ops) and [kernels](https://github.com/pytorch/executorch/blob/main/examples/cadence/kernels).
The (quantized linear) model in this tutorial is a typical operation appearing in ASR models, and can be extended to a complete ASR model by creating the model as a new test and adding the needed operators/kernels to [operators](https://github.com/pytorch/executorch/blob/main/backends/cadence/hifi/operators) and [kernels](https://github.com/pytorch/executorch/blob/main/backends/cadence/hifi/kernels).

Other models can be created following the same structure, always assuming that operators and kernels are available.
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@

import torch

from ..aot.meta_registrations import * # noqa
from executorch.backends.cadence.aot.ops_registrations import * # noqa

from typing import Tuple

from ..aot.export_example import export_model
from executorch.backends.cadence.aot.export_example import export_model


FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@

import logging

from ..aot.meta_registrations import * # noqa

import torch

from ..aot.export_example import export_model
from executorch.backends.cadence.aot.ops_registrations import * # noqa

from executorch.backends.cadence.aot.export_example import export_model


FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@

import logging

from ..aot.meta_registrations import * # noqa

import torch

from ..aot.export_example import export_model
from executorch.backends.cadence.aot.ops_registrations import * # noqa

from executorch.backends.cadence.aot.export_example import export_model


FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s"
Expand Down

0 comments on commit c001f59

Please sign in to comment.