@@ -54,12 +54,13 @@ def schedule_conv2d_nhwc_direct(cfg, s, Conv):
5454 cfg .define_knob ("vthread_n" , [1 ] if dynamic_batch else [1 , 2 ])
5555 cfg .define_knob ("vthread_c" , [1 , 2 ])
5656 cfg .define_knob ("step" , [16 , 3 , 32 , 64 ])
57+ cfg .define_knob ("vectorize" , [1 , 2 , 4 , 8 ])
5758
5859 # fallback support
5960 target = tvm .target .Target .current ()
6061 if cfg .is_fallback :
6162 ref_log = autotvm .tophub .load_reference_log (
62- target .kind .name , target .model , "conv2d_nhwc.cuda "
63+ target .kind .name , target .model , "conv2d_nhwc.gpu "
6364 )
6465 cfg .fallback_with_reference_log (ref_log )
6566
@@ -70,6 +71,7 @@ def schedule_conv2d_nhwc_direct(cfg, s, Conv):
7071 vthread_n = cfg ["vthread_n" ].val
7172 vthread_c = cfg ["vthread_c" ].val
7273 step = cfg ["step" ].val
74+ vec_factor = cfg ["vectorize" ].val
7375 block_factor_c = tile_c * num_thread_c * vthread_c
7476
7577 offset = 8
@@ -85,15 +87,17 @@ def schedule_conv2d_nhwc_direct(cfg, s, Conv):
8587 thread_yz = te .thread_axis ((0 , vthread_n ), "vthread" , name = "vy" )
8688
8789 # Schedule for output
88- ni , hi , wi , fi = s [output ].op .axis
89- bx = s [output ].fuse (hi , wi )
90+ ni , _ , wi , fi = s [output ].op .axis
91+ bx = wi
92+ fi , vec = s [output ].split (fi , factor = vec_factor )
93+ s [output ].vectorize (vec )
9094 tx , fi = s [output ].split (fi , factor = tile_c )
9195 txz , tx = s [output ].split (tx , factor = num_thread_c )
9296 bz , txz = s [output ].split (txz , factor = vthread_c )
9397 ty , ni = s [output ].split (ni , factor = tile_n )
9498 tyz , ty = s [output ].split (ty , factor = num_thread_n )
9599 by , tyz = s [output ].split (tyz , factor = vthread_n )
96- s [output ].reorder (bx , by , bz , tyz , txz , ty , tx , ni , fi )
100+ s [output ].reorder (bx , by , bz , tyz , txz , ty , tx , ni , fi , vec )
97101 s [output ].bind (bz , block_z )
98102 s [output ].bind (by , block_y )
99103 s [output ].bind (bx , block_x )
@@ -106,6 +110,7 @@ def schedule_conv2d_nhwc_direct(cfg, s, Conv):
106110 ni , yi , xi , fi = s [OL ].op .axis
107111 ry , rx , rc = s [OL ].op .reduce_axis
108112 rco , rci = s [OL ].split (rc , factor = step )
113+ s [OL ].vectorize (fi )
109114 s [OL ].reorder (rco , ry , rx , rci , ni , fi )
110115
111116 s [AA ].compute_at (s [OL ], rx )
@@ -125,6 +130,8 @@ def schedule_conv2d_nhwc_direct(cfg, s, Conv):
125130 _ , _ , ic , o = s [WW ].op .axis
126131 t = s [WW ].fuse (ic , o )
127132 s [WW ].storage_align (ic , W_align - 1 , W_align )
133+ t , vec = s [WW ].split (t , factor = vec_factor )
134+ s [WW ].vectorize (vec )
128135 ty , tx = s [WW ].split (t , factor = num_thread_c )
129136 _ , ty = s [WW ].split (ty , factor = num_thread_n )
130137 s [WW ].bind (tx , thread_x )
0 commit comments