15
15
"""Matmul kernel for Blackwell."""
16
16
17
17
import itertools
18
+ import math
18
19
19
20
import jax
20
21
from jax ._src .interpreters import mlir
@@ -81,8 +82,11 @@ def build_kernel(
81
82
if (m // block_tile_m ) % grid_tile_m :
82
83
raise ValueError (f"{ m = } // { tile_m = } must be divisible by { grid_tile_m = } " )
83
84
85
+ # We intend this to be iterated in column-major order.
86
+ logical_grid = (grid_tile_m , n // tile_n , m // (block_tile_m * grid_tile_m ))
87
+
84
88
def kernel (ctx , a , b , d , smem ):
85
- ((a_smem , b_smem ), d_smem ), barriers , mma_done_barrier , acc = smem
89
+ ((a_smem , b_smem ), d_smem ), barriers , mma_done_barrier , tmem_done_barrier , acc = smem
86
90
(ab_full_barriers , ab_empty_barriers ) = barriers
87
91
88
92
warp_idx = mgpu .warp_idx (sync = True )
@@ -93,18 +97,29 @@ def kernel(ctx, a, b, d, smem):
93
97
is_leader_block = arith .cmpi (
94
98
arith .CmpIPredicate .eq , ctx .cluster_idx (gpu .Dimension .x ), c (0 , index )
95
99
)
100
+ is_store_warpgroup = arith .cmpi (
101
+ arith .CmpIPredicate .eq , mgpu .warpgroup_idx (sync = True ), c (1 , i32 )
102
+ )
96
103
97
- # This function executes the kernel for a single output tile.
98
- def compute_output (block_m_start , n_start ):
99
- """Compute and store a single output tile."""
104
+ def compute_output (block_m_start , n_start , call_counter ):
105
+ """Compute and store a single output tile.
106
+
107
+ call_counter should be 0 the first time this function is called and
108
+ incremented by 1 before each subsequent call.
109
+ """
110
+ isnt_first_call = arith .cmpi (
111
+ arith .CmpIPredicate .ne , call_counter , c (0 , index )
112
+ )
100
113
# All blocks in the cluster share the same m_start -- align it!
101
114
m_start = arith .muli (arith .divui (block_m_start , c (tile_m , index )), c (tile_m , index ))
102
115
with mgpu .when (is_leader_of (TMA_WARP )):
103
116
@mgpu .fori (c (k_loop_iter , index ), None )
104
117
def _tma_body (ki , _ ):
105
118
slot = arith .remui (ki , c (max_concurrent_steps , index ))
106
- # TODO(apaszke): Use a predicate instead of a conditional.
107
- with mgpu .when (arith .cmpi (arith .CmpIPredicate .uge , ki , c (max_concurrent_steps , index ))):
119
+ isnt_warmup = arith .cmpi (
120
+ arith .CmpIPredicate .uge , ki , c (max_concurrent_steps , index )
121
+ )
122
+ with mgpu .when (arith .ori (isnt_first_call , isnt_warmup )):
108
123
ab_empty_barriers [slot ].wait ()
109
124
full_barrier = ab_full_barriers [slot ]
110
125
with mgpu .when (is_leader_block ):
@@ -135,6 +150,9 @@ def _tma_body(ki, _):
135
150
** common_args ,
136
151
)
137
152
153
+ # We wait in all blocks in the cluster to avoid double arrival errors.
154
+ with mgpu .when (arith .andi (is_leader_of (MMA_WARP ), isnt_first_call )):
155
+ tmem_done_barrier .wait (for_tensor_core = True )
138
156
with mgpu .when (arith .andi (is_leader_of (MMA_WARP ), is_leader_block )):
139
157
@mgpu .fori (c (k_loop_iter , index ), arith .constant (i1 , 0 ))
140
158
def _mma_body (ki , accumulate ):
@@ -150,41 +168,68 @@ def _mma_body(ki, accumulate):
150
168
collective = collective ,
151
169
)
152
170
accumulate = arith .constant (i1 , 1 )
171
+ tcgen05 .commit_arrive (ab_empty_barriers [slot ], collective = collective , ctx = ctx )
153
172
is_last_iter = arith .cmpi (
154
173
arith .CmpIPredicate .eq , ki , c (k_loop_iter - 1 , index )
155
174
)
156
- barrier_ptr = arith .select (
157
- is_last_iter ,
158
- mma_done_barrier .get_ptr (),
159
- ab_empty_barriers [slot ].get_ptr (),
160
- )
161
- tcgen05 .commit_arrive (barrier_ptr , collective = collective , ctx = ctx )
175
+ with mgpu .when (is_last_iter ):
176
+ tcgen05 .commit_arrive (mma_done_barrier , collective = collective , ctx = ctx )
162
177
return accumulate
163
178
164
- gpu .barrier ()
165
- mma_done_barrier .wait (for_tensor_core = True )
179
+ with mgpu .when (is_store_warpgroup ):
180
+ mma_done_barrier .wait (for_tensor_core = True )
181
+ final_acc = acc .load ().astype (mlir .dtype_to_ir_type (jnp .dtype (dtype )))
182
+ assert tile_n % epilogue_tile_n == 0
183
+ for ni in range (tile_n // epilogue_tile_n ):
184
+ n_slice = ds (ni * epilogue_tile_n , epilogue_tile_n )
185
+ final_acc [:, n_slice ].store_tiled (d_smem , swizzle = 128 )
186
+ # We store the first tile before arriving to reduce register pressure.
187
+ if ni == 0 :
188
+ # Make sure we're loaded all of TMEM before we arrive.
189
+ tcgen05 .wait_tmem_load ()
190
+ mgpu .warpgroup_barrier ()
191
+ tmem_done_barrier .arrive (for_tensor_core = True )
192
+ mgpu .commit_shared ()
193
+ store_n_start = arith .addi (n_start , c (ni * epilogue_tile_n , index ))
194
+ ctx .async_copy (
195
+ src_ref = d_smem ,
196
+ dst_ref = d ,
197
+ gmem_slice = (
198
+ ds (block_m_start , block_tile_m ),
199
+ ds (store_n_start , epilogue_tile_n ),
200
+ ),
201
+ gmem_transform = mgpu .TileTransform ((128 , swizzle_elems )),
202
+ swizzle = 128 ,
203
+ )
204
+ ctx .await_async_copy (0 )
205
+
206
+ # We statically assign the tiles to SMs.
207
+ logical_grid_size = math .prod (logical_grid )
208
+ sm_id = gpu .block_id (gpu .Dimension .x )
209
+ extra_step = arith .cmpi (
210
+ arith .CmpIPredicate .slt , sm_id , c (logical_grid_size % num_sms , index )
211
+ ) # Some SMs do an extra step when grid size isn't divisible by SM count.
212
+ mn_steps = arith .addi (
213
+ mgpu .c (logical_grid_size // num_sms , index ),
214
+ arith .index_castui (index , extra_step ),
215
+ )
166
216
167
- final_acc = acc .load ().astype (mlir .dtype_to_ir_type (jnp .dtype (dtype )))
168
- final_acc .store_tiled (d_smem , swizzle = 128 )
169
- mgpu .commit_shared ()
170
- ctx .async_copy (
171
- src_ref = d_smem ,
172
- dst_ref = d ,
173
- gmem_slice = (ds (block_m_start , block_tile_m ), ds (n_start , tile_n )),
174
- gmem_transform = mgpu .TileTransform ((128 , swizzle_elems )),
175
- swizzle = swizzle ,
217
+ @mgpu .fori (mn_steps , None )
218
+ def _mn_loop (local_mn_step , _ ):
219
+ global_mn_step = arith .addi (
220
+ sm_id , arith .muli (local_mn_step , mgpu .c (num_sms , index ))
176
221
)
177
- ctx .await_async_copy (0 )
222
+ logical_idxs = []
223
+ for dim_size in logical_grid :
224
+ logical_idxs .append (arith .remui (global_mn_step , mgpu .c (dim_size , index )))
225
+ global_mn_step = arith .divui (global_mn_step , mgpu .c (dim_size , index ))
226
+ lx , ly , lz = logical_idxs
227
+ m_idx = arith .addi (lx , arith .muli (lz , c (grid_tile_m , index )))
228
+ n_idx = ly
178
229
179
- m_idx = arith .addi (
180
- gpu .block_id (gpu .Dimension .x ),
181
- arith .muli (gpu .block_id (gpu .Dimension .z ), c (grid_tile_m , index )),
182
- )
183
- n_idx = gpu .block_id (gpu .Dimension .y )
184
- block_m_start = arith .muli (m_idx , c (block_tile_m , index ))
185
- n_start = arith .muli (n_idx , c (tile_n ,index ))
186
- # This is not a persistent kernel, so we only process one tile.
187
- compute_output (block_m_start , n_start )
230
+ block_m_start = arith .muli (m_idx , c (block_tile_m , index ))
231
+ n_start = arith .muli (n_idx , c (tile_n ,index ))
232
+ compute_output (block_m_start , n_start , local_mn_step )
188
233
189
234
compute_buffers = (
190
235
jax .ShapeDtypeStruct (
@@ -194,20 +239,23 @@ def _mma_body(ki, accumulate):
194
239
mgpu .tile_shape ((max_concurrent_steps , block_tile_n , tile_k ), tiling ),
195
240
dtype ),
196
241
)
242
+ epilogue_tile_n = 64
197
243
epilogue_buffer = jax .ShapeDtypeStruct (
198
- mgpu .tile_shape ((block_tile_m , tile_n ), (128 , swizzle_elems )),
244
+ mgpu .tile_shape ((block_tile_m , epilogue_tile_n ), (128 , swizzle_elems )),
199
245
dtype )
200
- smem_buffers = mgpu . Union ( [compute_buffers , epilogue_buffer ])
246
+ smem_buffers = [compute_buffers , epilogue_buffer ]
201
247
smem = (
202
248
smem_buffers ,
203
249
[mgpu .Barrier (arrival_count = 1 , num_barriers = max_concurrent_steps )] * 2 ,
204
250
mgpu .Barrier (arrival_count = 1 ),
251
+ mgpu .ClusterBarrier (collective_dims = (gpu .Dimension .x ,), num_barriers = 1 ),
205
252
mgpu .TMEM ((128 , tile_n ), jnp .float32 , collective = collective ),
206
253
)
254
+ num_sms = 148
207
255
return mgpu .as_gpu_kernel (
208
256
kernel ,
209
- (grid_tile_m , n // tile_n , m // ( block_tile_m * grid_tile_m )),
210
- (128 , 1 , 1 ),
257
+ (num_sms , 1 , 1 ), # This is a persistent kernel.
258
+ (2 * 128 , 1 , 1 ),
211
259
(
212
260
jax .ShapeDtypeStruct ((m , k ), dtype ),
213
261
jax .ShapeDtypeStruct ((n , k ), dtype ),
@@ -219,7 +267,7 @@ def _mma_body(ki, accumulate):
219
267
220
268
221
269
def main (unused_argv ):
222
- m , k , n = 8192 , 4096 , 8192
270
+ m , k , n = 2048 , 128 , 2048
223
271
224
272
ka , kb = jr .split (jr .key (0 ), 2 )
225
273
a = jr .normal (key = ka , shape = (m , k ), dtype = jnp .float16 )
@@ -234,34 +282,45 @@ def main(unused_argv):
234
282
names = ("collective" , "tile_m" , "tile_n" , "grid_tile_m" , "max_concurrent_steps" )
235
283
best_runtime = float ("inf" )
236
284
best_kwargs = {}
237
- for config in configs :
238
- kwargs = dict (zip (names , config ))
239
- tile_m = kwargs ["tile_m" ]
240
- tile_n = kwargs ["tile_n" ]
241
- if kwargs ["collective" ]:
242
- tile_m *= 2
243
- tile_n *= 2
244
- if m < tile_m or n < tile_n :
245
- continue
246
- if tile_n > 512 :
247
- continue
248
- if (m // tile_m ) % kwargs ["grid_tile_m" ]:
249
- continue
250
- try :
251
- with mlir .make_ir_context (), ir .Location .unknown ():
252
- f = build_kernel (m , k , n , jnp .float16 , ** kwargs )
253
- _ , runtime = profiler .measure (f )(a , b )
254
- except ValueError as e :
255
- if "Mosaic GPU kernel exceeds available shared memory" not in str (e ):
256
- raise
257
- runtime = float ("inf" )
258
- else :
259
- print (" " .join (f"{ k } ={ v } " for k , v in kwargs .items ()), int (runtime * 1000 ))
260
- if runtime < best_runtime :
261
- best_runtime = runtime
262
- best_kwargs = kwargs
263
- if not best_kwargs :
264
- raise ValueError ("No valid configuration found" )
285
+ # for config in configs:
286
+ # kwargs = dict(zip(names, config))
287
+ # tile_m = kwargs["tile_m"]
288
+ # tile_n = kwargs["tile_n"]
289
+ # if kwargs["collective"]:
290
+ # tile_m *= 2
291
+ # tile_n *= 2
292
+ # if m < tile_m or n < tile_n:
293
+ # continue
294
+ # if tile_n > 512:
295
+ # continue
296
+ # if (m // tile_m) % kwargs["grid_tile_m"]:
297
+ # continue
298
+ # try:
299
+ # with mlir.make_ir_context(), ir.Location.unknown():
300
+ # f = build_kernel(m, k, n, jnp.float16, **kwargs)
301
+ # _, runtime = profiler.measure(f)(a, b)
302
+ # except ValueError as e:
303
+ # if "Mosaic GPU kernel exceeds available shared memory" not in str(e):
304
+ # raise
305
+ # runtime = float("inf")
306
+ # else:
307
+ # print(" ".join(f"{k}={v}" for k, v in kwargs.items()), int(runtime * 1000))
308
+ # if runtime < best_runtime:
309
+ # best_runtime = runtime
310
+ # best_kwargs = kwargs
311
+ # if not best_kwargs:
312
+ # raise ValueError("No valid configuration found")
313
+ best_kwargs = dict (
314
+ max_concurrent_steps = 2 ,
315
+ collective = True ,
316
+ tile_m = 128 ,
317
+ tile_n = 256 ,
318
+ grid_tile_m = 4 ,
319
+ )
320
+ with mlir .make_ir_context (), ir .Location .unknown ():
321
+ kernel = build_kernel (m , k , n , jnp .float16 , ** best_kwargs )
322
+ for i in range (50 ):
323
+ jax .block_until_ready (kernel (a , b ))
265
324
266
325
with mlir .make_ir_context (), ir .Location .unknown ():
267
326
d , runtime = profiler .measure (build_kernel (m , k , n , jnp .float16 , ** best_kwargs ))(a , b )
0 commit comments