File tree Expand file tree Collapse file tree 10 files changed +145
-0
lines changed Expand file tree Collapse file tree 10 files changed +145
-0
lines changed Original file line number Diff line number Diff line change 50
50
op_static_constant_pad ,
51
51
op_static_resize_bilinear_2d ,
52
52
op_sub ,
53
+ op_tanh ,
53
54
op_to_copy ,
54
55
)
Original file line number Diff line number Diff line change
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from typing import Dict
8
+
9
+ import torch
10
+ from executorch .backends .xnnpack .operators .node_visitor import (
11
+ NodeVisitor ,
12
+ register_node_visitor ,
13
+ )
14
+ from executorch .backends .xnnpack .serialization .xnnpack_graph_schema import (
15
+ XNNGraph ,
16
+ XNNTanh ,
17
+ XNode ,
18
+ )
19
+ from executorch .backends .xnnpack .utils .utils import get_input_node
20
+
21
+
22
+ @register_node_visitor
23
+ class TanhVisitor (NodeVisitor ):
24
+ target = "aten.tanh.default"
25
+
26
+ def __init__ (self , * args ) -> None :
27
+ super ().__init__ (* args )
28
+
29
+ def define_node (
30
+ self ,
31
+ node : torch .fx .Node ,
32
+ xnn_graph : XNNGraph ,
33
+ vals_to_ids : Dict [torch .fx .Node , int ],
34
+ debug_handle : int ,
35
+ ) -> None :
36
+ self .define_nodes_tensor_inputs_outputs (node , xnn_graph , vals_to_ids )
37
+
38
+ # input
39
+ input_id = vals_to_ids [get_input_node (node , 0 )]
40
+
41
+ # output
42
+ output_id = vals_to_ids [node ]
43
+
44
+ ser_node = XNode (
45
+ xnode_union = XNNTanh (
46
+ input_id = input_id ,
47
+ output_id = output_id ,
48
+ flags = 0 ,
49
+ ),
50
+ debug_handle = debug_handle ,
51
+ )
52
+ xnn_graph .xnodes .append (ser_node )
Original file line number Diff line number Diff line change 49
49
SoftmaxConfig ,
50
50
SquareRootConfig ,
51
51
SubConfig ,
52
+ TanhConfig ,
52
53
UpsampleBilinear2dConfig ,
53
54
)
54
55
from executorch .backends .xnnpack .partition .config .node_configs import (
99
100
PreluConfig ,
100
101
ReciprocalSquareRootConfig ,
101
102
ReLUConfig ,
103
+ TanhConfig ,
102
104
# SDPAConfig, TODO: D60553559: preserving SDPA for fairseq fails
103
105
SigmoidConfig ,
104
106
SliceCopyConfig ,
Original file line number Diff line number Diff line change @@ -371,6 +371,13 @@ def supported_precision_types(self) -> List[ConfigPrecisionType]:
371
371
return [ConfigPrecisionType .FP32 ]
372
372
373
373
374
+ class TanhConfig (GenericNodePartitionerConfig ):
375
+ target_name = "tanh.default"
376
+
377
+ def supported_precision_types (self ) -> List [ConfigPrecisionType ]:
378
+ return [ConfigPrecisionType .FP32 ]
379
+
380
+
374
381
class MeanDimConfig (GenericNodePartitionerConfig ):
375
382
target_name = "mean.dim"
376
383
Original file line number Diff line number Diff line change 66
66
exir_ops .edge .aten .rsqrt .default ,
67
67
exir_ops .edge .aten .log .default ,
68
68
exir_ops .edge .aten .gelu .default ,
69
+ exir_ops .edge .aten .tanh .default ,
69
70
]
70
71
71
72
SUPPORTED_MODULES = [
Original file line number Diff line number Diff line change @@ -1162,6 +1162,36 @@ Error defineArgMaxPooling2dNode(
1162
1162
return Error::Ok;
1163
1163
}
1164
1164
1165
+ /*
1166
+ Define serialized tanh node into the subgraph, using the remapped ids
1167
+ to map the serialized ids, to the new ids generated when defining the
1168
+ tensor value
1169
+ */
1170
+ Error defineTanhNode (
1171
+ xnn_subgraph_t subgraph_ptr,
1172
+ const std::unordered_map<uint32_t , uint32_t >& remapped_ids,
1173
+ const NodePtr node,
1174
+ const fb_xnnpack::XNNGraph* graph) noexcept {
1175
+ MAYBE_UNUSED (graph);
1176
+
1177
+ auto graph_node = node->xnode_union_as_XNNTanh ();
1178
+
1179
+ xnn_status status = xnn_define_tanh (
1180
+ subgraph_ptr,
1181
+ remapped_ids.at (graph_node->input_id ()),
1182
+ remapped_ids.at (graph_node->output_id ()),
1183
+ graph_node->flags ());
1184
+
1185
+ ET_CHECK_OR_RETURN_ERROR (
1186
+ status == xnn_status_success,
1187
+ Internal,
1188
+ " Failed to create tanh node %i with code: %s" ,
1189
+ node->debug_handle (),
1190
+ xnn_status_to_string (status));
1191
+
1192
+ return Error::Ok;
1193
+ }
1194
+
1165
1195
/*
1166
1196
Defines serialized prelu node into the subgraph,
1167
1197
using the remapped ids to map the serialized ids,
@@ -1697,6 +1727,7 @@ DefineNodeFunc getDefineNodeFunc(fb_xnnpack::XNodeUnion nodeType) {
1697
1727
_DEFINE (Gelu)
1698
1728
_DEFINE (Hardswish)
1699
1729
_DEFINE (Log)
1730
+ _DEFINE (Tanh)
1700
1731
_DEFINE (Negate)
1701
1732
_DEFINE (Square)
1702
1733
_DEFINE (Clamp)
Original file line number Diff line number Diff line change @@ -154,6 +154,7 @@ union XNodeUnion {
154
154
XNNReciprocalSquareRoot: _XNNNode1x1,
155
155
XNNLog: _XNNNode1x1,
156
156
XNNGelu: _XNNNode1x1,
157
+ XNNTanh: _XNNNode1x1,
157
158
}
158
159
159
160
union XValueUnion {
Original file line number Diff line number Diff line change @@ -150,6 +150,7 @@ union XNodeUnion {
150
150
XNNReciprocalSquareRoot: _XNNNode1x1,
151
151
XNNLog: _XNNNode1x1,
152
152
XNNGelu: _XNNNode1x1,
153
+ XNNTanh: _XNNNode1x1,
153
154
}
154
155
155
156
union XValueUnion {
Original file line number Diff line number Diff line change @@ -319,6 +319,11 @@ class XNNLog(XNNNode1x1):
319
319
pass
320
320
321
321
322
+ @dataclass
323
+ class XNNTanh (XNNNode1x1 ):
324
+ pass
325
+
326
+
322
327
@dataclass
323
328
class XNNMaximum (XNNNode2x1 ):
324
329
pass
@@ -391,6 +396,7 @@ class XNNScaledDotProductAttention:
391
396
XNNReciprocalSquareRoot ,
392
397
XNNLog ,
393
398
XNNGelu ,
399
+ XNNTanh ,
394
400
]
395
401
396
402
Original file line number Diff line number Diff line change
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import unittest
8
+
9
+ import torch
10
+ from executorch .backends .xnnpack .test .tester import Tester
11
+
12
+
13
+ class TestTanh (unittest .TestCase ):
14
+ def setUp (self ):
15
+ torch ._dynamo .reset ()
16
+
17
+ class Tanh (torch .nn .Module ):
18
+ def __init__ (self ):
19
+ super ().__init__ ()
20
+
21
+ def forward (self , x ):
22
+ return torch .tanh (x )
23
+
24
+ def run_tanh_test (self , inputs ):
25
+ (
26
+ Tester (self .Tanh (), inputs )
27
+ .export ()
28
+ .check_count ({"torch.ops.aten.tanh.default" : 1 })
29
+ .to_edge_transform_and_lower ()
30
+ .check_count ({"torch.ops.higher_order.executorch_call_delegate" : 1 })
31
+ .check_not (["executorch_exir_dialects_edge__ops_aten_tanh_default" ])
32
+ .to_executorch ()
33
+ .serialize ()
34
+ .run_method_and_compare_outputs ()
35
+ )
36
+
37
+ def test_fp16_tanh (self ):
38
+ inputs = (torch .randn (20 ).to (torch .float16 ),)
39
+ self .run_tanh_test (inputs )
40
+
41
+ def test_fp32_tanh (self ):
42
+ inputs = (torch .randn (20 ),)
43
+ self .run_tanh_test (inputs )
You can’t perform that action at this time.
0 commit comments