@@ -98,20 +98,21 @@ def intrin_func(ins, outs):
9898 return tvm .decl_tensor_intrin (C .op , intrin_func , binds = {A : BA , C : BC })
9999
100100
101- def test_tensor_core_gemm ():
102- n = 4096
101+ def test_tensor_core_batch_matmal ():
102+ batch_size = 20
103+ n = 2048
103104 m , l = n , n
104105 assert (n % 16 == 0 )
105106 assert (m % 16 == 0 )
106107 assert (l % 16 == 0 )
107108 nn , mm , ll = n // 16 , m // 16 , l // 16
108- A = tvm .placeholder ((nn , ll , 16 , 16 ), name = 'A' , dtype = 'float16' )
109- B = tvm .placeholder ((ll , mm , 16 , 16 ), name = 'B' , dtype = 'float16' )
109+ A = tvm .placeholder ((batch_size , nn , ll , 16 , 16 ), name = 'A' , dtype = 'float16' )
110+ B = tvm .placeholder ((batch_size , ll , mm , 16 , 16 ), name = 'B' , dtype = 'float16' )
110111 k1 = tvm .reduce_axis ((0 , ll ), name = 'k1' )
111112 k2 = tvm .reduce_axis ((0 , 16 ), name = 'k2' )
112- C = tvm .compute ((nn , mm , 16 , 16 ),
113- lambda i , j , ii , jj :
114- tvm .sum (A [i , k1 , ii , k2 ].astype ('float' ) * B [k1 , j , k2 , jj ].astype ('float' ), axis = [k1 , k2 ]),
113+ C = tvm .compute ((batch_size , nn , mm , 16 , 16 ),
114+ lambda b , i , j , ii , jj :
115+ tvm .sum (A [b , i , k1 , ii , k2 ].astype ('float' ) * B [b , k1 , j , k2 , jj ].astype ('float' ), axis = [k1 , k2 ]),
115116 name = 'Fragment_C' )
116117 s = tvm .create_schedule (C .op )
117118
@@ -125,6 +126,7 @@ def test_tensor_core_gemm():
125126
126127 block_x = tvm .thread_axis ('blockIdx.x' )
127128 block_y = tvm .thread_axis ('blockIdx.y' )
129+ block_z = tvm .thread_axis ('blockIdx.z' )
128130 thread_x = tvm .thread_axis ('threadIdx.x' )
129131 thread_y = tvm .thread_axis ('threadIdx.y' )
130132 thread_z = tvm .thread_axis ('threadIdx.z' )
@@ -135,19 +137,20 @@ def test_tensor_core_gemm():
135137 BF = s .cache_read (BS , 'wmma.matrix_b' , [C ])
136138 CF = s .cache_write (C , 'wmma.accumulator' )
137139
138- i , j , kernel_i , kernel_j = s [C ].op .axis
140+ b , i , j , kernel_i , kernel_j = s [C ].op .axis
139141 i , ii = s [C ].split (i , factor = warp_row_tiles )
140142 block_i , i = s [C ].split (i , factor = block_row_warps )
141143 j , jj = s [C ].split (j , factor = warp_col_tiles )
142144 block_j , j = s [C ].split (j , factor = block_col_warps )
143145 s [C ].reorder (block_i , block_j , i , j , ii , jj , kernel_i , kernel_j )
146+ s [C ].bind (b , block_z )
144147 s [C ].bind (block_i , block_x )
145148 s [C ].bind (block_j , block_y )
146149 s [C ].bind (i , thread_y )
147150 s [C ].bind (j , thread_z )
148151
149152 s [CF ].compute_at (s [C ], j )
150- warp_i , warp_j , _i , _j = s [CF ].op .axis
153+ b , warp_i , warp_j , _i , _j = s [CF ].op .axis
151154 k , _k = CF .op .reduce_axis
152155 ko , ki = s [CF ].split (k , factor = chunk )
153156 s [CF ].reorder (ko , ki , warp_i , warp_j , _i , _j , _k )
@@ -156,7 +159,7 @@ def test_tensor_core_gemm():
156159 s [BF ].compute_at (s [CF ], ki )
157160
158161 s [AS ].compute_at (s [CF ], ko )
159- xo , yo , xi , yi = AS .op .axis
162+ b , xo , yo , xi , yi = AS .op .axis
160163 tx , xo = s [AS ].split (xo , nparts = block_row_warps )
161164 ty , yo = s [AS ].split (yo , nparts = block_col_warps )
162165 t = s [AS ].fuse (xi , yi )
@@ -167,7 +170,7 @@ def test_tensor_core_gemm():
167170 s [AS ].vectorize (ti )
168171
169172 s [BS ].compute_at (s [CF ], ko )
170- xo , yo , xi , yi = BS .op .axis
173+ b , xo , yo , xi , yi = BS .op .axis
171174 tx , xo = s [BS ].split (xo , nparts = block_row_warps )
172175 ty , yo = s [BS ].split (yo , nparts = block_col_warps )
173176 t = s [BS ].fuse (xi , yi )
@@ -184,23 +187,23 @@ def test_tensor_core_gemm():
184187 func = tvm .build (s , [A , B , C ], 'cuda' )
185188
186189 ctx = tvm .gpu (0 )
187- a_np = np .random .uniform (size = (nn , nn , 16 , 16 )).astype (A .dtype )
188- b_np = np .random .uniform (size = (nn , nn , 16 , 16 )).astype (B .dtype )
190+ a_np = np .random .uniform (size = (batch_size , nn , nn , 16 , 16 )).astype (A .dtype )
191+ b_np = np .random .uniform (size = (batch_size , nn , nn , 16 , 16 )).astype (B .dtype )
189192 a = tvm .nd .array (a_np , ctx )
190193 b = tvm .nd .array (b_np , ctx )
191- c = tvm .nd .array (np .zeros ((nn , nn , 16 , 16 ), dtype = C .dtype ), ctx )
194+ c = tvm .nd .array (np .zeros ((batch_size , nn , nn , 16 , 16 ), dtype = C .dtype ), ctx )
192195 evaluator = func .time_evaluator (func .entry_name , ctx , number = 3 )
193196 print ('gemm with tensor core: %f ms' % (evaluator (a , b , c ).mean * 1e3 ))
194197
195198 if VERIFY :
196199 func (a , b , c )
197- a_np = a_np .transpose (0 , 2 , 1 , 3 ) .reshape (n , n )
198- b_np = b_np .transpose (0 , 2 , 1 , 3 ) .reshape (n , n )
199- c_np = c .asnumpy ().transpose (0 , 2 , 1 , 3 ) .reshape (n , n )
200- np .testing .assert_allclose (c_np , np .dot (a_np .astype (C .dtype ), b_np .astype (C .dtype )), rtol = 1e-4 , atol = 1e-4 )
200+ a_np = a_np .transpose (( 0 , 1 , 3 , 2 , 4 )) .reshape (batch_size , n , n )
201+ b_np = b_np .transpose (( 0 , 1 , 3 , 2 , 4 )) .reshape (batch_size , n , n )
202+ c_np = c .asnumpy ().transpose (( 0 , 1 , 3 , 2 , 4 )) .reshape (batch_size , n , n )
203+ np .testing .assert_allclose (c_np , np .matmul (a_np .astype (C .dtype ), b_np .astype (C .dtype )), rtol = 1e-4 , atol = 1e-4 )
201204
202205
203- def test_tensor_core_conv ():
206+ def test_tensor_core_batch_conv ():
204207 # The sizes of inputs and filters
205208 batch_size = 256
206209 height = 14
@@ -364,5 +367,5 @@ def test_tensor_core_conv():
364367 if not nvcc .have_tensorcore (ctx .compute_version ):
365368 print ("skip because gpu does not support tensor core" )
366369 else :
367- test_tensor_core_gemm ()
368- test_tensor_core_conv ()
370+ test_tensor_core_batch_matmal ()
371+ test_tensor_core_batch_conv ()
0 commit comments