Skip to content

Commit deac10c

Browse files
swolchokYIWENX14
authored andcommitted
Support rsqrt in XNNPACK backend (#7992)
I think I updated everywhere that needs updating? Test Plan: python -m unittest backends/xnnpack/test/ops/test_rsqrt.py
1 parent e4ad233 commit deac10c

File tree

10 files changed

+144
-0
lines changed

10 files changed

+144
-0
lines changed

backends/xnnpack/operators/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
op_prelu,
3838
op_quantize_per_tensor,
3939
op_relu,
40+
op_rsqrt,
4041
op_sdpa,
4142
op_sigmoid,
4243
op_skip_ops,
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from typing import Dict
8+
9+
import torch
10+
from executorch.backends.xnnpack.operators.node_visitor import (
11+
NodeVisitor,
12+
register_node_visitor,
13+
)
14+
from executorch.backends.xnnpack.serialization.xnnpack_graph_schema import (
15+
XNNGraph,
16+
XNNReciprocalSquareRoot,
17+
XNode,
18+
)
19+
from executorch.backends.xnnpack.utils.utils import get_input_node
20+
21+
22+
@register_node_visitor
23+
class ReciprocalSquareRootVisitor(NodeVisitor):
24+
target = "aten.rsqrt.default"
25+
26+
def __init__(self, *args) -> None:
27+
super().__init__(*args)
28+
29+
def define_node(
30+
self,
31+
node: torch.fx.Node,
32+
xnn_graph: XNNGraph,
33+
vals_to_ids: Dict[torch.fx.Node, int],
34+
debug_handle: int,
35+
) -> None:
36+
self.define_nodes_tensor_inputs_outputs(node, xnn_graph, vals_to_ids)
37+
38+
# input
39+
input_id = vals_to_ids[get_input_node(node, 0)]
40+
41+
# output
42+
output_id = vals_to_ids[node]
43+
44+
ser_node = XNode(
45+
xnode_union=XNNReciprocalSquareRoot(
46+
input_id=input_id,
47+
output_id=output_id,
48+
flags=0,
49+
),
50+
debug_handle=debug_handle,
51+
)
52+
xnn_graph.xnodes.append(ser_node)

backends/xnnpack/partition/config/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
PermuteConfig,
4040
PowConfig,
4141
QuantizedPerTensorConfig,
42+
ReciprocalSquareRootConfig,
4243
ReLUConfig,
4344
# SDPAConfig, TODO: D60553559: preserving SDPA for fairseq fails
4445
SigmoidConfig,
@@ -92,6 +93,7 @@
9293
PermuteConfig,
9394
PowConfig,
9495
PreluConfig,
96+
ReciprocalSquareRootConfig,
9597
ReLUConfig,
9698
# SDPAConfig, TODO: D60553559: preserving SDPA for fairseq fails
9799
SigmoidConfig,

backends/xnnpack/partition/config/generic_node_configs.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -482,6 +482,13 @@ def supported_precision_types(self) -> List[ConfigPrecisionType]:
482482
return [ConfigPrecisionType.FP32]
483483

484484

485+
class ReciprocalSquareRootConfig(GenericNodePartitionerConfig):
486+
target_name = "rsqrt.default"
487+
488+
def supported_precision_types(self) -> List[ConfigPrecisionType]:
489+
return [ConfigPrecisionType.FP32]
490+
491+
485492
class ConstantPadConfig(GenericNodePartitionerConfig):
486493
target_name = "constant_pad_nd.default"
487494

backends/xnnpack/partition/configs.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@
6363
exir_ops.edge.aten.avg_pool2d.default,
6464
exir_ops.edge.aten.leaky_relu.default,
6565
exir_ops.edge.aten.addmm.default, # TODO(T163877189) add constraint for addmm
66+
exir_ops.edge.aten.rsqrt.default,
6667
]
6768

