Skip to content

Support rsqrt in XNNPACK backend #7992

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jan 28, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions backends/xnnpack/operators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
op_prelu,
op_quantize_per_tensor,
op_relu,
op_rsqrt,
op_sdpa,
op_sigmoid,
op_skip_ops,
Expand Down
52 changes: 52 additions & 0 deletions backends/xnnpack/operators/op_rsqrt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Dict

import torch
from executorch.backends.xnnpack.operators.node_visitor import (
NodeVisitor,
register_node_visitor,
)
from executorch.backends.xnnpack.serialization.xnnpack_graph_schema import (
XNNGraph,
XNNReciprocalSquareRoot,
XNode,
)
from executorch.backends.xnnpack.utils.utils import get_input_node


@register_node_visitor
class ReciprocalSquareRootVisitor(NodeVisitor):
target = "aten.rsqrt.default"

def __init__(self, *args) -> None:
super().__init__(*args)

def define_node(
self,
node: torch.fx.Node,
xnn_graph: XNNGraph,
vals_to_ids: Dict[torch.fx.Node, int],
debug_handle: int,
) -> None:
self.define_nodes_tensor_inputs_outputs(node, xnn_graph, vals_to_ids)

# input
input_id = vals_to_ids[get_input_node(node, 0)]

# output
output_id = vals_to_ids[node]

ser_node = XNode(
xnode_union=XNNReciprocalSquareRoot(
input_id=input_id,
output_id=output_id,
flags=0,
),
debug_handle=debug_handle,
)
xnn_graph.xnodes.append(ser_node)
2 changes: 2 additions & 0 deletions backends/xnnpack/partition/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
PermuteConfig,
PowConfig,
QuantizedPerTensorConfig,
ReciprocalSquareRootConfig,
ReLUConfig,
# SDPAConfig, TODO: D60553559: preserving SDPA for fairseq fails
SigmoidConfig,
Expand Down Expand Up @@ -92,6 +93,7 @@
PermuteConfig,
PowConfig,
PreluConfig,
ReciprocalSquareRootConfig,
ReLUConfig,
# SDPAConfig, TODO: D60553559: preserving SDPA for fairseq fails
SigmoidConfig,
Expand Down
7 changes: 7 additions & 0 deletions backends/xnnpack/partition/config/generic_node_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,6 +482,13 @@ def supported_precision_types(self) -> List[ConfigPrecisionType]:
return [ConfigPrecisionType.FP32]


class ReciprocalSquareRootConfig(GenericNodePartitionerConfig):
target_name = "rsqrt.default"

def supported_precision_types(self) -> List[ConfigPrecisionType]:
return [ConfigPrecisionType.FP32]


class ConstantPadConfig(GenericNodePartitionerConfig):
target_name = "constant_pad_nd.default"

Expand Down
1 change: 1 addition & 0 deletions backends/xnnpack/partition/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
exir_ops.edge.aten.avg_pool2d.default,
exir_ops.edge.aten.leaky_relu.default,
exir_ops.edge.aten.addmm.default, # TODO(T163877189) add constraint for addmm
exir_ops.edge.aten.rsqrt.default,
]

SUPPORTED_MODULES = [
Expand Down
31 changes: 31 additions & 0 deletions backends/xnnpack/runtime/XNNCompiler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1345,6 +1345,36 @@ Error defineSquareRootNode(
return Error::Ok;
}

/*
Define serialized square root node into the subgraph, using the remapped ids
to map the serialized ids, to the new ids generated when defining the
tensor value
*/
Error defineReciprocalSquareRootNode(
xnn_subgraph_t subgraph_ptr,
const std::unordered_map<uint32_t, uint32_t>& remapped_ids,
const NodePtr node,
const fb_xnnpack::XNNGraph* graph) noexcept {
MAYBE_UNUSED(graph);

auto graph_node = node->xnode_union_as_XNNReciprocalSquareRoot();

xnn_status status = xnn_define_reciprocal_square_root(
subgraph_ptr,
remapped_ids.at(graph_node->input_id()),
remapped_ids.at(graph_node->output_id()),
graph_node->flags());

ET_CHECK_OR_RETURN_ERROR(
status == xnn_status_success,
Internal,
"Failed to create reciprocal square root node %i with code: %s",
node->debug_handle(),
xnn_status_to_string(status));

return Error::Ok;
}

/*
Define serialized ceiling node into the subgraph, using the remapped ids
to map the serialized ids, to the new ids generated when defining the
Expand Down Expand Up @@ -1904,6 +1934,7 @@ DefineNodeFunc getDefineNodeFunc(fb_xnnpack::XNodeUnion nodeType) {
_DEFINE(StaticReshape)
_DEFINE(ArgMaxPooling2d)
_DEFINE(SquareRoot)
_DEFINE(ReciprocalSquareRoot)
_DEFINE(Ceiling)
_DEFINE(Hardswish)
_DEFINE(LeakyReLU)
Expand Down
1 change: 1 addition & 0 deletions backends/xnnpack/serialization/runtime_schema.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ union XNodeUnion {
XNNBatchMatrixMultiply: _XNNNode2x1,
XNNConcatenate5: _XNNCat,
XNNConvTranspose2d: _XNNNodeConv,
XNNReciprocalSquareRoot: _XNNNode1x1,
}

union XValueUnion {
Expand Down
1 change: 1 addition & 0 deletions backends/xnnpack/serialization/schema.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ union XNodeUnion {
XNNBatchMatrixMultiply: _XNNNode2x1,
XNNConcatenate5: _XNNCat,
XNNConvTranspose2d: _XNNNodeConv,
XNNReciprocalSquareRoot: _XNNNode1x1,
}

union XValueUnion {
Expand Down
6 changes: 6 additions & 0 deletions backends/xnnpack/serialization/xnnpack_graph_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,11 @@ class XNNSquareRoot(XNNNode1x1):
pass


@dataclass
class XNNReciprocalSquareRoot(XNNNode1x1):
pass


@dataclass
class XNNCeiling(XNNNode1x1):
pass
Expand Down Expand Up @@ -373,6 +378,7 @@ class XNNScaledDotProductAttention:
XNNStaticSlice,
XNNScaledDotProductAttention,
XNNBatchMatrixMultiply,
XNNReciprocalSquareRoot,
]


Expand Down
42 changes: 42 additions & 0 deletions backends/xnnpack/test/ops/test_rsqrt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import unittest

import torch
from executorch.backends.xnnpack.test.tester import Tester


class TestRsqrt(unittest.TestCase):
class Rsqrt(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x):
x = torch.abs(x)
z = torch.rsqrt(x)
return z

def _test_rsqrt(self, inputs):
(
Tester(self.Rsqrt(), inputs)
Copy link
Contributor

@digantdesai digantdesai Jan 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Do you want to improve this test by adding dynamic_shape support like this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

improve this test by adding dynamic_shape support

it's a pointwise operator, why does it care about the shape? Everything is the same as sqrt including the test

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, just checking boxes, i.e. adding dynamic shape test for every operator.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

out of scope for this change! :)

.export()
.check_count({"torch.ops.aten.rsqrt.default": 1})
.to_edge_transform_and_lower()
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
.check_not(["executorch_exir_dialects_edge__ops_aten_rsqrt_default"])
.to_executorch()
.serialize()
.run_method_and_compare_outputs()
)

def test_fp16_rsqrt(self):
inputs = (torch.randn(20).to(torch.float16),)
self._test_rsqrt(inputs)

def test_fp32_rsqrt(self):
inputs = (torch.randn(20),)
self._test_rsqrt(inputs)