diff --git a/backends/cadence/aot/compiler.py b/backends/cadence/aot/compiler.py index 302252c42a9..39511ae9177 100644 --- a/backends/cadence/aot/compiler.py +++ b/backends/cadence/aot/compiler.py @@ -11,23 +11,23 @@ import torch from executorch.backends.cadence.aot.passes import ( + InitializePipeline, + RemoveNopExpandOpPass, RemoveZeroSizedCatArgsPass, + ReplaceLogicalNotBooleanWhereWithWherePass, ReplacePT2DequantWithCadenceDequantPass, ReplacePT2QuantWithCadenceQuantPass, ReplaceScalarTensorWithFullPass, ReplaceSqueezeAndUnsqueezeWithViewPass, ) from executorch.backends.cadence.aot.quantizer.fusion_pass import QuantFusion -from executorch.backends.cadence.aot.quantizer.quantizer import ( - CadenceAtenQuantizer, - CadenceQuantizer, -) +from executorch.backends.cadence.aot.quantizer.quantizer import CadenceQuantizer from executorch.backends.cadence.aot.utils import model_is_quantized from executorch.backends.transforms.decompose_sdpa import ( DecomposeScaledDotProductAttention, ) +from executorch.backends.transforms.remove_clone_ops import RemoveCloneOpsTransform from executorch.exir import EdgeCompileConfig, EdgeProgramManager, to_edge -from pyre_extensions import assert_is_instance from torch._export import capture_pre_autograd_graph from torch.ao.quantization.pt2e.export_utils import model_is_exported from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e @@ -63,10 +63,8 @@ def quantize_pt2( converted_model = convert_pt2e(prepared_model) # Get patterns and apply fusion of dq -> op -> q to qop - patterns = [ - assert_is_instance(q, CadenceAtenQuantizer).pattern - for q in quantizer.quantizers - ] + # pyre-ignore[16]: no attribute + patterns = [q.pattern for q in quantizer.quantizers] QuantFusion(patterns)(converted_model) return converted_model @@ -148,8 +146,12 @@ def export_to_cadence( # Run a couple required passes for quant/dequant ops cadence_program_manager = edge_program_manager.transform( [ + InitializePipeline(), RemoveZeroSizedCatArgsPass(), + ReplaceLogicalNotBooleanWhereWithWherePass(), ReplaceScalarTensorWithFullPass(), + RemoveCloneOpsTransform(), + RemoveNopExpandOpPass(), ReplaceSqueezeAndUnsqueezeWithViewPass(), ReplacePT2QuantWithCadenceQuantPass(), ReplacePT2DequantWithCadenceDequantPass(), diff --git a/backends/cadence/aot/functions.yaml b/backends/cadence/aot/functions.yaml index f79d5f870da..dbfe1e3639c 100644 --- a/backends/cadence/aot/functions.yaml +++ b/backends/cadence/aot/functions.yaml @@ -62,16 +62,31 @@ - arg_meta: null kernel_name: torch::executor::full_out +- op: mean.out + kernels: + - arg_meta: null + kernel_name: torch::executor::mean_dim_out + - op: mul.out kernels: - arg_meta: null kernel_name: torch::executor::mul_out +- op: mul.Scalar_out + kernels: + - arg_meta: null + kernel_name: torch::executor::mul_scalar_out + - op: permute_copy.out kernels: - arg_meta: null kernel_name: torch::executor::permute_copy_out +- op: rsqrt.out + kernels: + - arg_meta: null + kernel_name: torch::executor::rsqrt_out + - op: sigmoid.out kernels: - arg_meta: null @@ -134,3 +149,8 @@ kernels: - arg_meta: null kernel_name: impl::reference::quantized_relu_out + +func: cadence::quantized_matmul.out(Tensor X, int X_zero_point, Tensor Y, int Y_zero_point, Tensor? bias, int out_multiplier, int out_shift, int out_zero_point, bool transposed, *, Tensor(a!) out) -> Tensor(a!) + kernels: + - arg_meta: null + kernel_name: impl::reference::quantized_matmul_out diff --git a/backends/cadence/aot/passes.py b/backends/cadence/aot/passes.py index ca8a44f00c8..db419bfb5e1 100644 --- a/backends/cadence/aot/passes.py +++ b/backends/cadence/aot/passes.py @@ -4,18 +4,19 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import Any, Dict, Tuple +# pyre-strict + +from typing import Any, cast, Dict, Sequence, Tuple import torch from executorch.backends.cadence.aot.utils import get_edge_overload_packet from executorch.exir.dialects._ops import ops as exir_ops -from executorch.exir.pass_base import ExportPass, NodeMetadata, ProxyValue +from executorch.exir.pass_base import ExportPass, NodeMetadata, PassResult, ProxyValue +from executorch.exir.passes import dead_code_elimination_pass +from executorch.exir.passes.spec_prop_pass import SpecPropPass from torch._subclasses import FakeTensor from torch.utils._pytree import tree_map_only - -# pyre-strict - # Similar to what's done in executorch/exir/pass_base.py Argument = Any # pyre-ignore @@ -173,3 +174,95 @@ def call_operator( init_args[0] = new_args args = tuple(args) return super().call_operator(op, args, kwargs, meta) + + +class RemoveNopExpandOpPass(ExportPass): + """ + For an expand op, if the operator shape matches the expand shape, then the + expand is a nop. + """ + + def call_operator( + self, + op, # pyre-ignore + args: tuple[Argument, ...], + kwargs: dict[str, Argument], + meta: NodeMetadata, + ) -> ProxyValue: + if get_edge_overload_packet(op) not in { + exir_ops.edge.aten.expand_copy, + exir_ops.edge.aten.expand, + }: + return super().call_operator(op, args, kwargs, meta) + + # Parse the args, and check for nop condition + arg0 = cast(ProxyValue, args[0]) + arg1 = cast(Sequence[int], args[1]) + in_tensor = arg0.to_tensor() + if list(in_tensor.shape) == list(arg1): + return arg0 + + return super().call_operator(op, args, kwargs, meta) + + +class ReplaceLogicalNotBooleanWhereWithWherePass(ExportPass): + """ + A where op with a logical_not and a boolean tensor can be replaced + by a where op with flipped inputs and the initial boolean tensor. + """ + + def replace_logical_nop_where_with_where( + self, graph_module: torch.fx.GraphModule + ) -> None: + graph = graph_module.graph + for node in graph.nodes: + # We are only interested in where nodes + if node.target != exir_ops.edge.aten.where.self: + continue + + # If the third arg is not a logical_not, bail. + if node.args[0].target != exir_ops.edge.aten.logical_not.default: + continue + + # Get the third arg node and its input + logical_not_node = node.args[0] + logical_not_input_tensor = ( + logical_not_node.args[0].to_tensor() + if isinstance(logical_not_node.args[0], ProxyValue) + else logical_not_node.args[0] + ) + + # If the logical_not input is not a boolean tensor, bail. + if logical_not_input_tensor.meta["spec"].dtype != torch.bool: + continue + + # Replace the where op with another one, flipping the inputs and using the boolean + # tensor from logical_not. + with graph.inserting_before(node): + linear_node = graph.call_function( + exir_ops.edge.aten.where.self, + args=(logical_not_node.args[0], node.args[2], node.args[1]), + ) + # Replace all the uses + node.replace_all_uses_with(linear_node) + + graph_module.recompile() + graph_module.graph.eliminate_dead_code() + + def call(self, graph_module: torch.fx.GraphModule) -> PassResult: + self.replace_logical_nop_where_with_where(graph_module) + result = super().call(graph_module) + return result + + +class InitializePipeline(ExportPass): + """ + Initialize the Jarvis pipeline. This should invariably be the first pass to + run. + """ + + def call(self, graph_module: torch.fx.GraphModule) -> PassResult: + dead_code_elimination_pass(graph_module) + result = SpecPropPass()(graph_module) + assert result is not None + return result diff --git a/backends/cadence/aot/quantizer/quantizer.py b/backends/cadence/aot/quantizer/quantizer.py index 4cd3c6bfb4d..61e414ca10d 100644 --- a/backends/cadence/aot/quantizer/quantizer.py +++ b/backends/cadence/aot/quantizer/quantizer.py @@ -26,7 +26,6 @@ is_annotated, no_outside_users, ) -from pyre_extensions import assert_is_instance from torch import fx @@ -100,14 +99,11 @@ def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: continue for output, *custom_spec in anchors.output: - assert_is_instance(output, fx.Node).meta["quantization_annotation"] = ( - QuantizationAnnotation( - # pyre-ignore[6]: incompatible parameter type - output_qspec=( - custom_spec[0] if custom_spec else output_act_qspec - ), - _annotated=True, - ) + # pyre-ignore[16]: no attribute + output.meta["quantization_annotation"] = QuantizationAnnotation( + # pyre-ignore[6]: incompatible parameter type + output_qspec=(custom_spec[0] if custom_spec else output_act_qspec), + _annotated=True, ) def annotate_inputs( @@ -118,16 +114,18 @@ def annotate_inputs( spec: Optional[QuantizationSpec], ) -> None: for node, idx, *custom_spec in inputs: - _node = assert_is_instance(node, fx.Node) - annotation = _node.meta.get( + # pyre-ignore[16]: no attribute + annotation = node.meta.get( "quantization_annotation", QuantizationAnnotation(_annotated=True), ) # pyre-ignore[6]: incompatible parameter type - annotation.input_qspec_map[_node.args[idx]] = ( + # pyre-ignore[16]: no attribute + annotation.input_qspec_map[node.args[idx]] = ( custom_spec[0] if custom_spec else spec ) - _node.meta["quantization_annotation"] = annotation + # pyre-ignore[16]: no attribute + node.meta["quantization_annotation"] = annotation annotate_inputs(anchors.inputs, input_act_qspec) annotate_inputs(anchors.weights, weight_qspec) diff --git a/backends/cadence/reference/operators/CMakeLists.txt b/backends/cadence/reference/operators/CMakeLists.txt index c22dc0c9976..c81e9348501 100644 --- a/backends/cadence/reference/operators/CMakeLists.txt +++ b/backends/cadence/reference/operators/CMakeLists.txt @@ -32,12 +32,15 @@ set(_aten_ops__srcs "${EXECUTORCH_ROOT}/kernels/portable/cpu/util/matmul_ops_util.cpp" "${EXECUTORCH_ROOT}/kernels/portable/cpu/util/reduce_util.cpp" "${EXECUTORCH_ROOT}/kernels/portable/cpu/util/repeat_util.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/pattern/unary_ufunc_realhb_to_floath.cpp" "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_bmm.cpp" "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_cat.cpp" "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_clone.cpp" "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_div.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_mean.cpp" "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_mul.cpp" "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_permute_copy.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_rsqrt.cpp" "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_sigmoid.cpp" "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_slice_copy.cpp" "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_softmax.cpp" @@ -60,7 +63,8 @@ target_include_directories(aten_ops_cadence PUBLIC ${ROOT_DIR}/.. add_library( custom_ops "quantized_linear_out.cpp" "quantized_conv_out.cpp" "quantized_relu_out.cpp" "quantized_layer_norm.cpp" - "quantize_per_tensor.cpp" "dequantize_per_tensor.cpp") + "quantize_per_tensor.cpp" "dequantize_per_tensor.cpp" + "quantized_matmul_out.cpp") target_include_directories(custom_ops PUBLIC ${ROOT_DIR}/.. ${CMAKE_BINARY_DIR} ${_common_include_directories}) diff --git a/backends/cadence/reference/operators/quantized_matmul_out.cpp b/backends/cadence/reference/operators/quantized_matmul_out.cpp index 95df35caba7..49dd222a96c 100644 --- a/backends/cadence/reference/operators/quantized_matmul_out.cpp +++ b/backends/cadence/reference/operators/quantized_matmul_out.cpp @@ -13,6 +13,9 @@ namespace impl { namespace reference { namespace native { +using Tensor = exec_aten::Tensor; +using RuntimeContext = torch::executor::RuntimeContext; + // The quantized matmul. The quantized matmul accumulates in a wider register, // whose type is TA. template < @@ -50,27 +53,32 @@ __attribute__((noinline)) void qmatmul( } } -template +template void inline _typed_quantized_matmul( const Tensor& X, int64_t X_zero_point, const Tensor& Y, int64_t Y_zero_point, - const c10::optional& bias, + const exec_aten::optional& bias, int64_t out_multiplier, int64_t out_shift, int64_t out_zero_point, bool transposed, Tensor& out) { - ctype* __restrict__ out_data = out.mutable_data_ptr(); - const ctype* __restrict__ X_data = X.const_data_ptr(); - const ctype* __restrict__ Y_data = Y.const_data_ptr(); + size_t batch_size = getLeadingDims(X, X.dim() - 2); + size_t leading_dim = X.size(X.dim() - 2); + size_t out_dim = Y.size(Y.dim() - 1 - transposed); + size_t in_dim = X.size(X.dim() - 1); + + T* __restrict__ out_data = out.mutable_data_ptr(); + const T* __restrict__ X_data = X.const_data_ptr(); + const T* __restrict__ Y_data = Y.const_data_ptr(); for (size_t i = 0; i < batch_size; ++i) { - const ctype* x = X_data + i * leading_dim * in_dim; - const ctype* y = Y_data + i * in_dim * out_dim; - ctype* z = out_data + i * leading_dim * out_dim; + const T* x = X_data + i * leading_dim * in_dim; + const T* y = Y_data + i * in_dim * out_dim; + T* z = out_data + i * leading_dim * out_dim; if (transposed) { - qmatmul( + qmatmul( z, static_cast(out_multiplier), static_cast(out_shift), @@ -83,7 +91,7 @@ void inline _typed_quantized_matmul( in_dim, out_dim); } else { - qmatmul( + qmatmul( z, static_cast(out_multiplier), static_cast(out_shift), @@ -101,24 +109,18 @@ void inline _typed_quantized_matmul( } void quantized_matmul_out( + RuntimeContext& ctx, const Tensor& X, int64_t X_zero_point, const Tensor& Y, int64_t Y_zero_point, - const c10::optional& bias, + const exec_aten::optional& bias, int64_t out_multiplier, int64_t out_shift, int64_t out_zero_point, bool transposed, Tensor& out) { - (void)bias; - - size_t batch_size = getLeadingDims(X, X.dim() - 2); - size_t leading_dim = X.size(X.dim() - 2); - size_t out_dim = Y.size(Y.dim() - 1 - transposed); - size_t in_dim = X.size(X.dim() - 1); - - if (out.ScalarType() == at::ScalarType::Byte) { + if (out.scalar_type() == at::ScalarType::Byte) { _typed_quantized_matmul( X, X_zero_point, @@ -130,7 +132,7 @@ void quantized_matmul_out( out_zero_point, transposed, out); - } else if (out.ScalarType() == at::ScalarType::Char) { + } else if (out.scalar_type() == at::ScalarType::Char) { _typed_quantized_matmul( X, X_zero_point, diff --git a/examples/cadence/models/babyllama.py b/examples/cadence/models/babyllama.py new file mode 100644 index 00000000000..347f9b4a7a7 --- /dev/null +++ b/examples/cadence/models/babyllama.py @@ -0,0 +1,39 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# Example script for exporting simple models to flatbuffer + +import logging + +from executorch.backends.cadence.aot.ops_registrations import * # noqa + +import torch + +from executorch.backends.cadence.aot.export_example import export_model + +from executorch.examples.models.llama2.llama_transformer import ModelArgs, Transformer + + +FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s" +logging.basicConfig(level=logging.INFO, format=FORMAT) + + +if __name__ == "__main__": + + args = ModelArgs( + dim=512, + vocab_size=512, + hidden_dim=1024, + n_heads=8, + # use_kv_cache=True, + n_layers=1, + ) + seq = 64 + b = 1 + model = Transformer(args) + example_inputs = (torch.randint(0, 10, [b, seq], dtype=torch.int64),) + + export_model(model, example_inputs)