File tree Expand file tree Collapse file tree 10 files changed +146
-0
lines changed Expand file tree Collapse file tree 10 files changed +146
-0
lines changed Original file line number Diff line number Diff line change 20
20
op_dynamic_quantize_ops ,
21
21
op_elu ,
22
22
op_floor ,
23
+ op_gelu ,
23
24
op_hardswish ,
24
25
op_hardtanh ,
25
26
op_leaky_relu ,
Original file line number Diff line number Diff line change
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from typing import Dict
8
+
9
+ import torch
10
+ from executorch .backends .xnnpack .operators .node_visitor import (
11
+ NodeVisitor ,
12
+ register_node_visitor ,
13
+ )
14
+ from executorch .backends .xnnpack .serialization .xnnpack_graph_schema import (
15
+ XNNGelu ,
16
+ XNNGraph ,
17
+ XNode ,
18
+ )
19
+ from executorch .backends .xnnpack .utils .utils import get_input_node
20
+
21
+
22
+ @register_node_visitor
23
+ class GeluVisitor (NodeVisitor ):
24
+ target = "aten.gelu.default"
25
+
26
+ def __init__ (self , * args ) -> None :
27
+ super ().__init__ (* args )
28
+
29
+ def define_node (
30
+ self ,
31
+ node : torch .fx .Node ,
32
+ xnn_graph : XNNGraph ,
33
+ vals_to_ids : Dict [torch .fx .Node , int ],
34
+ debug_handle : int ,
35
+ ) -> None :
36
+ self .define_nodes_tensor_inputs_outputs (node , xnn_graph , vals_to_ids )
37
+
38
+ # input
39
+ input_id = vals_to_ids [get_input_node (node , 0 )]
40
+
41
+ # output
42
+ output_id = vals_to_ids [node ]
43
+
44
+ ser_node = XNode (
45
+ xnode_union = XNNGelu (
46
+ input_id = input_id ,
47
+ output_id = output_id ,
48
+ flags = 0 ,
49
+ ),
50
+ debug_handle = debug_handle ,
51
+ )
52
+ xnn_graph .xnodes .append (ser_node )
Original file line number Diff line number Diff line change 26
26
DeQuantizedPerTensorConfig ,
27
27
DivConfig ,
28
28
FloorConfig ,
29
+ GeluConfig ,
29
30
HardswishConfig ,
30
31
# EluConfig,
31
32
HardtanhConfig ,
79
80
DivConfig ,
80
81
# EluConfig, # Waiting for PyTorch Pin Update
81
82
FloorConfig ,
83
+ GeluConfig ,
82
84
HardtanhConfig ,
83
85
HardswishConfig ,
84
86
LeakyReLUConfig ,
Original file line number Diff line number Diff line change @@ -343,6 +343,13 @@ def supported_precision_types(self) -> List[ConfigPrecisionType]:
343
343
return [ConfigPrecisionType .FP32 ]
344
344
345
345
346
+ class GeluConfig (GenericNodePartitionerConfig ):
347
+ target_name = "gelu.default"
348
+
349
+ def supported_precision_types (self ) -> List [ConfigPrecisionType ]:
350
+ return [ConfigPrecisionType .FP32 ]
351
+
352
+
346
353
class HardswishConfig (GenericNodePartitionerConfig ):
347
354
target_name = "hardswish.default"
348
355
Original file line number Diff line number Diff line change 65
65
exir_ops .edge .aten .addmm .default , # TODO(T163877189) add constraint for addmm
66
66
exir_ops .edge .aten .rsqrt .default ,
67
67
exir_ops .edge .aten .log .default ,
68
+ exir_ops .edge .aten .gelu .default ,
68
69
]
69
70
70
71
SUPPORTED_MODULES = [
Original file line number Diff line number Diff line change @@ -1448,6 +1448,36 @@ Error defineLogNode(
1448
1448
return Error::Ok;
1449
1449
}
1450
1450
1451
+ /*
1452
+ Define serialized gelu node into the subgraph, using the remapped ids
1453
+ to map the serialized ids, to the new ids generated when defining the
1454
+ tensor value
1455
+ */
1456
+ Error defineGeluNode (
1457
+ xnn_subgraph_t subgraph_ptr,
1458
+ const std::unordered_map<uint32_t , uint32_t >& remapped_ids,
1459
+ const NodePtr node,
1460
+ const fb_xnnpack::XNNGraph* graph) noexcept {
1461
+ MAYBE_UNUSED (graph);
1462
+
1463
+ auto graph_node = node->xnode_union_as_XNNGelu ();
1464
+
1465
+ xnn_status status = xnn_define_gelu (
1466
+ subgraph_ptr,
1467
+ remapped_ids.at (graph_node->input_id ()),
1468
+ remapped_ids.at (graph_node->output_id ()),
1469
+ graph_node->flags ());
1470
+
1471
+ ET_CHECK_OR_RETURN_ERROR (
1472
+ status == xnn_status_success,
1473
+ Internal,
1474
+ " Failed to create gelu node %i with code: %s" ,
1475
+ node->debug_handle (),
1476
+ xnn_status_to_string (status));
1477
+
1478
+ return Error::Ok;
1479
+ }
1480
+
1451
1481
/*
1452
1482
Define serialized ceiling node into the subgraph, using the remapped ids
1453
1483
to map the serialized ids, to the new ids generated when defining the
@@ -2009,6 +2039,7 @@ DefineNodeFunc getDefineNodeFunc(fb_xnnpack::XNodeUnion nodeType) {
2009
2039
_DEFINE (SquareRoot)
2010
2040
_DEFINE (ReciprocalSquareRoot)
2011
2041
_DEFINE (Ceiling)
2042
+ _DEFINE (Gelu)
2012
2043
_DEFINE (Hardswish)
2013
2044
_DEFINE (LeakyReLU)
2014
2045
_DEFINE (Log)
Original file line number Diff line number Diff line change @@ -140,6 +140,7 @@ union XNodeUnion {
140
140
XNNConvTranspose2d: _XNNNodeConv,
141
141
XNNReciprocalSquareRoot: _XNNNode1x1,
142
142
XNNLog: _XNNNode1x1,
143
+ XNNGelu: _XNNNode1x1,
143
144
}
144
145
145
146
union XValueUnion {
Original file line number Diff line number Diff line change @@ -136,6 +136,7 @@ union XNodeUnion {
136
136
XNNConvTranspose2d: _XNNNodeConv,
137
137
XNNReciprocalSquareRoot: _XNNNode1x1,
138
138
XNNLog: _XNNNode1x1,
139
+ XNNGelu: _XNNNode1x1,
139
140
}
140
141
141
142
union XValueUnion {
Original file line number Diff line number Diff line change @@ -291,6 +291,11 @@ class XNNCeiling(XNNNode1x1):
291
291
pass
292
292
293
293
294
+ @dataclass
295
+ class XNNGelu (XNNNode1x1 ):
296
+ pass
297
+
298
+
294
299
@dataclass
295
300
class XNNHardswish (XNNNode1x1 ):
296
301
pass
@@ -385,6 +390,7 @@ class XNNScaledDotProductAttention:
385
390
XNNBatchMatrixMultiply ,
386
391
XNNReciprocalSquareRoot ,
387
392
XNNLog ,
393
+ XNNGelu ,
388
394
]
389
395
390
396
Original file line number Diff line number Diff line change
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import unittest
8
+
9
+ import torch
10
+ from executorch .backends .xnnpack .test .tester import Tester
11
+
12
+
13
+ class TestGelu (unittest .TestCase ):
14
+ def setUp (self ):
15
+ torch ._dynamo .reset ()
16
+
17
+ class Gelu (torch .nn .Module ):
18
+ def __init__ (self ):
19
+ super ().__init__ ()
20
+ self .gelu = torch .nn .GELU ()
21
+
22
+ def forward (self , x ):
23
+ return self .gelu (x )
24
+
25
+ def run_gelu_test (self , inputs ):
26
+ (
27
+ Tester (self .Gelu (), inputs )
28
+ .export ()
29
+ .check_count ({"torch.ops.aten.gelu.default" : 1 })
30
+ .to_edge_transform_and_lower ()
31
+ .check_count ({"torch.ops.higher_order.executorch_call_delegate" : 1 })
32
+ .check_not (["executorch_exir_dialects_edge__ops_aten_gelu_default" ])
33
+ .to_executorch ()
34
+ .serialize ()
35
+ .run_method_and_compare_outputs ()
36
+ )
37
+
38
+ def test_fp16_gelu (self ):
39
+ inputs = (torch .randn (20 ).to (torch .float16 ),)
40
+ self .run_gelu_test (inputs )
41
+
42
+ def test_fp32_gelu (self ):
43
+ inputs = (torch .randn (20 ),)
44
+ self .run_gelu_test (inputs )
You can’t perform that action at this time.
0 commit comments