6869
SUPPORTED_MODULES = [

backends/xnnpack/runtime/XNNCompiler.cpp

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1345,6 +1345,36 @@ Error defineSquareRootNode(
13451345
return Error::Ok;
13461346
}
13471347

1348+
/*
1349+
Define serialized square root node into the subgraph, using the remapped ids
1350+
to map the serialized ids, to the new ids generated when defining the
1351+
tensor value
1352+
*/
1353+
Error defineReciprocalSquareRootNode(
1354+
xnn_subgraph_t subgraph_ptr,
1355+
const std::unordered_map<uint32_t, uint32_t>& remapped_ids,
1356+
const NodePtr node,
1357+
const fb_xnnpack::XNNGraph* graph) noexcept {
1358+
MAYBE_UNUSED(graph);
1359+
1360+
auto graph_node = node->xnode_union_as_XNNReciprocalSquareRoot();
1361+
1362+
xnn_status status = xnn_define_reciprocal_square_root(
1363+
subgraph_ptr,
1364+
remapped_ids.at(graph_node->input_id()),
1365+
remapped_ids.at(graph_node->output_id()),
1366+
graph_node->flags());
1367+
1368+
ET_CHECK_OR_RETURN_ERROR(
1369+
status == xnn_status_success,
1370+
Internal,
1371+
"Failed to create reciprocal square root node %i with code: %s",
1372+
node->debug_handle(),
1373+
xnn_status_to_string(status));
1374+
1375+
return Error::Ok;
1376+
}
1377+
13481378
/*
13491379
Define serialized ceiling node into the subgraph, using the remapped ids
13501380
to map the serialized ids, to the new ids generated when defining the
@@ -1904,6 +1934,7 @@ DefineNodeFunc getDefineNodeFunc(fb_xnnpack::XNodeUnion nodeType) {
19041934
_DEFINE(StaticReshape)
19051935
_DEFINE(ArgMaxPooling2d)
19061936
_DEFINE(SquareRoot)
1937+
_DEFINE(ReciprocalSquareRoot)
19071938
_DEFINE(Ceiling)
19081939
_DEFINE(Hardswish)
19091940
_DEFINE(LeakyReLU)

backends/xnnpack/serialization/runtime_schema.fbs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,7 @@ union XNodeUnion {
138138
XNNBatchMatrixMultiply: _XNNNode2x1,
139139
XNNConcatenate5: _XNNCat,
140140
XNNConvTranspose2d: _XNNNodeConv,
141+
XNNReciprocalSquareRoot: _XNNNode1x1,
141142
}
142143

143144
union XValueUnion {

backends/xnnpack/serialization/schema.fbs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,7 @@ union XNodeUnion {
134134
XNNBatchMatrixMultiply: _XNNNode2x1,
135135
XNNConcatenate5: _XNNCat,
136136
XNNConvTranspose2d: _XNNNodeConv,
137+
XNNReciprocalSquareRoot: _XNNNode1x1,
137138
}
138139

139140
union XValueUnion {

backends/xnnpack/serialization/xnnpack_graph_schema.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -281,6 +281,11 @@ class XNNSquareRoot(XNNNode1x1):
281281
pass
282282

283283

284+
@dataclass
285+
class XNNReciprocalSquareRoot(XNNNode1x1):
286+
pass
287+
288+
284289
@dataclass
285290
class XNNCeiling(XNNNode1x1):
286291
pass
@@ -373,6 +378,7 @@ class XNNScaledDotProductAttention:
373378
XNNStaticSlice,
374379
XNNScaledDotProductAttention,
375380
XNNBatchMatrixMultiply,
381+
XNNReciprocalSquareRoot,
376382
]
377383

378384

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import unittest
8+
9+
import torch
10+
from executorch.backends.xnnpack.test.tester import Tester
11+
12+
13+
class TestRsqrt(unittest.TestCase):
14+
class Rsqrt(torch.nn.Module):
15+
def __init__(self):
16+
super().__init__()
17+
18+
def forward(self, x):
19+
x = torch.abs(x)
20+
z = torch.rsqrt(x)
21+
return z
22+
23+
def _test_rsqrt(self, inputs):
24+
(
25+
Tester(self.Rsqrt(), inputs)
26+
.export()
27+
.check_count({"torch.ops.aten.rsqrt.default": 1})
28+
.to_edge_transform_and_lower()
29+
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
30+
.check_not(["executorch_exir_dialects_edge__ops_aten_rsqrt_default"])
31+
.to_executorch()
32+
.serialize()
33+
.run_method_and_compare_outputs()
34+
)
35+
36+
def test_fp16_rsqrt(self):
37+
inputs = (torch.randn(20).to(torch.float16),)
38+
self._test_rsqrt(inputs)
39+
40+
def test_fp32_rsqrt(self):
41+
inputs = (torch.randn(20),)
42+
self._test_rsqrt(inputs)

0 commit comments

Comments
 (0)