Skip to content

Commit

Permalink
Cadence examples -- Add RNNT encoder from torchaudio (#3691)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #3691

As titled.

Reviewed By: tarun292

Differential Revision: D57617809

fbshipit-source-id: bd3bbee10ae120804a58ee9dc43ff9598686b3ec
  • Loading branch information
mcremon-meta authored and facebook-github-bot committed Jun 5, 2024
1 parent e85b52a commit b2fafe9
Show file tree
Hide file tree
Showing 7 changed files with 375 additions and 12 deletions.
16 changes: 16 additions & 0 deletions backends/cadence/aot/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,20 @@ load("@fbcode_macros//build_defs:python_library.bzl", "python_library")

oncall("odai_jarvis")

python_library(
name = "utils",
srcs = [
"utils.py",
],
deps = [
"fbsource//third-party/pypi/tabulate:tabulate",
"//caffe2:torch",
"//executorch/exir:memory",
"//executorch/exir/dialects:lib",
"//executorch/exir/dialects/edge:lib",
],
)

python_library(
name = "compiler",
srcs = [
Expand All @@ -26,6 +40,8 @@ python_library(
"passes.py",
],
deps = [
":utils",
"//caffe2:torch",
"//executorch/exir:pass_base",
"//executorch/exir/dialects:lib",
],
Expand Down
15 changes: 12 additions & 3 deletions backends/cadence/aot/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,11 @@
import torch

from executorch.backends.cadence.aot.passes import (
ReplacePT2DequantWithCadenceDequant,
ReplacePT2QuantWithCadenceQuant,
RemoveZeroSizedCatArgsPass,
ReplacePT2DequantWithCadenceDequantPass,
ReplacePT2QuantWithCadenceQuantPass,
ReplaceScalarTensorWithFullPass,
ReplaceSqueezeAndUnsqueezeWithViewPass,
)
from executorch.exir import EdgeCompileConfig, EdgeProgramManager, to_edge

Expand Down Expand Up @@ -84,7 +87,13 @@ def export_to_cadence(

# Run a couple required passes for quant/dequant ops
cadence_program_manager = edge_program_manager.transform(
[ReplacePT2QuantWithCadenceQuant(), ReplacePT2DequantWithCadenceDequant()]
[
RemoveZeroSizedCatArgsPass(),
ReplaceScalarTensorWithFullPass(),
ReplaceSqueezeAndUnsqueezeWithViewPass(),
ReplacePT2QuantWithCadenceQuantPass(),
ReplacePT2DequantWithCadenceDequantPass(),
]
)

return cadence_program_manager
9 changes: 5 additions & 4 deletions backends/cadence/aot/export_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,9 @@ def _save_pte_program(


def export_model(
model: nn.Module, example_inputs: Tuple[Any], file_name: str = "CadenceDemoModel"
model: nn.Module,
example_inputs: Tuple[Any, ...],
file_name: str = "CadenceDemoModel",
):
# Quantizer
quantizer = CadenceQuantizer()
Expand All @@ -72,9 +74,8 @@ def export_model(

exec_prog = cadence_prog_manager.to_executorch()

logging.info(
f"Final exported graph module:\n{exec_prog.exported_program().graph_module}"
)
logging.info("Final exported graph:")
exec_prog.exported_program().graph_module.graph.print_tabular()

# Print some information to terminal
print_ops_info(
Expand Down
65 changes: 65 additions & 0 deletions backends/cadence/aot/functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,46 @@


# aten ops
- op: _to_copy.out
kernels:
- arg_meta: null
kernel_name: torch::executor::to_copy_out

- op: _softmax.out
kernels:
- arg_meta: null
kernel_name: torch::executor::softmax_out

- op: add.out
kernels:
- arg_meta: null
kernel_name: torch::executor::add_out

- op: bmm.out
kernels:
- arg_meta: null
kernel_name: torch::executor::bmm_out

- op: cat.out
kernels:
- arg_meta: null
kernel_name: torch::executor::cat_out

- op: clone.out
kernels:
- arg_meta: null
kernel_name: torch::executor::clone_out

- op: div.out
kernels:
- arg_meta: null
kernel_name: torch::executor::div_out

- op: div.out_mode
kernels:
- arg_meta: null
kernel_name: torch::executor::div_out_mode

- op: embedding.out
kernels:
- arg_meta: null
Expand All @@ -27,16 +62,46 @@
- arg_meta: null
kernel_name: torch::executor::full_out

- op: mul.out
kernels:
- arg_meta: null
kernel_name: torch::executor::mul_out

- op: permute_copy.out
kernels:
- arg_meta: null
kernel_name: torch::executor::permute_copy_out

- op: sigmoid.out
kernels:
- arg_meta: null
kernel_name: torch::executor::sigmoid_out

- op: slice_copy.Tensor_out
kernels:
- arg_meta: null
kernel_name: torch::executor::slice_copy_Tensor_out

- op: split_with_sizes_copy.out
kernels:
- arg_meta: null
kernel_name: torch::executor::split_with_sizes_copy_out

- op: sub.out
kernels:
- arg_meta: null
kernel_name: torch::executor::sub_out

- op: view_copy.out
kernels:
- arg_meta: null
kernel_name: torch::executor::view_copy_out

- op: where.self_out
kernels:
- arg_meta: null
kernel_name: torch::executor::where_out

# custom ops
- func: cadence::quantize_per_tensor.out(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!)
variants: function
Expand Down
100 changes: 97 additions & 3 deletions backends/cadence/aot/passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,15 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import torch
from executorch.backends.cadence.aot.utils import get_edge_overload_packet
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass
from executorch.exir.pass_base import ExportPass, ProxyValue
from torch._subclasses import FakeTensor
from torch.utils._pytree import tree_map_only


class ReplacePT2QuantWithCadenceQuant(ExportPass):
class ReplacePT2QuantWithCadenceQuantPass(ExportPass):
"""
Replace the pt2 quantization ops with custom cadence quantization ops.
"""
Expand All @@ -25,7 +29,7 @@ def call_operator(self, op, args, kwargs, meta):
)


class ReplacePT2DequantWithCadenceDequant(ExportPass):
class ReplacePT2DequantWithCadenceDequantPass(ExportPass):
"""
Replace the pt2 dequantization ops with custom cadence dequantization ops.
"""
Expand All @@ -40,3 +44,93 @@ def call_operator(self, op, args, kwargs, meta):
kwargs,
meta,
)


class ReplaceScalarTensorWithFullPass(ExportPass):
"""
aten.scalar_tensor can be replaced by aten.full with a shape of [1].
"""

def call_operator(self, op, args, kwargs, meta):
if op not in {
exir_ops.edge.aten.scalar_tensor.default,
torch.ops.aten.scalar_tensor.default,
}:
return super().call_operator(op, args, kwargs, meta)

return super().call_operator(
exir_ops.edge.aten.full.default,
(
[1],
args[0],
),
{},
meta,
)


class ReplaceSqueezeAndUnsqueezeWithViewPass(ExportPass):
"""
When the shape is static, replace squeeze_copy and unsqueeze_copy ops with
view_copy op
"""

def call_operator(self, op, args, kwargs, meta):
# Instead of testing EdgeOpOverload, test EdgeOpOverloadPacket,
# which allows us to cover all overloads.
if get_edge_overload_packet(op) not in {
exir_ops.edge.aten.squeeze_copy,
exir_ops.edge.aten.unsqueeze_copy,
}:
return super().call_operator(op, args, kwargs, meta)
# Get the output tensor shape
out_shape = meta["val"].shape

# Bail out if any dim is not an int (dynamic shape)
for dim in list(out_shape):
if not isinstance(dim, int):
return super().call_operator(op, args, kwargs, meta)

# Return a view op with the new shape
view_args = (args[0], list(out_shape))
return super().call_operator(
exir_ops.edge.aten.view_copy.default, view_args, kwargs, meta
)


class RemoveZeroSizedCatArgsPass(ExportPass):
def call_operator(self, op, args, kwargs, meta):
if op != exir_ops.edge.aten.cat.default:
return super().call_operator(op, args, kwargs, meta)

# Remove any zero-sized tensor arg to form a new args list.
new_args = []
for arg in args[0]:
arg_tensor = arg.to_tensor() if isinstance(arg, ProxyValue) else arg
if arg_tensor.numel() > 0:
new_args.append(arg)

# If all the tensors were empty, we just return an empty tensor with
# the right shape.
if not new_args:
args_data, kwargs_data = tree_map_only(
ProxyValue, lambda x: x.data, (args, kwargs)
)
result = op(*args_data, **kwargs_data)
# When tracing with PT2, the FakeTensor mode requires the constant
# argument to be set to itself.
# TODO(matthiascremon): confirm this is the best way to do this.
if isinstance(result, FakeTensor):
result.constant = result
return torch.empty_like(result)

# If there was only one tensor in the new_args list,
# we can safely erase this cat op.
if len(new_args) == 1:
return new_args[0]

# Otherwise, we replace args[0] with new_args.
args = list(args)
args[0] = new_args
args = tuple(args)
return super().call_operator(op, args, kwargs, meta)
21 changes: 19 additions & 2 deletions backends/cadence/hifi/operators/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,27 @@ set(_aten_ops__srcs
"${CMAKE_CURRENT_SOURCE_DIR}/op_embedding.cpp"
"${CMAKE_CURRENT_SOURCE_DIR}/op_full.cpp"
"${CMAKE_CURRENT_SOURCE_DIR}/op_view_copy.cpp"
"${EXECUTORCH_ROOT}/kernels/portable/cpu/op_permute_copy.cpp"
"${EXECUTORCH_ROOT}/kernels/portable/cpu/util/activation_ops_util.cpp"
"${EXECUTORCH_ROOT}/kernels/portable/cpu/util/copy_ops_util.cpp"
"${EXECUTORCH_ROOT}/kernels/portable/cpu/util/broadcast_util.cpp"
"${EXECUTORCH_ROOT}/kernels/portable/cpu/util/repeat_util.cpp")
"${EXECUTORCH_ROOT}/kernels/portable/cpu/util/index_util.cpp"
"${EXECUTORCH_ROOT}/kernels/portable/cpu/util/kernel_ops_util.cpp"
"${EXECUTORCH_ROOT}/kernels/portable/cpu/util/matmul_ops_util.cpp"
"${EXECUTORCH_ROOT}/kernels/portable/cpu/util/reduce_util.cpp"
"${EXECUTORCH_ROOT}/kernels/portable/cpu/util/repeat_util.cpp"
"${EXECUTORCH_ROOT}/kernels/portable/cpu/op_bmm.cpp"
"${EXECUTORCH_ROOT}/kernels/portable/cpu/op_cat.cpp"
"${EXECUTORCH_ROOT}/kernels/portable/cpu/op_clone.cpp"
"${EXECUTORCH_ROOT}/kernels/portable/cpu/op_div.cpp"
"${EXECUTORCH_ROOT}/kernels/portable/cpu/op_mul.cpp"
"${EXECUTORCH_ROOT}/kernels/portable/cpu/op_permute_copy.cpp"
"${EXECUTORCH_ROOT}/kernels/portable/cpu/op_sigmoid.cpp"
"${EXECUTORCH_ROOT}/kernels/portable/cpu/op_slice_copy.cpp"
"${EXECUTORCH_ROOT}/kernels/portable/cpu/op_softmax.cpp"
"${EXECUTORCH_ROOT}/kernels/portable/cpu/op_split_with_sizes_copy.cpp"
"${EXECUTORCH_ROOT}/kernels/portable/cpu/op_sub.cpp"
"${EXECUTORCH_ROOT}/kernels/portable/cpu/op_to_copy.cpp"
"${EXECUTORCH_ROOT}/kernels/portable/cpu/op_where.cpp")
add_library(aten_ops_cadence ${_aten_ops__srcs})
target_link_libraries(aten_ops_cadence PUBLIC executorch)
target_link_libraries(aten_ops_cadence PRIVATE cadence_kernels)
Expand Down
Loading

0 comments on commit b2fafe9

Please sign in to comment.