@@ -59,6 +59,11 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type):
5959 data , kernel = tinfos
6060 out_dtype = out_type .dtype
6161
62+ # Extract data types
63+ data_tensor , kernel_tensor = tinfos
64+ data_dtype = data_tensor .dtype
65+ kernel_dtype = kernel_tensor .dtype
66+
6267 idxd = tvm .tir .indexdiv
6368
6469 if topi_tmpl == "conv2d_nchw_spatial_pack.arm_cpu" :
@@ -169,4 +174,60 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type):
169174
170175 return relay .nn .conv2d (* inputs , ** new_attrs )
171176
177+ if topi_tmpl == "conv2d_NCHWc.x86" :
178+ # Converting NCHW to NCHWc.
179+ assert data_layout == "NCHW" and kernel_layout == "OIHW"
180+ if cfg .is_fallback :
181+ _get_default_config (cfg , data_tensor , kernel_tensor , strides , padding ,
182+ out_dtype , False , data_layout )
183+ batch_size , in_channel , height , width = get_const_tuple (data_tensor .shape )
184+ out_channel , _ , kh , kw = get_const_tuple (kernel_tensor .shape )
185+ ic_bn , oc_bn = cfg ["tile_ic" ].size [- 1 ], cfg ["tile_oc" ].size [- 1 ]
186+
187+ # update new attrs
188+ new_attrs ['channels' ] = out_channel
189+ new_attrs ['data_layout' ] = 'NCHW%dc' % ic_bn
190+ # (oc, ic, h, w) -> (OC, IC, h, w, ic, oc)
191+ new_attrs ['kernel_layout' ] = 'OIHW%di%do' % (ic_bn , oc_bn )
192+ new_attrs ['out_layout' ] = 'NCHW%dc' % oc_bn
193+
194+ # Store altered operator's config
195+ new_data = te .placeholder ((batch_size , in_channel // ic_bn , height , width , ic_bn ),
196+ dtype = data_dtype )
197+ new_kernel = te .placeholder ((out_channel // oc_bn , in_channel // ic_bn ,
198+ kh , kw , ic_bn , oc_bn ), dtype = kernel_tensor .dtype )
199+ new_workload = autotvm .task .args_to_workload (
200+ [new_data , new_kernel , strides , padding , dilation , new_attrs ["data_layout" ],
201+ new_attrs ["out_layout" ], out_dtype ], topi_tmpl )
202+ dispatch_ctx .update (target , new_workload , cfg )
203+ return relay .nn .contrib_conv2d_nchwc (* inputs , ** new_attrs )
204+
205+ if topi_tmpl == "depthwise_conv2d_NCHWc.x86" :
206+ # Converting NCHW to NCHWc.
207+ assert data_layout == "NCHW" and kernel_layout == "OIHW"
208+ if cfg .is_fallback :
209+ _get_default_config (cfg , data_tensor , kernel_tensor , strides , padding ,
210+ out_dtype , True , data_layout )
211+
212+ batch_size , in_channel , height , width = get_const_tuple (data_tensor .shape )
213+ out_channel , channel_multiplier , kh , kw = get_const_tuple (kernel_tensor .shape )
214+ ic_bn , oc_bn = cfg ["tile_ic" ].size [- 1 ], cfg ["tile_oc" ].size [- 1 ]
215+ assert channel_multiplier == 1
216+
217+ # update new attrs
218+ new_attrs ['channels' ] = out_channel
219+ new_attrs ['data_layout' ] = 'NCHW%dc' % ic_bn
220+ new_attrs ['kernel_layout' ] = 'OIHW1i%do' % oc_bn
221+ new_attrs ['out_layout' ] = 'NCHW%dc' % oc_bn
222+
223+ # Store altered operator's config.
224+ new_data = te .placeholder ((batch_size , in_channel // ic_bn , height , width , ic_bn ),
225+ dtype = data_dtype )
226+ new_kernel = te .placeholder ((out_channel // oc_bn , 1 , kh , kw , 1 , oc_bn ), dtype = kernel_dtype )
227+ new_workload = autotvm .task .args_to_workload (
228+ [new_data , new_kernel , strides , padding , dilation , new_attrs ['data_layout' ],
229+ new_attrs ['out_layout' ], out_dtype ], topi_tmpl )
230+ dispatch_ctx .update (target , new_workload , cfg )
231+ return relay .nn .contrib_depthwise_conv2d_nchwc (* inputs , ** new_attrs )
232+
172233 return None
0 commit comments