Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,14 @@ buck2-bin/
build/
cmake-android-out/
cmake-ios-out/
cmake-android-out*
cmake-out*
cmake-out-android/
build-android/
build-x86/
dist/
arm-scratch/
ethos-u-scratch/
executorch.egg-info
pip-out/
build-profiling/
Expand Down
8 changes: 4 additions & 4 deletions backends/xnnpack/partition/xnnpack_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,16 @@
from executorch.backends.xnnpack.xnnpack_preprocess import XnnpackBackend
from executorch.exir.backend.backend_details import ExportedProgram
from executorch.exir.backend.canonical_partitioners.config_partitioner import (
ConfigerationBasedPartitioner,
ConfigurationBasedPartitioner,
)
from executorch.exir.backend.partitioner import DelegationSpec
from executorch.exir.backend.partitioner import DelegationSpec, PartitionResult
from torch.fx.passes.infra.partitioner import Partition

logging.basicConfig(level=logging.WARNING)
logger = logging.getLogger(__name__)


class XnnpackPartitioner(ConfigerationBasedPartitioner):
class XnnpackPartitioner(ConfigurationBasedPartitioner):
def __init__(
self,
configs: Optional[List[Type[XNNPartitionerConfig]]] = None,
Expand Down Expand Up @@ -83,7 +83,7 @@ def _check_if_called_from_to_backend(self) -> bool:
return True
return False

def partition(self, exported_program):
def partition(self, exported_program: ExportedProgram) -> PartitionResult:
"""
Override partition to add deprecation warning when called from to_backend.
"""
Expand Down
219 changes: 219 additions & 0 deletions backends/xnnpack/test/fragments/test_hop.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,219 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import unittest

import torch
from executorch.backends.xnnpack.test.tester import Tester
from torch._higher_order_ops.map import map as torch_map
from torch._higher_order_ops.scan import scan


class TestHigherOrderOps(unittest.TestCase):
def setUp(self):
torch._dynamo.reset()

def test_cond(self):
"""
Test that torch.cond with add/sub branches can be lowered to XNNPACK.

The model returns x + y if x[0] > 0, else x - y.
Verifies that add and sub ops are delegated to XNNPACK (not present
as undelegated operators in the executorch program).
"""

class CondModel(torch.nn.Module):
def true_fn(self, x, y):
return x + y

def false_fn(self, x, y):
return x - y

def forward(self, x, y):
return torch.cond(x[0] > 0, self.true_fn, self.false_fn, [x, y])

model = CondModel()
inputs = (torch.randn(4), torch.randn(4))

tester = (
Tester(model, inputs)
.export()
.to_edge_transform_and_lower()
.to_executorch()
)

# Get the executorch program
program = tester.get_artifact()._emitter_output.program

# Check that add and sub are not in the operators list (they should be delegated)
operator_names = [
op.name for plan in program.execution_plan for op in plan.operators
]

self.assertNotIn(
"aten::add",
operator_names,
"add op should be delegated",
)
self.assertNotIn(
"aten::sub",
operator_names,
"sub op should be delegated",
)

# Verify there are XNNPACK delegates
delegates = [d for plan in program.execution_plan for d in plan.delegates]
xnnpack_delegates = [d for d in delegates if d.id == "XnnpackBackend"]
self.assertEqual(
len(xnnpack_delegates),
2,
"Expected 2 XNNPACK delegates (one for each branch)",
)

# Verify execution produces correct results
tester.serialize().run_method_and_compare_outputs()

def test_cond_with_linear(self):
"""
Test that torch.cond with a linear module in one branch can be lowered.

The model returns linear(x) if x[0] > 0, else x * 2.
This test verifies that lowering and execution succeed.
"""

class CondLinearModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(4, 4)

def true_fn(self, x):
return self.linear(x)

def false_fn(self, x):
return x * 2

def forward(self, x):
return torch.cond(x[0] > 0, self.true_fn, self.false_fn, [x])

model = CondLinearModel()
inputs = (torch.randn(4),)

tester = (
Tester(model, inputs)
.export()
.to_edge_transform_and_lower()
)

# Print the graph module and all submodules
gm = tester.get_artifact().exported_program().graph_module
print("\n=== Main Graph Module ===")
print(gm)
print("\n=== Main Graph ===")
gm.graph.print_tabular()

for name, submod in gm.named_modules():
if name and hasattr(submod, 'graph'):
print(f"\n=== Submodule: {name} ===")
print(submod)
if hasattr(submod, 'graph'):
submod.graph.print_tabular()

(
tester
.to_executorch()
.serialize()
.run_method_and_compare_outputs()
)

def test_map(self):
"""
Test that torch.map with add operation can be lowered to XNNPACK.

Maps a function that adds y to each element of xs.
Verifies that add ops are delegated to XNNPACK.
"""

class MapModel(torch.nn.Module):
def forward(self, xs, y):
def f(x, y):
return x + y

return torch_map(f, xs, y)

model = MapModel()
inputs = (torch.randn(3, 4), torch.randn(4))

tester = (
Tester(model, inputs)
.export()
.to_edge_transform_and_lower()
.to_executorch()
)

# Get the executorch program (before serialize)
program = tester.get_artifact()._emitter_output.program

# Check that add is not in the operators list (it should be delegated)
operator_names = [
op.name for plan in program.execution_plan for op in plan.operators
]

self.assertNotIn(
"aten::add",
operator_names,
"add op should be delegated",
)

# Verify execution produces correct results
tester.serialize().run_method_and_compare_outputs()

def test_scan(self):
"""
Test that torch.scan (cumulative sum) can be lowered to XNNPACK.

Performs a cumulative sum over the input tensor.
Verifies that add ops inside scan are delegated to XNNPACK.
"""

class ScanModel(torch.nn.Module):
def forward(self, xs):
def combine_fn(carry, x):
new_carry = carry + x
return new_carry, new_carry + 0

init = torch.zeros_like(xs[0])
return scan(combine_fn, init, xs)

model = ScanModel()
inputs = (torch.randn(5, 4),)

tester = (
Tester(model, inputs)
.export()
.to_edge_transform_and_lower()
.to_executorch()
)

# Get the executorch program (before serialize)
program = tester.get_artifact()._emitter_output.program

# Check that add is not in the operators list (it should be delegated)
operator_names = [
op.name for plan in program.execution_plan for op in plan.operators
]

self.assertNotIn(
"aten::add",
operator_names,
"add op should be delegated",
)

# Verify execution produces correct results
tester.serialize().run_method_and_compare_outputs()


if __name__ == "__main__":
unittest.main()
90 changes: 51 additions & 39 deletions backends/xnnpack/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,23 +8,15 @@

import executorch.exir as exir
import torch

from executorch.backends.xnnpack.utils.configs import (
get_transform_passes,
get_xnnpack_capture_config,
get_xnnpack_edge_compile_config,
)
from executorch.exir import ExportedProgram
from executorch.exir.dialects._ops import ops as exir_ops

from torch._export.utils import (
get_buffer,
get_lifted_tensor_constant,
get_param,
is_buffer,
is_lifted_tensor_constant,
is_param,
)
from executorch.exir.passes.propagate_input_spec import INPUT_SPEC_KEY
from torch.export.graph_signature import InputKind
from torchao.quantization.pt2e.utils import _is_conv_node, _is_conv_transpose_node


Expand Down Expand Up @@ -104,48 +96,68 @@ def is_get_attr_node(node: torch.fx.Node) -> bool:


def is_param_node(exp_prog: ExportedProgram, node: torch.fx.Node) -> bool:
return (
is_get_attr_node(node)
or is_param(exp_prog, node)
or is_buffer(exp_prog, node)
or is_lifted_tensor_constant(exp_prog, node)
)
# Check if node is a get_attr node first
if is_get_attr_node(node):
return True

# Check if node can be resolved via input_spec and state_dict/constants
input_spec = node.meta.get(INPUT_SPEC_KEY, None)
if input_spec is not None and input_spec.target is not None:
target = input_spec.target
if input_spec.kind == InputKind.PARAMETER:
if target in exp_prog.state_dict or target in exp_prog.constants:
return True
elif input_spec.kind == InputKind.BUFFER:
if target in exp_prog.state_dict or target in exp_prog.constants:
return True
elif input_spec.kind == InputKind.CONSTANT_TENSOR:
if target in exp_prog.constants:
return True

return False


def get_param_tensor(
exp_prog: ExportedProgram, node: torch.fx.Node
) -> Optional[torch.Tensor]:
if node is None:
return None
elif is_param(exp_prog, node):
return get_param(exp_prog, node)
elif is_buffer(exp_prog, node):
return get_buffer(exp_prog, node)
elif is_lifted_tensor_constant(exp_prog, node):
return get_lifted_tensor_constant(exp_prog, node)
elif is_get_attr_node(node):
# This is a hack to support both lifted and unlifted graph
try:
return getattr(node.graph.owning_module, node.target)
except AttributeError:
return getattr(exp_prog.graph_module, node.target)
raise RuntimeError(f"unsupported param type, {node.op}.")

input_spec = node.meta.get(INPUT_SPEC_KEY, None)

if input_spec is not None and input_spec.target is not None:
target = input_spec.target
if input_spec.kind == InputKind.PARAMETER:
if target in exp_prog.state_dict:
return exp_prog.state_dict[target]
if target in exp_prog.constants:
return exp_prog.constants[target]
elif input_spec.kind == InputKind.BUFFER:
if input_spec.persistent and target in exp_prog.state_dict:
return exp_prog.state_dict[target]
if target in exp_prog.constants:
return exp_prog.constants[target]
elif input_spec.kind == InputKind.CONSTANT_TENSOR:
if target in exp_prog.constants:
return exp_prog.constants[target]

# This is a hack to support both lifted and unlifted graph
try:
return getattr(node.graph.owning_module, node.target)
except AttributeError:
return getattr(exp_prog.graph_module, node.target)


def get_tensor_name(exp_prog: ExportedProgram, node: torch.fx.Node) -> str:
if node is None:
return ""
if is_param(exp_prog, node):
return exp_prog.graph_signature.inputs_to_parameters[node.name]
elif is_buffer(exp_prog, node):
return exp_prog.graph_signature.inputs_to_buffers[node.name]
elif is_lifted_tensor_constant(exp_prog, node):
return exp_prog.graph_signature.inputs_to_lifted_tensor_constants[node.name]
else:
assert isinstance(node.target, str)
return node.target

return ""
input_spec = node.meta.get(INPUT_SPEC_KEY, None)
if input_spec is not None:
return input_spec.target

assert isinstance(node.target, str)
return node.target


def get_source_fn(node: torch.fx.Node) -> Optional[torch.fx.Node]:
Expand Down
Loading
Loading