@@ -30,7 +30,7 @@ def depthwise_conv2d_nchw(cfg, data, kernel, strides, padding, dilation, out_dty
3030 return nn .depthwise_conv2d_nchw (data , kernel , strides , padding , dilation , out_dtype )
3131
3232
33- # register customized schedule for arm cpu .
33+ # register customized schedule for Mali .
3434@autotvm .register_topi_schedule ("depthwise_conv2d_nchw.mali" )
3535def schedule_depthwise_conv2d_nchw (cfg , outs ):
3636 """Schedule depthwise conv2d
@@ -70,7 +70,7 @@ def depthwise_conv2d_nhwc(cfg, data, kernel, strides, padding, dilation, out_dty
7070 return nn .depthwise_conv2d_nhwc (data , kernel , strides , padding , dilation , out_dtype )
7171
7272
73- # register customized schedule for arm cpu .
73+ # register customized schedule for Mali .
7474@autotvm .register_topi_schedule ("depthwise_conv2d_nhwc.mali" )
7575def schedule_depthwise_conv2d_nhwc (cfg , outs ):
7676 """Schedule depthwise conv2d
@@ -124,8 +124,13 @@ def _schedule(cfg, s, pad_data, kernel, conv, layout):
124124
125125 # fallback support
126126 if cfg .is_fallback :
127- ref_log = autotvm .tophub .load_reference_log ("mali" , "rk3399" , "depthwise_conv2d_nchw.mali" )
128- cfg .fallback_with_reference_log (ref_log )
127+ if layout == "NCHW" :
128+ ref_log = autotvm .tophub .load_reference_log ("mali" , "rk3399" , "depthwise_conv2d_nchw.mali" )
129+ cfg .fallback_with_reference_log (ref_log )
130+ else :
131+ cfg .fallback_split ("tile_c" , [- 1 , 4 , 2 ])
132+ cfg .fallback_split ("tile_y" , [- 1 , 4 , 2 ])
133+ cfg .fallback_split ("tile_x" , [- 1 , 4 , 2 ])
129134 ###### space definition end ######
130135
131136 # schedule padding
0 commit comments