Skip to content

Commit

Permalink
Add exportable baby llama example (#4345)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #4345

Add a small LLaMa model, based on the babyllama paper. Note that this test case is only one layer by default, and the number of layers can be adjusted in the test.

Removed some pyre changes that broke the OSS AoT export, and added some required passes and operators.

Reviewed By: dulinriley

Differential Revision: D60073137

fbshipit-source-id: 8379296ad0aa4099b09d033b33479165d7c7c5c9
  • Loading branch information
mcremon-meta authored and facebook-github-bot committed Jul 30, 2024
1 parent db1c4d8 commit 1e14333
Show file tree
Hide file tree
Showing 9 changed files with 212 additions and 51 deletions.
4 changes: 3 additions & 1 deletion backends/cadence/aot/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,13 @@ python_library(
"compiler.py",
],
deps = [
"fbsource//third-party/pypi/pyre-extensions:pyre-extensions",
":passes",
":utils",
"//caffe2:torch",
"//executorch/backends/cadence/aot/quantizer:fusion_pass",
"//executorch/backends/cadence/aot/quantizer:quantizer",
"//executorch/backends/transforms:decompose_sdpa",
"//executorch/backends/transforms:remove_clone_ops",
"//executorch/exir:lib",
],
)
Expand All @@ -49,5 +49,7 @@ python_library(
"//caffe2:torch",
"//executorch/exir:pass_base",
"//executorch/exir/dialects:lib",
"//executorch/exir/passes:lib",
"//executorch/exir/passes:spec_prop_pass",
],
)
20 changes: 11 additions & 9 deletions backends/cadence/aot/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,23 +11,23 @@
import torch

