@@ -158,20 +158,19 @@ def _alter_conv2d_layout(attrs, inputs, tinfo, F):
158158 return None
159159 return F .nn .contrib_conv2d_nchwc_int8 (* copy_inputs , ** new_attrs )
160160
161- if data_layout == 'NCHW' and attrs ['kernel_layout' ] == 'OIHW' :
162- # (oc, ic, h, w) -> (OC, IC, h, w, ic, oc)
163- new_attrs ['kernel_layout' ] = 'OIHW%di%do' % (ic_bn , oc_bn )
164- # Store altered operator's config
165- new_kernel = tvm .placeholder ((out_channel // oc_bn , in_channel // ic_bn ,
166- kh , kw , ic_bn , oc_bn ), dtype = kernel_tensor .dtype )
167- new_workload = autotvm .task .args_to_workload (
168- [new_data , new_kernel , strides , padding , dilation , new_attrs [layout_name ],
169- new_attrs ['out_layout' ], out_dtype ], conv2d_NCHWc )
170- dispatch_ctx .update (target , new_workload , cfg )
171-
172- if F .__name__ == 'nnvm.symbol' :
173- return F .contrib .conv2d_NCHWc (* copy_inputs , ** new_attrs )
174- return F .nn .contrib_conv2d_nchwc (* copy_inputs , ** new_attrs )
161+ # (oc, ic, h, w) -> (OC, IC, h, w, ic, oc)
162+ new_attrs ['kernel_layout' ] = 'OIHW%di%do' % (ic_bn , oc_bn )
163+ # Store altered operator's config
164+ new_kernel = tvm .placeholder ((out_channel // oc_bn , in_channel // ic_bn ,
165+ kh , kw , ic_bn , oc_bn ), dtype = kernel_tensor .dtype )
166+ new_workload = autotvm .task .args_to_workload (
167+ [new_data , new_kernel , strides , padding , dilation , new_attrs [layout_name ],
168+ new_attrs ['out_layout' ], out_dtype ], conv2d_NCHWc )
169+ dispatch_ctx .update (target , new_workload , cfg )
170+
171+ if F .__name__ == 'nnvm.symbol' :
172+ return F .contrib .conv2d_NCHWc (* copy_inputs , ** new_attrs )
173+ return F .nn .contrib_conv2d_nchwc (* copy_inputs , ** new_attrs )
175174 return None
176175
177176
0 commit comments