Skip to content

Commit

Permalink
Add exportable baby llama example (#4345)
Browse files Browse the repository at this point in the history
Summary:

Add a small LLaMa model, based on the babyllama paper. Note that this test case is only one layer by default, and the number of layers can be adjusted in the test.

Removed some pyre changes that broke the OSS AoT export, and added some required passes and operators.

Differential Revision: D60073137
  • Loading branch information
mcremon-meta authored and facebook-github-bot committed Jul 24, 2024
1 parent fd2dccf commit f623ca8
Show file tree
Hide file tree
Showing 6 changed files with 171 additions and 27 deletions.
20 changes: 11 additions & 9 deletions backends/cadence/aot/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,23 +11,23 @@
import torch

from executorch.backends.cadence.aot.passes import (
InitializePipeline,
RemoveNopExpandOpPass,
RemoveZeroSizedCatArgsPass,
ReplaceLogicalNotBooleanWhereWithWherePass,
ReplacePT2DequantWithCadenceDequantPass,
ReplacePT2QuantWithCadenceQuantPass,
ReplaceScalarTensorWithFullPass,
ReplaceSqueezeAndUnsqueezeWithViewPass,
)
from executorch.backends.cadence.aot.quantizer.fusion_pass import QuantFusion
from executorch.backends.cadence.aot.quantizer.quantizer import (
CadenceAtenQuantizer,
CadenceQuantizer,
)
from executorch.backends.cadence.aot.quantizer.quantizer import CadenceQuantizer
from executorch.backends.cadence.aot.utils import model_is_quantized
from executorch.backends.transforms.decompose_sdpa import (
DecomposeScaledDotProductAttention,
)
from executorch.backends.transforms.remove_clone_ops import RemoveCloneOpsTransform
from executorch.exir import EdgeCompileConfig, EdgeProgramManager, to_edge
from pyre_extensions import assert_is_instance
from torch._export import capture_pre_autograd_graph
from torch.ao.quantization.pt2e.export_utils import model_is_exported
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
Expand Down Expand Up @@ -63,10 +63,8 @@ def quantize_pt2(
converted_model = convert_pt2e(prepared_model)

# Get patterns and apply fusion of dq -> op -> q to qop
patterns = [
assert_is_instance(q, CadenceAtenQuantizer).pattern
for q in quantizer.quantizers
]
# pyre-ignore[16]: no attribute
patterns = [q.pattern for q in quantizer.quantizers]
QuantFusion(patterns)(converted_model)

return converted_model
Expand Down Expand Up @@ -148,8 +146,12 @@ def export_to_cadence(
# Run a couple required passes for quant/dequant ops
cadence_program_manager = edge_program_manager.transform(
[
InitializePipeline(),
RemoveZeroSizedCatArgsPass(),
ReplaceLogicalNotBooleanWhereWithWherePass(),
ReplaceScalarTensorWithFullPass(),
RemoveCloneOpsTransform(),
RemoveNopExpandOpPass(),
ReplaceSqueezeAndUnsqueezeWithViewPass(),
ReplacePT2QuantWithCadenceQuantPass(),
ReplacePT2DequantWithCadenceDequantPass(),
Expand Down
10 changes: 10 additions & 0 deletions backends/cadence/aot/functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,11 @@
- arg_meta: null
kernel_name: torch::executor::full_out

- op: mean.out
kernels:
- arg_meta: null
kernel_name: torch::executor::mean_out

- op: mul.out
kernels:
- arg_meta: null
Expand All @@ -72,6 +77,11 @@
- arg_meta: null
kernel_name: torch::executor::permute_copy_out

- op: rsqrt.out
kernels:
- arg_meta: null
kernel_name: torch::executor::rsqrt_out

- op: sigmoid.out
kernels:
- arg_meta: null
Expand Down
103 changes: 98 additions & 5 deletions backends/cadence/aot/passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,19 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Any, Dict, Tuple
# pyre-strict

from typing import Any, cast, Dict, Sequence, Tuple

import torch
from executorch.backends.cadence.aot.utils import get_edge_overload_packet
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass, NodeMetadata, ProxyValue
from executorch.exir.pass_base import ExportPass, NodeMetadata, PassResult, ProxyValue
from executorch.exir.passes import dead_code_elimination_pass
from executorch.exir.passes.spec_prop_pass import SpecPropPass
from torch._subclasses import FakeTensor
from torch.utils._pytree import tree_map_only


# pyre-strict

# Similar to what's done in executorch/exir/pass_base.py
Argument = Any # pyre-ignore

Expand Down Expand Up @@ -173,3 +174,95 @@ def call_operator(
init_args[0] = new_args
args = tuple(args)
return super().call_operator(op, args, kwargs, meta)


class RemoveNopExpandOpPass(ExportPass):
"""
For an expand op, if the operator shape matches the expand shape, then the
expand is a nop.
"""

def call_operator(
self,
op, # pyre-ignore
args: tuple[Argument, ...],
kwargs: dict[str, Argument],
meta: NodeMetadata,
) -> ProxyValue:
if get_edge_overload_packet(op) not in {
exir_ops.edge.aten.expand_copy,
exir_ops.edge.aten.expand,
}:
return super().call_operator(op, args, kwargs, meta)

# Parse the args, and check for nop condition
arg0 = cast(ProxyValue, args[0])
arg1 = cast(Sequence[int], args[1])
in_tensor = arg0.to_tensor()
if list(in_tensor.shape) == list(arg1):
return arg0

return super().call_operator(op, args, kwargs, meta)


class ReplaceLogicalNotBooleanWhereWithWherePass(ExportPass):
"""
A where op with a logical_not and a boolean tensor can be replaced
by a where op with flipped inputs and the initial boolean tensor.
"""

def replace_logical_nop_where_with_where(
self, graph_module: torch.fx.GraphModule
) -> None:
graph = graph_module.graph
for node in graph.nodes:
# We are only interested in where nodes
if node.target != exir_ops.edge.aten.where.self:
continue

# If the third arg is not a logical_not, bail.
if node.args[0].target != exir_ops.edge.aten.logical_not.default:
continue

# Get the third arg node and its input
logical_not_node = node.args[0]
logical_not_input_tensor = (
logical_not_node.args[0].to_tensor()
if isinstance(logical_not_node.args[0], ProxyValue)
else logical_not_node.args[0]
)

# If the logical_not input is not a boolean tensor, bail.
if logical_not_input_tensor.meta["spec"].dtype != torch.bool:
continue

# Replace the where op with another one, flipping the inputs and using the boolean
# tensor from logical_not.
with graph.inserting_before(node):
linear_node = graph.call_function(
exir_ops.edge.aten.where.self,
args=(logical_not_node.args[0], node.args[2], node.args[1]),
)
# Replace all the uses
node.replace_all_uses_with(linear_node)

graph_module.recompile()
graph_module.graph.eliminate_dead_code()

def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
self.replace_logical_nop_where_with_where(graph_module)
result = super().call(graph_module)
return result


class InitializePipeline(ExportPass):
"""
Initialize the Jarvis pipeline. This should invariably be the first pass to
run.
"""

def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
dead_code_elimination_pass(graph_module)
result = SpecPropPass()(graph_module)
assert result is not None
return result
24 changes: 11 additions & 13 deletions backends/cadence/aot/quantizer/quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
is_annotated,
no_outside_users,
)
from pyre_extensions import assert_is_instance

from torch import fx

Expand Down Expand Up @@ -100,14 +99,11 @@ def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
continue

for output, *custom_spec in anchors.output:
assert_is_instance(output, fx.Node).meta["quantization_annotation"] = (
QuantizationAnnotation(
# pyre-ignore[6]: incompatible parameter type
output_qspec=(
custom_spec[0] if custom_spec else output_act_qspec
),
_annotated=True,
)
# pyre-ignore[16]: no attribute
output.meta["quantization_annotation"] = QuantizationAnnotation(
# pyre-ignore[6]: incompatible parameter type
output_qspec=(custom_spec[0] if custom_spec else output_act_qspec),
_annotated=True,
)

def annotate_inputs(
Expand All @@ -118,16 +114,18 @@ def annotate_inputs(
spec: Optional[QuantizationSpec],
) -> None:
for node, idx, *custom_spec in inputs:
_node = assert_is_instance(node, fx.Node)
annotation = _node.meta.get(
# pyre-ignore[16]: no attribute
annotation = node.meta.get(
"quantization_annotation",
QuantizationAnnotation(_annotated=True),
)
# pyre-ignore[6]: incompatible parameter type
annotation.input_qspec_map[_node.args[idx]] = (
# pyre-ignore[16]: no attribute
annotation.input_qspec_map[node.args[idx]] = (
custom_spec[0] if custom_spec else spec
)
_node.meta["quantization_annotation"] = annotation
# pyre-ignore[16]: no attribute
node.meta["quantization_annotation"] = annotation

annotate_inputs(anchors.inputs, input_act_qspec)
annotate_inputs(anchors.weights, weight_qspec)
Expand Down
2 changes: 2 additions & 0 deletions backends/cadence/reference/operators/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,10 @@ set(_aten_ops__srcs
"${EXECUTORCH_ROOT}/kernels/portable/cpu/op_cat.cpp"
"${EXECUTORCH_ROOT}/kernels/portable/cpu/op_clone.cpp"
"${EXECUTORCH_ROOT}/kernels/portable/cpu/op_div.cpp"
"${EXECUTORCH_ROOT}/kernels/portable/cpu/op_mean.cpp"
"${EXECUTORCH_ROOT}/kernels/portable/cpu/op_mul.cpp"
"${EXECUTORCH_ROOT}/kernels/portable/cpu/op_permute_copy.cpp"
"${EXECUTORCH_ROOT}/kernels/portable/cpu/op_rsqrt.cpp"
"${EXECUTORCH_ROOT}/kernels/portable/cpu/op_sigmoid.cpp"
"${EXECUTORCH_ROOT}/kernels/portable/cpu/op_slice_copy.cpp"
"${EXECUTORCH_ROOT}/kernels/portable/cpu/op_softmax.cpp"
Expand Down
39 changes: 39 additions & 0 deletions examples/cadence/models/babyllama.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# Example script for exporting simple models to flatbuffer

import logging

from executorch.backends.cadence.aot.ops_registrations import * # noqa

import torch

from executorch.backends.cadence.aot.export_example import export_model

from executorch.examples.models.llama2.llama_transformer import ModelArgs, Transformer


FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s"
logging.basicConfig(level=logging.INFO, format=FORMAT)


if __name__ == "__main__":

args = ModelArgs(
dim=512,
vocab_size=512,
hidden_dim=1024,
n_heads=8,
# use_kv_cache=True,
n_layers=1,
)
seq = 64
b = 1
model = Transformer(args)
example_inputs = (torch.randint(0, 10, [b, seq], dtype=torch.int64),)

export_model(model, example_inputs)

0 comments on commit f623ca8

Please sign in to comment.