@@ -51,77 +51,14 @@ def batch_norm(
5151 # We perform constant folding for batch norm when the weight, bias, running_mean, and running_var are all tensors.
5252 # Batch norm operation can be fused into a single layer, which is more efficient than the original implementation.
5353 # In this way, the batch norm layer will be fused with the Convolution layer and get a performance boost.
54- if all (
54+ if not all (
5555 [
5656 isinstance (weight , torch .Tensor ),
5757 isinstance (bias , torch .Tensor ),
5858 isinstance (running_mean , torch .Tensor ),
5959 isinstance (running_var , torch .Tensor ),
6060 ]
6161 ):
62- if weight is None :
63- weight = 1.0
64-
65- if bias is None :
66- bias = 0.0
67-
68- if running_mean is None :
69- running_mean = 0.0
70-
71- if running_var is None :
72- running_var = 1.0
73- adjusted_scale = weight / torch .sqrt (running_var + eps )
74- adjusted_bias = bias - running_mean * adjusted_scale
75- power = torch .ones_like (adjusted_scale )
76- adjusted_scale = to_trt_weights (
77- ctx ,
78- adjusted_scale ,
79- name ,
80- layer_type_name = "SCALE" ,
81- weight_type_name = "SCALE" ,
82- target = target ,
83- source_ir = source_ir ,
84- )
85- adjusted_bias = to_trt_weights (
86- ctx ,
87- adjusted_bias ,
88- name ,
89- layer_type_name = "SCALE" ,
90- weight_type_name = "SHIFT" ,
91- target = target ,
92- source_ir = source_ir ,
93- )
94-
95- power = to_trt_weights (
96- ctx ,
97- power ,
98- name ,
99- layer_type_name = "SCALE" ,
100- weight_type_name = "POWER" ,
101- target = target ,
102- source_ir = source_ir ,
103- )
104-
105- output_shape = input .shape
106- if len (input .shape ) < 4 :
107-
108- new_shape = (
109- (input .shape [0 ], input .shape [1 ], 1 , 1 )
110- if len (input .shape ) == 2
111- else (input .shape [0 ], input .shape [1 ], input .shape [2 ], 1 )
112- )
113- input = impl .shuffle .reshape (
114- ctx , target , source_ir , f"{ name } _reshape_2d" , input , new_shape
115- )
116-
117- layer = ctx .net .add_scale_nd (
118- input , trt .ScaleMode .CHANNEL , adjusted_bias , adjusted_scale , power , 1
119- )
120- set_layer_name (layer , target , name , source_ir )
121- output = layer .get_output (0 )
122-
123- else :
124-
12562 # We name the weight here according to the state_dict name
12663 weight = (
12764 get_trt_tensor (ctx , 1.0 , f"{ name } _weight" )
@@ -206,6 +143,70 @@ def batch_norm(
206143 bias_adjusted_reshape ,
207144 )
208145
146+ else :
147+ if weight is None :
148+ weight = 1.0
149+
150+ if bias is None :
151+ bias = 0.0
152+
153+ if running_mean is None :
154+ running_mean = 0.0
155+
156+ if running_var is None :
157+ running_var = 1.0
158+ adjusted_scale , adjusted_bias = batch_norm_constant_folding (
159+ weight , bias , running_mean , running_var , eps
160+ )
161+ power = torch .ones_like (adjusted_scale )
162+
163+ adjusted_scale = to_trt_weights (
164+ ctx ,
165+ adjusted_scale ,
166+ name ,
167+ layer_type_name = "SCALE" ,
168+ weight_type_name = "SCALE" ,
169+ target = target ,
170+ source_ir = source_ir ,
171+ )
172+ adjusted_bias = to_trt_weights (
173+ ctx ,
174+ adjusted_bias ,
175+ name ,
176+ layer_type_name = "SCALE" ,
177+ weight_type_name = "SHIFT" ,
178+ target = target ,
179+ source_ir = source_ir ,
180+ )
181+
182+ power = to_trt_weights (
183+ ctx ,
184+ power ,
185+ name ,
186+ layer_type_name = "SCALE" ,
187+ weight_type_name = "POWER" ,
188+ target = target ,
189+ source_ir = source_ir ,
190+ )
191+
192+ output_shape = input .shape
193+ if len (input .shape ) < 4 :
194+
195+ new_shape = (
196+ (input .shape [0 ], input .shape [1 ], 1 , 1 )
197+ if len (input .shape ) == 2
198+ else (input .shape [0 ], input .shape [1 ], input .shape [2 ], 1 )
199+ )
200+ input = impl .shuffle .reshape (
201+ ctx , target , source_ir , f"{ name } _reshape_2d" , input , new_shape
202+ )
203+
204+ layer = ctx .net .add_scale_nd (
205+ input , trt .ScaleMode .CHANNEL , adjusted_bias , adjusted_scale , power , 1
206+ )
207+ set_layer_name (layer , target , name , source_ir )
208+ output = layer .get_output (0 )
209+
209210 # For BatchNorm1d, reshape output back to original shape if necessary
210211 if len (output_shape ) < 4 :
211212 output = impl .shuffle .reshape (
@@ -224,6 +225,18 @@ def batch_norm(
224225 return output
225226
226227
228+ def batch_norm_constant_folding (
229+ weight : torch .Tensor ,
230+ bias : torch .Tensor ,
231+ running_mean : torch .Tensor ,
232+ running_var : torch .Tensor ,
233+ eps : float ,
234+ ) -> Tuple [torch .Tensor , torch .Tensor ]:
235+ adjusted_scale = weight / torch .sqrt (running_var + eps )
236+ adjusted_bias = bias - running_mean * adjusted_scale
237+ return adjusted_scale , adjusted_bias
238+
239+
227240def native_layer_norm (
228241 ctx : ConversionContext ,
229242 target : Target ,
@@ -303,7 +316,7 @@ def native_group_norm(
303316 ctx , target , source_ir , f"{ name } _expand_bias_zero" , bias_zero , shape
304317 )
305318
306- axes = get_axes_for_reduce_op ([ i for i in range (1 if group == 1 else 2 , rank )] )
319+ axes = get_axes_for_reduce_op (list ( range (1 if group == 1 else 2 , rank )) )
307320
308321 # INormalizationLayer scales the normalized output per-group, but PyTorch scales the normalized output per-channel,
309322 # hence causing diverse result. Let TensorRT does no-op for scaling here, and do scaling ourselves later
0 commit comments