Skip to content

[ExecuTorch][to_backend] add AllNodePartitioner #9822

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Apr 4, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions exir/backend/canonical_partitioners/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ runtime.python_library(
srcs = [
"duplicate_dequant_node_pass.py",
"pattern_op_partitioner.py",
"all_node_partitioner.py",
],
visibility = [
"//executorch/...",
Expand Down
55 changes: 55 additions & 0 deletions exir/backend/canonical_partitioners/all_node_partitioner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Dict, List

import torch
from executorch.exir.backend.backend_details import ExportedProgram
from executorch.exir.backend.compile_spec_schema import CompileSpec
from executorch.exir.backend.partitioner import (
DelegationSpec,
Partitioner,
PartitionResult,
)
from torch._export.utils import is_buffer, is_lifted_tensor_constant, is_param


def is_non_tensor_placeholder(node: torch.fx.Node, ep: ExportedProgram) -> bool:
"""
Returns true if the node is a placeholder node and it is not a tensor
"""
return node.op == "placeholder" and not (
is_param(ep, node) or is_buffer(ep, node) or is_lifted_tensor_constant(ep, node)
)


class AllNodePartitioner(Partitioner):
def __init__(
self,
backend_id: str,
compile_specs: List[CompileSpec],
):
"""
Partitioner that lowers every single node in the graph module unconditionally
to the specified backend_id
"""
super().__init__()
self.delegation_spec = DelegationSpec(backend_id, compile_specs)

def partition(self, exported_program: ExportedProgram) -> PartitionResult:
# tag all nodes
partition_tags: Dict[str, DelegationSpec] = {}
for node in exported_program.graph_module.graph.nodes:
if is_non_tensor_placeholder(node, exported_program) or node.op == "output":
continue

delegation_tag = self.delegation_spec.backend_id
node.meta["delegation_tag"] = delegation_tag
partition_tags[delegation_tag] = self.delegation_spec

return PartitionResult(
tagged_exported_program=exported_program, partition_tags=partition_tags
)
179 changes: 179 additions & 0 deletions exir/backend/test/test_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,11 @@

import executorch.exir as exir
import torch
from executorch.exir import to_edge
from executorch.exir.backend.backend_api import LoweredBackendModule, to_backend
from executorch.exir.backend.canonical_partitioners.all_node_partitioner import (
AllNodePartitioner,
)
from executorch.exir.backend.compile_spec_schema import CompileSpec
from executorch.exir.backend.partitioner import (
DelegationSpec,
Expand Down Expand Up @@ -1266,3 +1270,178 @@ def forward(self, x: List[torch.Tensor]):

gm = exir.capture(ComposedM(), inputs, exir.CaptureConfig()).to_edge()
gm(*inputs)

def test_to_backend_delegation_spec(self):
class SinModule(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x):
return [torch.sin(x)]

sin_module = SinModule()
model_inputs = (torch.ones(1),)
max_value = model_inputs[0].shape[0]

partitioner = AllNodePartitioner(
"BackendWithCompilerDemo", [CompileSpec("max_value", bytes([max_value]))]
)

edgeir_m = to_edge(torch.export.export(sin_module, model_inputs))
edgeir_m = edgeir_m.to_backend(partitioner)
exec_prog = edgeir_m.to_executorch()
graph_module = exec_prog.exported_program().graph_module
# Check that there is not an aten.sin node.
self.assertTrue(
exir_ops.edge.aten.sin
not in {node.target for node in graph_module.graph.nodes}
)

# Check that there exists a call_delegate, representing the call to the
# delegated function
FileCheck().check("torch.ops.higher_order.executorch_call_delegate").run(
graph_module.code
)
lowered_submodules = get_lowered_submodules(graph_module)
self.assertEqual(len(lowered_submodules), 1)

for node in graph_module.graph.nodes:
if node.op == "call_function" and node.target == executorch_call_delegate:
# Check that first arg is lowered_module_{unique_id}
self.assertEqual(node.args[0].target, "lowered_module_0")

program = exec_prog.executorch_program

# Check the program can be printed
print_program(program)

# Check the backend delegate
self.check_backend_delegate(
program=program,
delegate=program.execution_plan[0].delegates[0],
expected_id=BackendWithCompilerDemo.__name__,
expected_processed=b"1version:0#op:demo::aten.sin.default, numel:1, dtype:torch.float32<debug_handle>2#",
)

# Check the delegate instruction
self.assertTrue(
isinstance(
program.execution_plan[0].chains[0].instructions[0].instr_args,
DelegateCall,
)
)
buff = exec_prog.buffer

executorch_module = _load_for_executorch_from_buffer(buff)
model_inputs = torch.ones(1)
model_outputs = executorch_module.forward([model_inputs])
self.assertEqual(
model_inputs,
torch.ones(1),
)
expected_output = 0.8333 * torch.ones(1)

self.assertTrue(
torch.allclose(model_outputs[0], expected_output, atol=1e-03, rtol=1e-03)
)

def test_to_backend_multimethod_delegation_spec(self):
class SinModule(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x):
return torch.sin(x)

def inputs(self):
return (torch.ones(1),)

class AddMulModule(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, a, x, b):
y = torch.mm(a, x)
z = torch.add(y, b)
return z

def inputs(self):
return (torch.ones(2, 2), 2 * torch.ones(2, 2), 3 * torch.ones(2, 2))

sin_module = SinModule()
max_value_sin = sin_module.inputs()[0].shape[0]
sin_partitioner = AllNodePartitioner(
"BackendWithCompilerDemo",
[CompileSpec("max_value", bytes([max_value_sin]))],
)

add_mul_module = AddMulModule()
max_value_add_mul = add_mul_module.inputs()[0].shape[0]
add_mul_partitioner = AllNodePartitioner(
"BackendWithCompilerDemo",
[CompileSpec("max_value", bytes([max_value_add_mul]))],
)

edgeir_m = to_edge(
{
"sin": torch.export.export(sin_module, sin_module.inputs()),
"add_mul": torch.export.export(add_mul_module, add_mul_module.inputs()),
}
)
edgeir_m = edgeir_m.to_backend(
{
"sin": sin_partitioner,
"add_mul": add_mul_partitioner,
}
)
exec_prog = edgeir_m.to_executorch()

for method_name in ["sin", "add_mul"]:
graph_module = exec_prog.exported_program(method_name).graph_module
# Check delegated nodes are gone
self.assertTrue(
exir_ops.edge.aten.sin
not in {node.target for node in graph_module.graph.nodes}
)
self.assertTrue(
exir_ops.edge.aten.add
not in {node.target for node in graph_module.graph.nodes}
)
self.assertTrue(
exir_ops.edge.aten.mm
not in {node.target for node in graph_module.graph.nodes}
)
# Check that there exists a call_delegate, representing the call to the
# delegated function
FileCheck().check("torch.ops.higher_order.executorch_call_delegate").run(
graph_module.code
)
lowered_submodules = get_lowered_submodules(graph_module)
self.assertEqual(len(lowered_submodules), 1)

program = exec_prog.executorch_program

# Check the program can be printed
print_program(program)

buff = exec_prog.buffer

executorch_module = _load_for_executorch_from_buffer(buff)

for method_name, module in {
"sin": sin_module,
"add_mul": add_mul_module,
}.items():
inputs_flattened, _ = tree_flatten(module.inputs())
model_outputs = executorch_module.run_method(
method_name, tuple(inputs_flattened)
)

if method_name == "sin":
# backend with compiler demo does a taylor approximation of sin
ref_output = 0.8333 * torch.ones(1)
else:
ref_output = module(*module.inputs())
self.assertTrue(
torch.allclose(model_outputs[0], ref_output, atol=1e-03, rtol=1e-03)
)
15 changes: 15 additions & 0 deletions exir/backend/test/test_backends_lifted.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@
import torch
from executorch.exir import to_edge
from executorch.exir.backend.backend_api import LoweredBackendModule, to_backend
from executorch.exir.backend.canonical_partitioners.all_node_partitioner import (
AllNodePartitioner,
)
from executorch.exir.backend.compile_spec_schema import CompileSpec
from executorch.exir.backend.partitioner import (
DelegationSpec,
Expand Down Expand Up @@ -138,6 +141,18 @@ def forward(self, x):

self.assertTrue(torch.allclose(new_res, expected_res))

# Test same flow but through edge_program_manager
edgeir_m = to_edge(export(sin_module, model_inputs, strict=True))
loweredir_m = edgeir_m.to_backend(
AllNodePartitioner(BackendWithCompilerDemo.__name__, [])
)
lowered_sin_module = get_lowered_submodules(
loweredir_m.exported_program().graph_module
)[0][1]

new_res = lowered_sin_module(*model_inputs)[0]

self.assertTrue(torch.allclose(new_res, expected_res))
# TODO(tkaruturi): emitting single LoweredBackendModule
# program = to_edge(export(graph_module)).to_exectorch()._emitter_output.program

Expand Down
49 changes: 49 additions & 0 deletions exir/backend/test/test_compatibility.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@
from executorch.exir import to_edge
from executorch.exir._serialize import _serialize_pte_binary
from executorch.exir.backend.backend_api import to_backend
from executorch.exir.backend.canonical_partitioners.all_node_partitioner import (
AllNodePartitioner,
)
from executorch.exir.backend.compile_spec_schema import CompileSpec
from executorch.exir.backend.test.backend_with_compiler_demo import (
BackendWithCompilerDemo,
Expand Down Expand Up @@ -65,3 +68,49 @@ def forward(self, x):
"loading method forward failed with error 0x30",
):
executorch_module = _load_for_executorch_from_buffer(buff)

def test_compatibility_in_runtime_edge_program_manager(self):
class SinModule(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x):
return torch.sin(x)

sin_module = SinModule()
model_inputs = (torch.ones(1),)
edgeir_m = to_edge(export(sin_module, model_inputs, strict=True))
max_value = model_inputs[0].shape[0]
compile_specs = [CompileSpec("max_value", bytes([max_value]))]
lowered_edge_irm = edgeir_m.to_backend(
AllNodePartitioner("BackendWithCompilerDemo", compile_specs)
)
exec_prog = lowered_edge_irm.to_executorch()

buff = exec_prog.buffer

# The demo backend works well
executorch_module = _load_for_executorch_from_buffer(buff)
model_inputs = torch.ones(1)
_ = executorch_module.forward([model_inputs])

prog = exec_prog.executorch_program
# Rewrite the delegate version number from 0 to 1.
prog.backend_delegate_data[0].data = bytes(
"1version:1#op:demo::aten.sin.default, numel:1, dtype:torch.float32<debug_handle>1#",
encoding="utf8",
)

# Generate the .pte file with the wrong version.
buff = bytes(
_serialize_pte_binary(
program=prog,
)
)

# Throw runtime error with error code 0x30, meaning delegate is incompatible.
with self.assertRaisesRegex(
RuntimeError,
"loading method forward failed with error 0x30",
):
executorch_module = _load_for_executorch_from_buffer(buff)
1 change: 1 addition & 0 deletions exir/program/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ python_library(
"//executorch/exir/_serialize:lib",
"//executorch/exir/backend:backend_api",
"//executorch/exir/backend:partitioner",
"//executorch/exir/backend/canonical_partitioners:canonical_partitioner_lib",
"//executorch/exir/capture:config",
"//executorch/exir/emit:emit",
"//executorch/exir/emit:lib",
Expand Down
Loading