3434
3535@autotvm .register_topi_compute ("conv2d_nchw_spatial_pack.arm_cpu" )
3636def conv2d_nchw_spatial_pack (cfg , data , kernel , strides , padding , dilation , out_dtype ):
37+ """Compute conv2d with NCHW layout"""
3738 return conv2d_spatial_pack_nchw (cfg , data , kernel , strides , padding ,
3839 dilation , out_dtype , num_tile = 2 )
3940
4041
4142@autotvm .register_topi_schedule ("conv2d_nchw_spatial_pack.arm_cpu" )
4243def schedule_conv2d_nchw_spatial_pack (cfg , outs ):
44+ """Create schedule for conv2d_nchw"""
4345 s = tvm .create_schedule ([x .op for x in outs ])
4446
4547 def _callback (op ):
@@ -69,12 +71,14 @@ def _callback(op):
6971
7072@autotvm .register_topi_compute ("conv2d_nhwc_spatial_pack.arm_cpu" )
7173def conv2d_nhwc_spatial_pack (cfg , data , kernel , strides , padding , dilation , out_dtype ):
74+ """Compute conv2d with NHWC layout"""
7275 return conv2d_spatial_pack_nhwc (cfg , data , kernel , strides , padding ,
7376 dilation , out_dtype )
7477
7578
7679@autotvm .register_topi_schedule ("conv2d_nhwc_spatial_pack.arm_cpu" )
7780def schedule_conv2d_nhwc_spatial_pack (cfg , outs ):
81+ """Create schedule for conv2d_nhwc"""
7882 s = tvm .create_schedule ([x .op for x in outs ])
7983
8084 def _callback (op ):
@@ -87,13 +91,15 @@ def _callback(op):
8791
8892@autotvm .register_topi_compute ("conv2d_nchw_winograd.arm_cpu" )
8993def conv2d_nchw_winograd (cfg , data , kernel , strides , padding , dilation , out_dtype ):
94+ """Compute conv2d_nchw layout using Winograd with weight transform"""
9095 tile_size = 4
9196 return _decl_winograd (cfg , data , kernel , strides , padding , dilation ,
9297 out_dtype , tile_size )
9398
9499
95100@autotvm .register_topi_schedule ("conv2d_nchw_winograd.arm_cpu" )
96101def schedule_conv2d_nchw_winograd (cfg , outs ):
102+ """Create schedule for conv2d_nchw_winograd"""
97103 s = tvm .create_schedule ([x .op for x in outs ])
98104
99105 def _callback (op ):
@@ -286,6 +292,7 @@ def _schedule_winograd(cfg, s, output, last):
286292
287293@autotvm .register_topi_compute ("conv2d_nchw_winograd_nnpack.arm_cpu" )
288294def conv2d_nchw_winograd_nnpack (cfg , data , kernel , strides , padding , dilation , out_dtype ):
295+ """Compute conv2d_nchw using nnpack Winograd implementation"""
289296 dtype = data .dtype
290297 if dtype == "float32" :
291298 return _conv2d_arm_cpu_winograd_nnpack (
@@ -302,6 +309,7 @@ def conv2d_nchw_winograd_nnpack(cfg, data, kernel, strides, padding, dilation, o
302309
303310@autotvm .register_topi_schedule ("conv2d_nchw_winograd_nnpack.arm_cpu" )
304311def schedule_conv2d_nchw_winograd_nnpack (cfg , outs ):
312+ """Create schedule for conv2d_nchw_winograd_nnpack"""
305313 s = tvm .create_schedule ([x .op for x in outs ])
306314
307315 def _callback (op ):
@@ -371,6 +379,7 @@ def _schedule_winograd_nnpack(cfg, s, output, last):
371379@autotvm .register_topi_compute ("conv2d_nchw_winograd_nnpack_without_weight_transform.arm_cpu" )
372380def conv2d_nchw_winograd_nnpack_without_weight_transform (
373381 cfg , data , transformed_kernel , bias , strides , padding , dilation , out_dtype ):
382+ """Compute conv2d_nchw using NNPack winograd without weight transform"""
374383 N , CI , IH , IW = get_const_tuple (data .shape )
375384 if isinstance (dilation , int ):
376385 dilation_h = dilation_w = dilation
0 commit comments