Skip to content

Commit 4bb7959

Browse files
committed
Add tanh op to XNNPACK backend
1 parent daebcde commit 4bb7959

File tree

10 files changed

+511
-0
lines changed

10 files changed

+511
-0
lines changed

backends/xnnpack/operators/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,5 +50,6 @@
5050
op_static_constant_pad,
5151
op_static_resize_bilinear_2d,
5252
op_sub,
53+
op_tanh,
5354
op_to_copy,
5455
)

backends/xnnpack/operators/op_tanh.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from typing import Dict
8+
9+
import torch
10+
from executorch.backends.xnnpack.operators.node_visitor import (
11+
NodeVisitor,
12+
register_node_visitor,
13+
)
14+
from executorch.backends.xnnpack.serialization.xnnpack_graph_schema import (
15+
XNNGraph,
16+
XNNTanh,
17+
XNode,
18+
)
19+
from executorch.backends.xnnpack.utils.utils import get_input_node
20+
21+
22+
@register_node_visitor
23+
class TanhVisitor(NodeVisitor):
24+
target = "aten.tanh.default"
25+
26+
def __init__(self, *args) -> None:
27+
super().__init__(*args)
28+
29+
def define_node(
30+
self,
31+
node: torch.fx.Node,
32+
xnn_graph: XNNGraph,
33+
vals_to_ids: Dict[torch.fx.Node, int],
34+
debug_handle: int,
35+
) -> None:
36+
self.define_nodes_tensor_inputs_outputs(node, xnn_graph, vals_to_ids)
37+
38+
# input
39+
input_id = vals_to_ids[get_input_node(node, 0)]
40+
41+
# output
42+
output_id = vals_to_ids[node]
43+
44+
ser_node = XNode(
45+
xnode_union=XNNTanh(
46+
input_id=input_id,
47+
output_id=output_id,
48+
flags=0,
49+
),
50+
debug_handle=debug_handle,
51+
)
52+
xnn_graph.xnodes.append(ser_node)

backends/xnnpack/partition/config/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
SoftmaxConfig,
5050
SquareRootConfig,
5151
SubConfig,
52+
TanhConfig,
5253
UpsampleBilinear2dConfig,
5354
)
5455
from executorch.backends.xnnpack.partition.config.node_configs import (
@@ -99,6 +100,7 @@
99100
PreluConfig,
100101
ReciprocalSquareRootConfig,
101102
ReLUConfig,
103+
TanhConfig,
102104
# SDPAConfig, TODO: D60553559: preserving SDPA for fairseq fails
103105
SigmoidConfig,
104106
SliceCopyConfig,

backends/xnnpack/partition/config/generic_node_configs.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -371,6 +371,13 @@ def supported_precision_types(self) -> List[ConfigPrecisionType]:
371371
return [ConfigPrecisionType.FP32]
372372

373373

374+
class TanhConfig(GenericNodePartitionerConfig):
375+
target_name = "tanh.default"
376+
377+
def supported_precision_types(self) -> List[ConfigPrecisionType]:
378+
return [ConfigPrecisionType.FP32]
379+
380+
374381
class MeanDimConfig(GenericNodePartitionerConfig):
375382
target_name = "mean.dim"
376383

backends/xnnpack/partition/configs.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@
6666
exir_ops.edge.aten.rsqrt.default,
6767
exir_ops.edge.aten.log.default,
6868
exir_ops.edge.aten.gelu.default,
69+
exir_ops.edge.aten.tanh.default,
6970
]
7071

7172
SUPPORTED_MODULES = [

0 commit comments

Comments
 (0)