6
6
# pyre-unsafe
7
7
8
8
import torch
9
+ from executorch .backends .transforms .utils import (
10
+ create_constant_placeholder ,
11
+ delete_constant_placeholder ,
12
+ )
9
13
from executorch .exir import ExportedProgram
10
14
from executorch .exir .dialects ._ops import ops as exir_ops
11
15
from executorch .exir .pass_base import ExportPass , PassResult
12
16
from torch ._export .utils import get_buffer , get_param
17
+ from torch .export .graph_signature import InputKind
13
18
from torch .fx import Node
14
19
from torch .nn .utils .fusion import fuse_conv_bn_weights
15
20
@@ -23,7 +28,7 @@ def __init__(self, exported_program: ExportedProgram):
23
28
self .exported_program = exported_program
24
29
super ().__init__ ()
25
30
26
- def is_fuseable_conv_bn (self , node : Node ):
31
+ def is_fuseable_conv_bn (self , node : Node ) -> bool :
27
32
"""Returns True if node is a batchnorm that can be fused into
28
33
a parent convolution."""
29
34
if node .op != "call_function" :
@@ -44,15 +49,19 @@ def is_fuseable_conv_bn(self, node: Node):
44
49
# Since we change the output of the conv, fuse only if it has single user.
45
50
if len (conv .users ) > 1 :
46
51
return False
47
- # For similar reasons, only fuse if conv parameters have single user.
48
- if len (conv .all_input_nodes [1 ].users ) > 1 :
49
- return False
50
- if len (conv .all_input_nodes ) > 2 and len (conv .all_input_nodes [2 ].users ) > 1 :
51
- return False
52
52
return True
53
53
54
+ def get_bias_name (self , conv_weight_node : Node , conv_bias_node : Node ) -> str :
55
+ if conv_bias_node :
56
+ return conv_bias_node .name + "_fused_bn"
57
+ elif "weight" in conv_weight_node .name :
58
+ return conv_weight_node .name .replace ("weight" , "bias" ) + "_fused_bn"
59
+ else :
60
+ return conv_weight_node .name + "_bias_fused_bn"
61
+
54
62
def call (self , graph_module : torch .fx .GraphModule ) -> PassResult : # noqa: C901
55
63
modified = False
64
+ constant_placeholders_to_delete = set ()
56
65
for node in graph_module .graph .nodes :
57
66
if not self .is_fuseable_conv_bn (node ):
58
67
continue
@@ -64,68 +73,93 @@ def get_param_or_none(arg) -> torch.nn.Parameter | None:
64
73
)
65
74
66
75
# Get weight, bias, mean, var and epsilon from the batchnorm
67
- bn = node
68
- conv , bn_weight_node , bn_bias_node , bn_mean_node , bn_var_node = bn .args [0 :5 ]
69
- bn_weight = get_param_or_none (bn_weight_node )
70
- bn_bias = get_param_or_none (bn_bias_node )
71
-
72
- running_mean = get_buffer (self .exported_program , bn_mean_node )
73
- running_var = get_buffer (self .exported_program , bn_var_node )
74
- if running_mean is None or running_var is None :
76
+ bn_node = node
77
+ conv , bn_weight_node , bn_bias_node , bn_mean_node , bn_var_node = (
78
+ bn_node .args [0 :5 ]
79
+ )
80
+ bn_weight_tensor = get_param_or_none (bn_weight_node )
81
+ bn_bias_tensor = get_param_or_none (bn_bias_node )
82
+ bn_mean_tensor = get_buffer (self .exported_program , bn_mean_node )
83
+ bn_var_tensor = get_buffer (self .exported_program , bn_var_node )
84
+ if bn_mean_tensor is None or bn_var_tensor is None :
75
85
raise ValueError (
76
86
"Parameters running_mean and running_var of batchnorm can't be None."
77
87
)
78
- epsilon = bn .args [- 1 ]
88
+ epsilon = bn_node .args [- 1 ]
79
89
80
90
# Get weight and bias from conv
81
91
conv_weight_node , conv_bias_node = conv .args [1 :3 ]
82
- conv_weight = get_param (self .exported_program , conv_weight_node )
83
- conv_bias = get_param_or_none (conv_bias_node )
84
- if conv_weight is None :
92
+ conv_weight_tensor = get_param (self .exported_program , conv_weight_node )
93
+ conv_bias_tensor = get_param_or_none (conv_bias_node )
94
+ if conv_weight_tensor is None :
85
95
raise ValueError ("Parameter weight of convolution can't be None." )
86
96
87
97
# Compute conv parameters folded with batchnorm
88
98
fused_conv_weight , fused_conv_bias = fuse_conv_bn_weights (
89
- conv_weight ,
90
- conv_bias ,
91
- running_mean ,
92
- running_var ,
99
+ conv_weight_tensor ,
100
+ conv_bias_tensor ,
101
+ bn_mean_tensor ,
102
+ bn_var_tensor ,
93
103
epsilon ,
94
- bn_weight ,
95
- bn_bias ,
104
+ bn_weight_tensor ,
105
+ bn_bias_tensor ,
96
106
)
97
107
98
- # Set the conv parameters to fused value
99
- def try_set_param (
100
- param_node : Node | None , param_value : torch .nn .Parameter
101
- ) -> bool :
102
- """set_param but check if param_node is None first. Return True if param was set successfully, otherwise False."""
103
- if param_node is not None :
104
- param_name = (
105
- self .exported_program .graph_signature .inputs_to_parameters [
106
- param_node .name
107
- ]
108
+ # Create fused weights and bias to conv and replace conv args
109
+ with graph_module .graph .inserting_before (conv_weight_node ):
110
+ fused_conv_weight_node = create_constant_placeholder (
111
+ exp_program = self .exported_program ,
112
+ graph = graph_module .graph ,
113
+ kind = InputKind .PARAMETER ,
114
+ name = conv_weight_node .name + "_fused_bn" ,
115
+ data = fused_conv_weight ,
116
+ )
117
+
118
+ if fused_conv_bias is not None :
119
+ fused_conv_bias_node = create_constant_placeholder (
120
+ exp_program = self .exported_program ,
121
+ graph = graph_module .graph ,
122
+ kind = InputKind .PARAMETER ,
123
+ name = self .get_bias_name (conv_weight_node , conv_bias_node ),
124
+ data = fused_conv_bias ,
108
125
)
109
- self .exported_program .state_dict [param_name ] = param_value
110
- return True
111
- return False
126
+ else :
127
+ fused_conv_bias_node = None
128
+
129
+ conv .args = (
130
+ conv .args [0 ],
131
+ fused_conv_weight_node ,
132
+ fused_conv_bias_node ,
133
+ * conv .args [3 :],
134
+ )
112
135
113
- try_set_param (conv_weight_node , fused_conv_weight )
114
- if not try_set_param (conv_bias_node , fused_conv_bias ) and try_set_param (
115
- bn_bias_node , fused_conv_bias
116
- ):
117
- # pyre-ignore[60]
118
- # Conv didn't have bias but batchnorm did, steal bias from batchnorm.
119
- conv_args = (* conv .args [0 :2 ], bn_bias_node , * conv .args [3 :])
120
- conv .args = conv_args
121
-
122
- # Erasing nodes is handled by dead-code elimination.
123
- for user in bn .users :
136
+ # Erasing batch-norm nodes is handled by dead-code elimination. After that we may remove their constant placeholder inputs
137
+ for user in bn_node .users :
124
138
user .replace_all_uses_with (conv )
139
+
140
+ constant_placeholders_to_delete .update (
141
+ [
142
+ bn_weight_node ,
143
+ bn_bias_node ,
144
+ bn_mean_node ,
145
+ bn_var_node ,
146
+ conv_weight_node ,
147
+ conv_bias_node ,
148
+ ]
149
+ )
125
150
modified = True
126
151
127
152
if modified :
128
153
graph_module .graph .eliminate_dead_code ()
154
+ for constant_placeholder in constant_placeholders_to_delete :
155
+ if (constant_placeholder is not None ) and (
156
+ len (constant_placeholder .users ) == 0
157
+ ):
158
+ delete_constant_placeholder (
159
+ self .exported_program , constant_placeholder
160
+ )
161
+
129
162
graph_module .recompile ()
130
163
graph_module = super ().call (graph_module ).graph_module
164
+
131
165
return PassResult (graph_module = graph_module , modified = modified )
0 commit comments