from executorch.backends.cadence.aot.passes import (
InitializePipeline,
RemoveNopExpandOpPass,
RemoveZeroSizedCatArgsPass,
ReplaceLogicalNotBooleanWhereWithWherePass,
ReplacePT2DequantWithCadenceDequantPass,
ReplacePT2QuantWithCadenceQuantPass,
ReplaceScalarTensorWithFullPass,
ReplaceSqueezeAndUnsqueezeWithViewPass,
)
from executorch.backends.cadence.aot.quantizer.fusion_pass import QuantFusion
from executorch.backends.cadence.aot.quantizer.quantizer import (
CadenceAtenQuantizer,
CadenceQuantizer,
)
from executorch.backends.cadence.aot.quantizer.quantizer import CadenceQuantizer
from executorch.backends.cadence.aot.utils import model_is_quantized
from executorch.backends.transforms.decompose_sdpa import (
DecomposeScaledDotProductAttention,
)
from executorch.backends.transforms.remove_clone_ops import RemoveCloneOpsTransform
from executorch.exir import EdgeCompileConfig, EdgeProgramManager, to_edge
from pyre_extensions import assert_is_instance
from torch._export import capture_pre_autograd_graph
from torch.ao.quantization.pt2e.export_utils import model_is_exported
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
Expand Down Expand Up @@ -63,10 +63,8 @@ def quantize_pt2(
converted_model = convert_pt2e(prepared_model)

# Get patterns and apply fusion of dq -> op -> q to qop
patterns = [
assert_is_instance(q, CadenceAtenQuantizer).pattern
for q in quantizer.quantizers
]
# pyre-ignore[16]: no attribute
patterns = [q.pattern for q in quantizer.quantizers]
QuantFusion(patterns)(converted_model)

return converted_model
Expand Down Expand Up @@ -148,8 +146,12 @@ def export_to_cadence(
# Run a couple required passes for quant/dequant ops
cadence_program_manager = edge_program_manager.transform(
[
InitializePipeline(),
RemoveZeroSizedCatArgsPass(),
ReplaceLogicalNotBooleanWhereWithWherePass(),
ReplaceScalarTensorWithFullPass(),
RemoveCloneOpsTransform(),
RemoveNopExpandOpPass(),
ReplaceSqueezeAndUnsqueezeWithViewPass(),
ReplacePT2QuantWithCadenceQuantPass(),
ReplacePT2DequantWithCadenceDequantPass(),
Expand Down
20 changes: 20 additions & 0 deletions backends/cadence/aot/functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -62,16 +62,31 @@
- arg_meta: null
kernel_name: torch::executor::full_out

- op: mean.out
kernels:
- arg_meta: null
kernel_name: torch::executor::mean_dim_out

- op: mul.out
kernels:
- arg_meta: null
kernel_name: torch::executor::mul_out

- op: mul.Scalar_out
kernels:
- arg_meta: null
kernel_name: torch::executor::mul_scalar_out

- op: permute_copy.out
kernels:
- arg_meta: null
kernel_name: torch::executor::permute_copy_out

- op: rsqrt.out
kernels:
- arg_meta: null
kernel_name: torch::executor::rsqrt_out

- op: sigmoid.out
kernels:
- arg_meta: null
Expand Down Expand Up @@ -134,3 +149,8 @@
kernels:
- arg_meta: null
kernel_name: impl::reference::quantized_relu_out

func: cadence::quantized_matmul.out(Tensor X, int X_zero_point, Tensor Y, int Y_zero_point, Tensor? bias, int out_multiplier, int out_shift, int out_zero_point, bool transposed, *, Tensor(a!) out) -> Tensor(a!)
kernels:
- arg_meta: null
kernel_name: impl::reference::quantized_matmul_out
103 changes: 98 additions & 5 deletions backends/cadence/aot/passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,19 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Any, Dict, Tuple
# pyre-strict

from typing import Any, cast, Dict, Sequence, Tuple

import torch
from executorch.backends.cadence.aot.utils import get_edge_overload_packet
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass, NodeMetadata, ProxyValue
from executorch.exir.pass_base import ExportPass, NodeMetadata, PassResult, ProxyValue
from executorch.exir.passes import dead_code_elimination_pass
from executorch.exir.passes.spec_prop_pass import SpecPropPass
from torch._subclasses import FakeTensor
from torch.utils._pytree import tree_map_only


# pyre-strict

# Similar to what's done in executorch/exir/pass_base.py
Argument = Any # pyre-ignore

Expand Down Expand Up @@ -173,3 +174,95 @@ def call_operator(
init_args[0] = new_args
args = tuple(args)
return super().call_operator(op, args, kwargs, meta)


class RemoveNopExpandOpPass(ExportPass):
"""
For an expand op, if the operator shape matches the expand shape, then the
expand is a nop.
"""

def call_operator(
self,
op, # pyre-ignore
args: tuple[Argument, ...],
kwargs: dict[str, Argument],
meta: NodeMetadata,
) -> ProxyValue:
if get_edge_overload_packet(op) not in {
exir_ops.edge.aten.expand_copy,
exir_ops.edge.aten.expand,
}:
return super().call_operator(op, args, kwargs, meta)

# Parse the args, and check for nop condition
arg0 = cast(ProxyValue, args[0])
arg1 = cast(Sequence[int], args[1])
in_tensor = arg0.to_tensor()
if list(in_tensor.shape) == list(arg1):
return arg0

return super().call_operator(op, args, kwargs, meta)


class ReplaceLogicalNotBooleanWhereWithWherePass(ExportPass):
"""
A where op with a logical_not and a boolean tensor can be replaced
by a where op with flipped inputs and the initial boolean tensor.
"""

def replace_logical_nop_where_with_where(
self, graph_module: torch.fx.GraphModule
) -> None:
graph = graph_module.graph
for node in graph.nodes:
# We are only interested in where nodes
if node.target != exir_ops.edge.aten.where.self:
continue

# If the third arg is not a logical_not, bail.
if node.args[0].target != exir_ops.edge.aten.logical_not.default:
continue

# Get the third arg node and its input
logical_not_node = node.args[0]
logical_not_input_tensor = (
logical_not_node.args[0].to_tensor()
if isinstance(logical_not_node.args[0], ProxyValue)
else logical_not_node.args[0]
)

# If the logical_not input is not a boolean tensor, bail.
if logical_not_input_tensor.meta["spec"].dtype != torch.bool:
continue

# Replace the where op with another one, flipping the inputs and using the boolean
# tensor from logical_not.
with graph.inserting_before(node):
linear_node = graph.call_function(
exir_ops.edge.aten.where.self,
args=(logical_not_node.args[0], node.args[2], node.args[1]),
)
# Replace all the uses
node.replace_all_uses_with(linear_node)

graph_module.recompile()
graph_module.graph.eliminate_dead_code()

def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
self.replace_logical_nop_where_with_where(graph_module)
result = super().call(graph_module)
return result


class InitializePipeline(ExportPass):
"""
Initialize the Jarvis pipeline. This should invariably be the first pass to
run.
"""

def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
dead_code_elimination_pass(graph_module)
result = SpecPropPass()(graph_module)
assert result is not None
return result
1 change: 0 additions & 1 deletion backends/cadence/aot/quantizer/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ python_library(
],
typing = True,
deps = [
"fbsource//third-party/pypi/pyre-extensions:pyre-extensions",
":patterns",
":utils",
"//caffe2:torch",
Expand Down
25 changes: 11 additions & 14 deletions backends/cadence/aot/quantizer/quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
is_annotated,
no_outside_users,
)
from pyre_extensions import assert_is_instance

from torch import fx

Expand Down Expand Up @@ -100,14 +99,11 @@ def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
continue

for output, *custom_spec in anchors.output:
assert_is_instance(output, fx.Node).meta["quantization_annotation"] = (
QuantizationAnnotation(
# pyre-ignore[6]: incompatible parameter type
output_qspec=(
custom_spec[0] if custom_spec else output_act_qspec
),
_annotated=True,
)
# pyre-ignore[16]: no attribute
output.meta["quantization_annotation"] = QuantizationAnnotation(
# pyre-ignore[6]: incompatible parameter type
output_qspec=(custom_spec[0] if custom_spec else output_act_qspec),
_annotated=True,
)

def annotate_inputs(
Expand All @@ -118,16 +114,17 @@ def annotate_inputs(
spec: Optional[QuantizationSpec],
) -> None:
for node, idx, *custom_spec in inputs:
_node = assert_is_instance(node, fx.Node)
annotation = _node.meta.get(
# pyre-ignore[16]: no attribute
annotation = node.meta.get(
"quantization_annotation",
QuantizationAnnotation(_annotated=True),
)
# pyre-ignore[6]: incompatible parameter type
annotation.input_qspec_map[_node.args[idx]] = (
# pyre-ignore[16]: no attribute
annotation.input_qspec_map[node.args[idx]] = (
custom_spec[0] if custom_spec else spec
)
_node.meta["quantization_annotation"] = annotation
# pyre-ignore[16]: no attribute
node.meta["quantization_annotation"] = annotation

annotate_inputs(anchors.inputs, input_act_qspec)
annotate_inputs(anchors.weights, weight_qspec)
Expand Down
6 changes: 5 additions & 1 deletion backends/cadence/reference/operators/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,15 @@ set(_aten_ops__srcs
"${EXECUTORCH_ROOT}/kernels/portable/cpu/util/matmul_ops_util.cpp"
"${EXECUTORCH_ROOT}/kernels/portable/cpu/util/reduce_util.cpp"
"${EXECUTORCH_ROOT}/kernels/portable/cpu/util/repeat_util.cpp"
"${EXECUTORCH_ROOT}/kernels/portable/cpu/pattern/unary_ufunc_realhb_to_floath.cpp"
"${EXECUTORCH_ROOT}/kernels/portable/cpu/op_bmm.cpp"
"${EXECUTORCH_ROOT}/kernels/portable/cpu/op_cat.cpp"
"${EXECUTORCH_ROOT}/kernels/portable/cpu/op_clone.cpp"
"${EXECUTORCH_ROOT}/kernels/portable/cpu/op_div.cpp"
"${EXECUTORCH_ROOT}/kernels/portable/cpu/op_mean.cpp"
"${EXECUTORCH_ROOT}/kernels/portable/cpu/op_mul.cpp"
"${EXECUTORCH_ROOT}/kernels/portable/cpu/op_permute_copy.cpp"
"${EXECUTORCH_ROOT}/kernels/portable/cpu/op_rsqrt.cpp"
"${EXECUTORCH_ROOT}/kernels/portable/cpu/op_sigmoid.cpp"
"${EXECUTORCH_ROOT}/kernels/portable/cpu/op_slice_copy.cpp"
"${EXECUTORCH_ROOT}/kernels/portable/cpu/op_softmax.cpp"
Expand All @@ -60,7 +63,8 @@ target_include_directories(aten_ops_cadence PUBLIC ${ROOT_DIR}/..
add_library(
custom_ops "quantized_linear_out.cpp" "quantized_conv_out.cpp"
"quantized_relu_out.cpp" "quantized_layer_norm.cpp"
"quantize_per_tensor.cpp" "dequantize_per_tensor.cpp")
"quantize_per_tensor.cpp" "dequantize_per_tensor.cpp"
"quantized_matmul_out.cpp")
target_include_directories(custom_ops PUBLIC ${ROOT_DIR}/..
${CMAKE_BINARY_DIR}
${_common_include_directories})
Expand Down
Loading

0 comments on commit 1e14333

Please sign in to comment.