Skip to content

Commit

Permalink
Clean up lowering APIs (#3690)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #3690

Changes some types in `export_to_edge`, add an `export_to_cadence` function, and refactor some calls. Mostly preemptive work before adding a lot of AoT passes.

Reviewed By: dulinriley, zonglinpengmeta

Differential Revision: D57579738

fbshipit-source-id: 8c7975488636bd946894f7f7e62da7e76eb29c65
  • Loading branch information
mcremon-meta authored and facebook-github-bot committed May 23, 2024
1 parent 91bf5b9 commit ed2da4f
Show file tree
Hide file tree
Showing 4 changed files with 76 additions and 39 deletions.
32 changes: 32 additions & 0 deletions backends/cadence/aot/TARGETS
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

load("@fbcode_macros//build_defs:python_library.bzl", "python_library")

oncall("odai_jarvis")

python_library(
name = "compiler",
srcs = [
"compiler.py",
],
deps = [
":passes",
"//caffe2:torch",
"//executorch/exir:lib",
],
)

python_library(
name = "passes",
srcs = [
"passes.py",
],
deps = [
"//executorch/exir:pass_base",
"//executorch/exir/dialects:lib",
],
)
48 changes: 38 additions & 10 deletions backends/cadence/aot/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,20 @@

import torch

from executorch.backends.cadence.aot.passes import (
ReplacePT2DequantWithCadenceDequant,
ReplacePT2QuantWithCadenceQuant,
)
from executorch.exir import EdgeCompileConfig, EdgeProgramManager, to_edge

from torch.export import export
from torch.export.exported_program import ExportedProgram


# Export the model and lower it to an ExportedProgram (in aten IR)
def export_program(
model: torch.nn.Module,
inputs: Any,
inputs: Tuple[Any, ...],
) -> ExportedProgram:
assert isinstance(model, torch.nn.Module), "model should be an nn.Module"

Expand All @@ -37,26 +42,49 @@ def export_program(
return export(model, inputs)


# Export the model and lower it it edge IR.
# Export the model and lower it to an EdgeProgramManager (in edge IR).
def export_to_edge(
model: torch.nn.Module,
inputs: Any,
inputs: Tuple[Any, ...],
dump_graphs: bool = False,
) -> Tuple[EdgeProgramManager, ExportedProgram]:
) -> EdgeProgramManager:
assert isinstance(model, torch.nn.Module), "model should be an nn.Module"

# Export the model into an ExportedProgram.
expo_program = export_program(model, inputs)

if dump_graphs:
logging.info(f"Exported graph:\n{expo_program.graph_module.graph}")
logging.info("Exported graph:")
expo_program.graph_module.graph.print_tabular()

# Call to_edge to convert the graph to edge IR.
# Note: dim_order is skipped (https://github.com/pytorch/executorch/issues/3704)
edge_prog_manager = to_edge(
expo_program, compile_config=EdgeCompileConfig(_check_ir_validity=False)
expo_program,
compile_config=EdgeCompileConfig(
_check_ir_validity=False, _skip_dim_order=True
),
)

if dump_graphs:
logging.info(
f"Edge graph:\n{edge_prog_manager.exported_program().graph_module.graph}"
)
logging.info("Edge graph:")
edge_prog_manager.exported_program().graph_module.graph.print_tabular()

return edge_prog_manager


# Export the model and lower it to an EdgeProgramManager (in edge IR), and
# apply passes specific to Cadence DSP execution.
def export_to_cadence(
model: torch.nn.Module,
inputs: Tuple[Any, ...],
dump_graphs: bool = False,
) -> EdgeProgramManager:
edge_program_manager = export_to_edge(model, inputs)

# Run a couple required passes for quant/dequant ops
cadence_program_manager = edge_program_manager.transform(
[ReplacePT2QuantWithCadenceQuant(), ReplacePT2DequantWithCadenceDequant()]
)

return edge_prog_manager, expo_program
return cadence_program_manager
17 changes: 5 additions & 12 deletions backends/cadence/aot/export_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,7 @@
import os
from typing import Any, Tuple

from executorch.backends.cadence.aot.compiler import export_to_edge
from executorch.backends.cadence.aot.passes import (
ReplacePT2DequantWithCadenceDequant,
ReplacePT2QuantWithCadenceQuant,
)
from executorch.backends.cadence.aot.compiler import export_to_cadence, export_to_edge
from executorch.backends.cadence.aot.quantizer.fusion_pass import QuantFusion
from executorch.backends.cadence.aot.quantizer.quantizer import CadenceQuantizer
from executorch.exir import ExecutorchProgramManager
Expand Down Expand Up @@ -68,13 +64,11 @@ def export_model(
patterns = [q.pattern for q in quantizer.quantizers]
QuantFusion(patterns)(converted_model)

# Get edge program (note: the name will change to export_to_cadence in future PRs)
edge_prog_manager, expo_prog = export_to_edge(converted_model, example_inputs)
# Get edge program
edge_prog_manager = export_to_edge(converted_model, example_inputs)

# Run a couple required passes for quant/dequant ops
cadence_prog_manager = edge_prog_manager.transform(
[ReplacePT2QuantWithCadenceQuant(), ReplacePT2DequantWithCadenceDequant()]
)
# Get edge program after Cadence specific passes
cadence_prog_manager = export_to_cadence(converted_model, example_inputs)

exec_prog = cadence_prog_manager.to_executorch()

Expand All @@ -84,7 +78,6 @@ def export_model(

# Print some information to terminal
print_ops_info(
expo_prog.graph_module,
edge_prog_manager.exported_program().graph_module,
cadence_prog_manager.exported_program().graph_module,
)
Expand Down
18 changes: 1 addition & 17 deletions backends/cadence/aot/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,29 +71,15 @@ def get_ops_count(graph_module: torch.fx.GraphModule) -> Dict[str, int]:
# from export, from to_edge, and from Jarvis. Print the available
# implementations for each op, and error out if the op is not supported.
def print_ops_info(
export_gm: torch.fx.GraphModule,
to_edge_gm: torch.fx.GraphModule,
jarvis_gm: torch.fx.GraphModule,
):
export_ops_count = get_ops_count(export_gm)
to_edge_ops_count = get_ops_count(to_edge_gm)
jarvis_ops_count = get_ops_count(jarvis_gm)

# De-duplicate the "<op>" and "<op>_copy" ops
keys_to_delete_and_add = []
for k1 in export_ops_count:
for k2 in {**to_edge_ops_count, **jarvis_ops_count}:
if k2.startswith(k1):
keys_to_delete_and_add.append((k1, k2))
break

for k in keys_to_delete_and_add:
export_ops_count[k[1]] = export_ops_count[k[0]]
del export_ops_count[k[0]]

removed_ops = []
# Get the counts of the ops that are removed from the final graph
for k in {**export_ops_count, **to_edge_ops_count}:
for k in to_edge_ops_count:
if k not in jarvis_ops_count:
removed_ops.append(k)

Expand All @@ -103,7 +89,6 @@ def print_ops_info(
op,
jarvis_ops_count[op],
to_edge_ops_count[op] if op in to_edge_ops_count else 0,
export_ops_count[op] if op in export_ops_count else 0,
]
for op in jarvis_ops_count
]
Expand All @@ -115,7 +100,6 @@ def print_ops_info(
op,
0,
to_edge_ops_count[op] if op in to_edge_ops_count else 0,
export_ops_count[op] if op in export_ops_count else 0,
]
for op in removed_ops
]
Expand Down

0 comments on commit ed2da4f

Please sign in to comment.