@@ -235,5 +235,70 @@ def expected(A: T.Buffer((128, 128), "float32"), B: T.Buffer(128, "float32")):
235235 B [i ] = reduce [0 ]
236236
237237
238+ class TestMultiGroupMask (BaseCompare ):
239+ @T .prim_func
240+ def before (A : T .Buffer ((32 , 32 ), "float32" ), B : T .Buffer ((32 ,), "float32" )):
241+ T .func_attr ({"target" : T .target ("cuda" , host = "llvm" )})
242+ threadIdx_y = T .launch_thread ("threadIdx.y" , 32 )
243+ cross_thread_B = T .allocate ([1 ], "float32" , "local" )
244+ threadIdx_x = T .launch_thread ("threadIdx.x" , 32 )
245+ cross_thread_B_1 = T .Buffer ((1 ,), data = cross_thread_B , scope = "local" )
246+ with T .attr (
247+ T .comm_reducer (lambda x0 , y0 : x0 + y0 , [T .float32 (0 )]),
248+ "reduce_scope" ,
249+ T .reinterpret ("handle" , T .uint64 (0 )),
250+ ):
251+ A_1 = T .Buffer ((1024 ,), data = A .data )
252+ T .tvm_thread_allreduce (
253+ T .uint32 (1 ),
254+ A_1 [threadIdx_y * 32 + threadIdx_x ],
255+ T .bool (True ),
256+ cross_thread_B_1 [0 ],
257+ threadIdx_x ,
258+ )
259+ if threadIdx_x == 0 :
260+ B_1 = T .Buffer ((32 ,), data = B .data )
261+ B_1 [threadIdx_y ] = cross_thread_B_1 [0 ]
262+
263+ @T .prim_func
264+ def expected (A : T .Buffer ((32 , 32 ), "float32" ), B : T .Buffer ((32 ,), "float32" )):
265+ T .func_attr ({"target" : T .target ("cuda" , host = "llvm" )})
266+ threadIdx_y = T .launch_thread ("threadIdx.y" , 32 )
267+ red_buf0 = T .allocate ([1 ], "float32" , "local" )
268+ threadIdx_x = T .launch_thread ("threadIdx.x" , 32 )
269+ red_buf0_1 = T .Buffer ((1 ,), data = red_buf0 , scope = "local" )
270+ with T .attr (
271+ T .comm_reducer (lambda x0 , y0 : x0 + y0 , [T .float32 (0 )]),
272+ "reduce_scope" ,
273+ T .reinterpret ("handle" , T .uint64 (0 )),
274+ ):
275+ mask = T .allocate ([1 ], "uint32" , "local" )
276+ t0 = T .allocate ([1 ], "float32" , "local" )
277+ A_1 = T .Buffer ((1024 ,), data = A .data )
278+ red_buf0_1 [0 ] = A_1 [threadIdx_y * 32 + threadIdx_x ]
279+
280+ mask_1 = T .Buffer ((1 ,), "uint32" , data = mask , scope = "local" )
281+ mask_1 [0 ] = T .bitwise_and (
282+ T .tvm_warp_activemask (),
283+ T .shift_left (T .uint32 (4294967295 ), T .uint32 (32 ) * T .Cast ("uint32" , threadIdx_y )),
284+ )
285+
286+ t0_1 = T .Buffer ((1 ,), data = t0 , scope = "local" )
287+ t0_1 [0 ] = T .tvm_warp_shuffle_down (mask_1 [0 ], red_buf0_1 [0 ], 16 , 32 , 32 )
288+ red_buf0_1 [0 ] = red_buf0_1 [0 ] + t0_1 [0 ]
289+ t0_1 [0 ] = T .tvm_warp_shuffle_down (mask_1 [0 ], red_buf0_1 [0 ], 8 , 32 , 32 )
290+ red_buf0_1 [0 ] = red_buf0_1 [0 ] + t0_1 [0 ]
291+ t0_1 [0 ] = T .tvm_warp_shuffle_down (mask_1 [0 ], red_buf0_1 [0 ], 4 , 32 , 32 )
292+ red_buf0_1 [0 ] = red_buf0_1 [0 ] + t0_1 [0 ]
293+ t0_1 [0 ] = T .tvm_warp_shuffle_down (mask_1 [0 ], red_buf0_1 [0 ], 2 , 32 , 32 )
294+ red_buf0_1 [0 ] = red_buf0_1 [0 ] + t0_1 [0 ]
295+ t0_1 [0 ] = T .tvm_warp_shuffle_down (mask_1 [0 ], red_buf0_1 [0 ], 1 , 32 , 32 )
296+ red_buf0_1 [0 ] = red_buf0_1 [0 ] + t0_1 [0 ]
297+ red_buf0_1 [0 ] = T .tvm_warp_shuffle (mask_1 [0 ], red_buf0_1 [0 ], 32 * threadIdx_y , 32 , 32 )
298+ if threadIdx_x == 0 :
299+ B_1 = T .Buffer ((32 ,), data = B .data )
300+ B_1 [threadIdx_y ] = red_buf0_1 [0 ]
301+
302+
238303if __name__ == "__main__" :
239304 tvm .testing .main ()
0 commit comments