Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add exportable baby llama example #4345

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 11 additions & 9 deletions backends/cadence/aot/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,23 +11,23 @@
import torch

from executorch.backends.cadence.aot.passes import (
InitializePipeline,
RemoveNopExpandOpPass,
RemoveZeroSizedCatArgsPass,
ReplaceLogicalNotBooleanWhereWithWherePass,
ReplacePT2DequantWithCadenceDequantPass,
ReplacePT2QuantWithCadenceQuantPass,
ReplaceScalarTensorWithFullPass,
ReplaceSqueezeAndUnsqueezeWithViewPass,
)
from executorch.backends.cadence.aot.quantizer.fusion_pass import QuantFusion
from executorch.backends.cadence.aot.quantizer.quantizer import (
CadenceAtenQuantizer,
CadenceQuantizer,
)
from executorch.backends.cadence.aot.quantizer.quantizer import CadenceQuantizer
from executorch.backends.cadence.aot.utils import model_is_quantized
from executorch.backends.transforms.decompose_sdpa import (
DecomposeScaledDotProductAttention,
)
from executorch.backends.transforms.remove_clone_ops import RemoveCloneOpsTransform
from executorch.exir import EdgeCompileConfig, EdgeProgramManager, to_edge
from pyre_extensions import assert_is_instance
from torch._export import capture_pre_autograd_graph
from torch.ao.quantization.pt2e.export_utils import model_is_exported
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
Expand Down Expand Up @@ -63,10 +63,8 @@ def quantize_pt2(
converted_model = convert_pt2e(prepared_model)

# Get patterns and apply fusion of dq -> op -> q to qop
patterns = [
assert_is_instance(q, CadenceAtenQuantizer).pattern
for q in quantizer.quantizers
]
# pyre-ignore[16]: no attribute
patterns = [q.pattern for q in quantizer.quantizers]
QuantFusion(patterns)(converted_model)

return converted_model
Expand Down Expand Up @@ -148,8 +146,12 @@ def export_to_cadence(
# Run a couple required passes for quant/dequant ops
cadence_program_manager = edge_program_manager.transform(
[
InitializePipeline(),
RemoveZeroSizedCatArgsPass(),
ReplaceLogicalNotBooleanWhereWithWherePass(),
ReplaceScalarTensorWithFullPass(),
RemoveCloneOpsTransform(),
RemoveNopExpandOpPass(),
ReplaceSqueezeAndUnsqueezeWithViewPass(),
ReplacePT2QuantWithCadenceQuantPass(),
ReplacePT2DequantWithCadenceDequantPass(),
Expand Down
20 changes: 20 additions & 0 deletions backends/cadence/aot/functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -62,16 +62,31 @@
- arg_meta: null
kernel_name: torch::executor::full_out

- op: mean.out
kernels:
- arg_meta: null
kernel_name: torch::executor::mean_dim_out

- op: mul.out
kernels:
- arg_meta: null
kernel_name: torch::executor::mul_out

- op: mul.Scalar_out
kernels:
- arg_meta: null
kernel_name: torch::executor::mul_scalar_out

- op: permute_copy.out
kernels:
- arg_meta: null
kernel_name: torch::executor::permute_copy_out

- op: rsqrt.out
kernels:
- arg_meta: null
kernel_name: torch::executor::rsqrt_out

- op: sigmoid.out
kernels:
- arg_meta: null
Expand Down Expand Up @@ -134,3 +149,8 @@
kernels:
- arg_meta: null
kernel_name: impl::reference::quantized_relu_out

func: cadence::quantized_matmul.out(Tensor X, int X_zero_point, Tensor Y, int Y_zero_point, Tensor? bias, int out_multiplier, int out_shift, int out_zero_point, bool transposed, *, Tensor(a!) out) -> Tensor(a!)
kernels:
- arg_meta: null
kernel_name: impl::reference::quantized_matmul_out
103 changes: 98 additions & 5 deletions backends/cadence/aot/passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,19 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Any, Dict, Tuple
# pyre-strict

from typing import Any, cast, Dict, Sequence, Tuple

import torch
from executorch.backends.cadence.aot.utils import get_edge_overload_packet
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass, NodeMetadata, ProxyValue
from executorch.exir.pass_base import ExportPass, NodeMetadata, PassResult, ProxyValue
from executorch.exir.passes import dead_code_elimination_pass
from executorch.exir.passes.spec_prop_pass import SpecPropPass
from torch._subclasses import FakeTensor
from torch.utils._pytree import tree_map_only


# pyre-strict

# Similar to what's done in executorch/exir/pass_base.py
Argument = Any # pyre-ignore

Expand Down Expand Up @@ -173,3 +174,95 @@ def call_operator(
init_args[0] = new_args
args = tuple(args)
return super().call_operator(op, args, kwargs, meta)


class RemoveNopExpandOpPass(ExportPass):
"""
For an expand op, if the operator shape matches the expand shape, then the
expand is a nop.
"""

def call_operator(
self,
op, # pyre-ignore
args: tuple[Argument, ...],
kwargs: dict[str, Argument],
meta: NodeMetadata,
) -> ProxyValue:
if get_edge_overload_packet(op) not in {
exir_ops.edge.aten.expand_copy,
exir_ops.edge.aten.expand,
}:
return super().call_operator(op, args, kwargs, meta)

# Parse the args, and check for nop condition
arg0 = cast(ProxyValue, args[0])
arg1 = cast(Sequence[int], args[1])
in_tensor = arg0.to_tensor()
if list(in_tensor.shape) == list(arg1):
return arg0

return super().call_operator(op, args, kwargs, meta)


class ReplaceLogicalNotBooleanWhereWithWherePass(ExportPass):
"""
A where op with a logical_not and a boolean tensor can be replaced
by a where op with flipped inputs and the initial boolean tensor.
"""

def replace_logical_nop_where_with_where(
self, graph_module: torch.fx.GraphModule
) -> None:
graph = graph_module.graph
for node in graph.nodes:
# We are only interested in where nodes
if node.target != exir_ops.edge.aten.where.self:
continue

# If the third arg is not a logical_not, bail.
if node.args[0].target != exir_ops.edge.aten.logical_not.default:
continue

# Get the third arg node and its input
logical_not_node = node.args[0]
logical_not_input_tensor = (
logical_not_node.args[0].to_tensor()
if isinstance(logical_not_node.args[0], ProxyValue)
else logical_not_node.args[0]
)

# If the logical_not input is not a boolean tensor, bail.
if logical_not_input_tensor.meta["spec"].dtype != torch.bool:
continue

# Replace the where op with another one, flipping the inputs and using the boolean
# tensor from logical_not.
with graph.inserting_before(node):
linear_node = graph.call_function(
exir_ops.edge.aten.where.self,
args=(logical_not_node.args[0], node.args[2], node.args[1]),
)
# Replace all the uses
node.replace_all_uses_with(linear_node)

graph_module.recompile()
graph_module.graph.eliminate_dead_code()

def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
self.replace_logical_nop_where_with_where(graph_module)
result = super().call(graph_module)
return result


class InitializePipeline(ExportPass):
"""
Initialize the Jarvis pipeline. This should invariably be the first pass to
run.
"""

def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
dead_code_elimination_pass(graph_module)
result = SpecPropPass()(graph_module)
assert result is not None
return result
24 changes: 11 additions & 13 deletions backends/cadence/aot/quantizer/quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
is_annotated,
no_outside_users,
)
from pyre_extensions import assert_is_instance

from torch import fx

Expand Down Expand Up @@ -100,14 +99,11 @@ def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
continue

for output, *custom_spec in anchors.output:
assert_is_instance(output, fx.Node).meta["quantization_annotation"] = (
QuantizationAnnotation(
# pyre-ignore[6]: incompatible parameter type
output_qspec=(
custom_spec[0] if custom_spec else output_act_qspec
),
_annotated=True,
)
# pyre-ignore[16]: no attribute
output.meta["quantization_annotation"] = QuantizationAnnotation(
# pyre-ignore[6]: incompatible parameter type
output_qspec=(custom_spec[0] if custom_spec else output_act_qspec),
_annotated=True,
)

def annotate_inputs(
Expand All @@ -118,16 +114,18 @@ def annotate_inputs(
spec: Optional[QuantizationSpec],
) -> None:
for node, idx, *custom_spec in inputs:
_node = assert_is_instance(node, fx.Node)
annotation = _node.meta.get(
# pyre-ignore[16]: no attribute
annotation = node.meta.get(
"quantization_annotation",
QuantizationAnnotation(_annotated=True),
)
# pyre-ignore[6]: incompatible parameter type
annotation.input_qspec_map[_node.args[idx]] = (
# pyre-ignore[16]: no attribute
annotation.input_qspec_map[node.args[idx]] = (
custom_spec[0] if custom_spec else spec
)
_node.meta["quantization_annotation"] = annotation
# pyre-ignore[16]: no attribute
node.meta["quantization_annotation"] = annotation

annotate_inputs(anchors.inputs, input_act_qspec)
annotate_inputs(anchors.weights, weight_qspec)
Expand Down
6 changes: 5 additions & 1 deletion backends/cadence/reference/operators/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,15 @@ set(_aten_ops__srcs
"${EXECUTORCH_ROOT}/kernels/portable/cpu/util/matmul_ops_util.cpp"
"${EXECUTORCH_ROOT}/kernels/portable/cpu/util/reduce_util.cpp"
"${EXECUTORCH_ROOT}/kernels/portable/cpu/util/repeat_util.cpp"
"${EXECUTORCH_ROOT}/kernels/portable/cpu/pattern/unary_ufunc_realhb_to_floath.cpp"
"${EXECUTORCH_ROOT}/kernels/portable/cpu/op_bmm.cpp"
"${EXECUTORCH_ROOT}/kernels/portable/cpu/op_cat.cpp"
"${EXECUTORCH_ROOT}/kernels/portable/cpu/op_clone.cpp"
"${EXECUTORCH_ROOT}/kernels/portable/cpu/op_div.cpp"
"${EXECUTORCH_ROOT}/kernels/portable/cpu/op_mean.cpp"
"${EXECUTORCH_ROOT}/kernels/portable/cpu/op_mul.cpp"
"${EXECUTORCH_ROOT}/kernels/portable/cpu/op_permute_copy.cpp"
"${EXECUTORCH_ROOT}/kernels/portable/cpu/op_rsqrt.cpp"
"${EXECUTORCH_ROOT}/kernels/portable/cpu/op_sigmoid.cpp"
"${EXECUTORCH_ROOT}/kernels/portable/cpu/op_slice_copy.cpp"
"${EXECUTORCH_ROOT}/kernels/portable/cpu/op_softmax.cpp"
Expand All @@ -60,7 +63,8 @@ target_include_directories(aten_ops_cadence PUBLIC ${ROOT_DIR}/..
add_library(
custom_ops "quantized_linear_out.cpp" "quantized_conv_out.cpp"
"quantized_relu_out.cpp" "quantized_layer_norm.cpp"
"quantize_per_tensor.cpp" "dequantize_per_tensor.cpp")
"quantize_per_tensor.cpp" "dequantize_per_tensor.cpp"
"quantized_matmul_out.cpp")
target_include_directories(custom_ops PUBLIC ${ROOT_DIR}/..
${CMAKE_BINARY_DIR}
${_common_include_directories})
Expand Down
42 changes: 22 additions & 20 deletions backends/cadence/reference/operators/quantized_matmul_out.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ namespace impl {
namespace reference {
namespace native {

using Tensor = exec_aten::Tensor;
using RuntimeContext = torch::executor::RuntimeContext;

// The quantized matmul. The quantized matmul accumulates in a wider register,
// whose type is TA.
template <
Expand Down Expand Up @@ -50,27 +53,32 @@ __attribute__((noinline)) void qmatmul(
}
}

template <ctype>
template <typename T>
void inline _typed_quantized_matmul(
const Tensor& X,
int64_t X_zero_point,
const Tensor& Y,
int64_t Y_zero_point,
const c10::optional<Tensor>& bias,
const exec_aten::optional<Tensor>& bias,
int64_t out_multiplier,
int64_t out_shift,
int64_t out_zero_point,
bool transposed,
Tensor& out) {
ctype* __restrict__ out_data = out.mutable_data_ptr<ctype>();
const ctype* __restrict__ X_data = X.const_data_ptr<ctype>();
const ctype* __restrict__ Y_data = Y.const_data_ptr<ctype>();
size_t batch_size = getLeadingDims(X, X.dim() - 2);
size_t leading_dim = X.size(X.dim() - 2);
size_t out_dim = Y.size(Y.dim() - 1 - transposed);
size_t in_dim = X.size(X.dim() - 1);

T* __restrict__ out_data = out.mutable_data_ptr<T>();
const T* __restrict__ X_data = X.const_data_ptr<T>();
const T* __restrict__ Y_data = Y.const_data_ptr<T>();
for (size_t i = 0; i < batch_size; ++i) {
const ctype* x = X_data + i * leading_dim * in_dim;
const ctype* y = Y_data + i * in_dim * out_dim;
ctype* z = out_data + i * leading_dim * out_dim;
const T* x = X_data + i * leading_dim * in_dim;
const T* y = Y_data + i * in_dim * out_dim;
T* z = out_data + i * leading_dim * out_dim;
if (transposed) {
qmatmul<ctype, int32_t, true>(
qmatmul<T, int32_t, true>(
z,
static_cast<int32_t>(out_multiplier),
static_cast<int32_t>(out_shift),
Expand All @@ -83,7 +91,7 @@ void inline _typed_quantized_matmul(
in_dim,
out_dim);
} else {
qmatmul<ctype, int32_t, false>(
qmatmul<T, int32_t, false>(
z,
static_cast<int32_t>(out_multiplier),
static_cast<int32_t>(out_shift),
Expand All @@ -101,24 +109,18 @@ void inline _typed_quantized_matmul(
}

void quantized_matmul_out(
RuntimeContext& ctx,
const Tensor& X,
int64_t X_zero_point,
const Tensor& Y,
int64_t Y_zero_point,
const c10::optional<Tensor>& bias,
const exec_aten::optional<Tensor>& bias,
int64_t out_multiplier,
int64_t out_shift,
int64_t out_zero_point,
bool transposed,
Tensor& out) {
(void)bias;

size_t batch_size = getLeadingDims(X, X.dim() - 2);
size_t leading_dim = X.size(X.dim() - 2);
size_t out_dim = Y.size(Y.dim() - 1 - transposed);
size_t in_dim = X.size(X.dim() - 1);

if (out.ScalarType() == at::ScalarType::Byte) {
if (out.scalar_type() == at::ScalarType::Byte) {
_typed_quantized_matmul<uint8_t>(
X,
X_zero_point,
Expand All @@ -130,7 +132,7 @@ void quantized_matmul_out(
out_zero_point,
transposed,
out);
} else if (out.ScalarType() == at::ScalarType::Char) {
} else if (out.scalar_type() == at::ScalarType::Char) {
_typed_quantized_matmul<int8_t>(
X,
X_zero_point,
Expand Down
Loading
Loading