Skip to content

Commit 8ba9da6

Browse files
apaszkeGoogle-ML-Automation
authored andcommitted
[Mosaic GPU] Bug reproducer
PiperOrigin-RevId: 763812993
1 parent 3b3c338 commit 8ba9da6

File tree

4 files changed

+170
-84
lines changed

4 files changed

+170
-84
lines changed

jax/experimental/mosaic/gpu/examples/matmul_blackwell.py

Lines changed: 125 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
"""Matmul kernel for Blackwell."""
1616

1717
import itertools
18+
import math
1819

1920
import jax
2021
from jax._src.interpreters import mlir
@@ -81,8 +82,11 @@ def build_kernel(
8182
if (m // block_tile_m) % grid_tile_m:
8283
raise ValueError(f"{m=} // {tile_m=} must be divisible by {grid_tile_m=}")
8384

85+
# We intend this to be iterated in column-major order.
86+
logical_grid = (grid_tile_m, n // tile_n, m // (block_tile_m * grid_tile_m))
87+
8488
def kernel(ctx, a, b, d, smem):
85-
((a_smem, b_smem), d_smem), barriers, mma_done_barrier, acc = smem
89+
((a_smem, b_smem), d_smem), barriers, mma_done_barrier, tmem_done_barrier, acc = smem
8690
(ab_full_barriers, ab_empty_barriers) = barriers
8791

8892
warp_idx = mgpu.warp_idx(sync=True)
@@ -93,18 +97,29 @@ def kernel(ctx, a, b, d, smem):
9397
is_leader_block = arith.cmpi(
9498
arith.CmpIPredicate.eq, ctx.cluster_idx(gpu.Dimension.x), c(0, index)
9599
)
100+
is_store_warpgroup = arith.cmpi(
101+
arith.CmpIPredicate.eq, mgpu.warpgroup_idx(sync=True), c(1, i32)
102+
)
96103

97-
# This function executes the kernel for a single output tile.
98-
def compute_output(block_m_start, n_start):
99-
"""Compute and store a single output tile."""
104+
def compute_output(block_m_start, n_start, call_counter):
105+
"""Compute and store a single output tile.
106+
107+
call_counter should be 0 the first time this function is called and
108+
incremented by 1 before each subsequent call.
109+
"""
110+
isnt_first_call = arith.cmpi(
111+
arith.CmpIPredicate.ne, call_counter, c(0, index)
112+
)
100113
# All blocks in the cluster share the same m_start -- align it!
101114
m_start = arith.muli(arith.divui(block_m_start, c(tile_m, index)), c(tile_m, index))
102115
with mgpu.when(is_leader_of(TMA_WARP)):
103116
@mgpu.fori(c(k_loop_iter, index), None)
104117
def _tma_body(ki, _):
105118
slot = arith.remui(ki, c(max_concurrent_steps, index))
106-
# TODO(apaszke): Use a predicate instead of a conditional.
107-
with mgpu.when(arith.cmpi(arith.CmpIPredicate.uge, ki, c(max_concurrent_steps, index))):
119+
isnt_warmup = arith.cmpi(
120+
arith.CmpIPredicate.uge, ki, c(max_concurrent_steps, index)
121+
)
122+
with mgpu.when(arith.ori(isnt_first_call, isnt_warmup)):
108123
ab_empty_barriers[slot].wait()
109124
full_barrier = ab_full_barriers[slot]
110125
with mgpu.when(is_leader_block):
@@ -135,6 +150,9 @@ def _tma_body(ki, _):
135150
**common_args,
136151
)
137152

153+
# We wait in all blocks in the cluster to avoid double arrival errors.
154+
with mgpu.when(arith.andi(is_leader_of(MMA_WARP), isnt_first_call)):
155+
tmem_done_barrier.wait(for_tensor_core=True)
138156
with mgpu.when(arith.andi(is_leader_of(MMA_WARP), is_leader_block)):
139157
@mgpu.fori(c(k_loop_iter, index), arith.constant(i1, 0))
140158
def _mma_body(ki, accumulate):
@@ -150,41 +168,68 @@ def _mma_body(ki, accumulate):
150168
collective=collective,
151169
)
152170
accumulate = arith.constant(i1, 1)
171+
tcgen05.commit_arrive(ab_empty_barriers[slot], collective=collective, ctx=ctx)
153172
is_last_iter = arith.cmpi(
154173
arith.CmpIPredicate.eq, ki, c(k_loop_iter - 1, index)
155174
)
156-
barrier_ptr = arith.select(
157-
is_last_iter,
158-
mma_done_barrier.get_ptr(),
159-
ab_empty_barriers[slot].get_ptr(),
160-
)
161-
tcgen05.commit_arrive(barrier_ptr, collective=collective, ctx=ctx)
175+
with mgpu.when(is_last_iter):
176+
tcgen05.commit_arrive(mma_done_barrier, collective=collective, ctx=ctx)
162177
return accumulate
163178

164-
gpu.barrier()
165-
mma_done_barrier.wait(for_tensor_core=True)
179+
with mgpu.when(is_store_warpgroup):
180+
mma_done_barrier.wait(for_tensor_core=True)
181+
final_acc = acc.load().astype(mlir.dtype_to_ir_type(jnp.dtype(dtype)))
182+
assert tile_n % epilogue_tile_n == 0
183+
for ni in range(tile_n // epilogue_tile_n):
184+
n_slice = ds(ni * epilogue_tile_n, epilogue_tile_n)
185+
final_acc[:, n_slice].store_tiled(d_smem, swizzle=128)
186+
# We store the first tile before arriving to reduce register pressure.
187+
if ni == 0:
188+
# Make sure we're loaded all of TMEM before we arrive.
189+
tcgen05.wait_tmem_load()
190+
mgpu.warpgroup_barrier()
191+
tmem_done_barrier.arrive(for_tensor_core=True)
192+
mgpu.commit_shared()
193+
store_n_start = arith.addi(n_start, c(ni * epilogue_tile_n, index))
194+
ctx.async_copy(
195+
src_ref=d_smem,
196+
dst_ref=d,
197+
gmem_slice=(
198+
ds(block_m_start, block_tile_m),
199+
ds(store_n_start, epilogue_tile_n),
200+
),
201+
gmem_transform=mgpu.TileTransform((128, swizzle_elems)),
202+
swizzle=128,
203+
)
204+
ctx.await_async_copy(0)
205+
206+
# We statically assign the tiles to SMs.
207+
logical_grid_size = math.prod(logical_grid)
208+
sm_id = gpu.block_id(gpu.Dimension.x)
209+
extra_step = arith.cmpi(
210+
arith.CmpIPredicate.slt, sm_id, c(logical_grid_size % num_sms, index)
211+
) # Some SMs do an extra step when grid size isn't divisible by SM count.
212+
mn_steps = arith.addi(
213+
mgpu.c(logical_grid_size // num_sms, index),
214+
arith.index_castui(index, extra_step),
215+
)
166216

167-
final_acc = acc.load().astype(mlir.dtype_to_ir_type(jnp.dtype(dtype)))
168-
final_acc.store_tiled(d_smem, swizzle=128)
169-
mgpu.commit_shared()
170-
ctx.async_copy(
171-
src_ref=d_smem,
172-
dst_ref=d,
173-
gmem_slice=(ds(block_m_start, block_tile_m), ds(n_start, tile_n)),
174-
gmem_transform=mgpu.TileTransform((128, swizzle_elems)),
175-
swizzle=swizzle,
217+
@mgpu.fori(mn_steps, None)
218+
def _mn_loop(local_mn_step, _):
219+
global_mn_step = arith.addi(
220+
sm_id, arith.muli(local_mn_step, mgpu.c(num_sms, index))
176221
)
177-
ctx.await_async_copy(0)
222+
logical_idxs = []
223+
for dim_size in logical_grid:
224+
logical_idxs.append(arith.remui(global_mn_step, mgpu.c(dim_size, index)))
225+
global_mn_step = arith.divui(global_mn_step, mgpu.c(dim_size, index))
226+
lx, ly, lz = logical_idxs
227+
m_idx = arith.addi(lx, arith.muli(lz, c(grid_tile_m, index)))
228+
n_idx = ly
178229

179-
m_idx = arith.addi(
180-
gpu.block_id(gpu.Dimension.x),
181-
arith.muli(gpu.block_id(gpu.Dimension.z), c(grid_tile_m, index)),
182-
)
183-
n_idx = gpu.block_id(gpu.Dimension.y)
184-
block_m_start = arith.muli(m_idx, c(block_tile_m, index))
185-
n_start = arith.muli(n_idx, c(tile_n,index))
186-
# This is not a persistent kernel, so we only process one tile.
187-
compute_output(block_m_start, n_start)
230+
block_m_start = arith.muli(m_idx, c(block_tile_m, index))
231+
n_start = arith.muli(n_idx, c(tile_n,index))
232+
compute_output(block_m_start, n_start, local_mn_step)
188233

189234
compute_buffers = (
190235
jax.ShapeDtypeStruct(
@@ -194,20 +239,23 @@ def _mma_body(ki, accumulate):
194239
mgpu.tile_shape((max_concurrent_steps, block_tile_n, tile_k), tiling),
195240
dtype),
196241
)
242+
epilogue_tile_n = 64
197243
epilogue_buffer = jax.ShapeDtypeStruct(
198-
mgpu.tile_shape((block_tile_m, tile_n), (128, swizzle_elems)),
244+
mgpu.tile_shape((block_tile_m, epilogue_tile_n), (128, swizzle_elems)),
199245
dtype)
200-
smem_buffers = mgpu.Union([compute_buffers, epilogue_buffer])
246+
smem_buffers = [compute_buffers, epilogue_buffer]
201247
smem = (
202248
smem_buffers,
203249
[mgpu.Barrier(arrival_count=1, num_barriers=max_concurrent_steps)] * 2,
204250
mgpu.Barrier(arrival_count=1),
251+
mgpu.ClusterBarrier(collective_dims=(gpu.Dimension.x,), num_barriers=1),
205252
mgpu.TMEM((128, tile_n), jnp.float32, collective=collective),
206253
)
254+
num_sms = 148
207255
return mgpu.as_gpu_kernel(
208256
kernel,
209-
(grid_tile_m, n // tile_n, m // (block_tile_m * grid_tile_m)),
210-
(128, 1, 1),
257+
(num_sms, 1, 1), # This is a persistent kernel.
258+
(2 * 128, 1, 1),
211259
(
212260
jax.ShapeDtypeStruct((m, k), dtype),
213261
jax.ShapeDtypeStruct((n, k), dtype),
@@ -219,7 +267,7 @@ def _mma_body(ki, accumulate):
219267

220268

221269
def main(unused_argv):
222-
m, k, n = 8192, 4096, 8192
270+
m, k, n = 2048, 128, 2048
223271

224272
ka, kb = jr.split(jr.key(0), 2)
225273
a = jr.normal(key=ka, shape=(m, k), dtype=jnp.float16)
@@ -234,34 +282,45 @@ def main(unused_argv):
234282
names = ("collective", "tile_m", "tile_n", "grid_tile_m", "max_concurrent_steps")
235283
best_runtime = float("inf")
236284
best_kwargs = {}
237-
for config in configs:
238-
kwargs = dict(zip(names, config))
239-
tile_m = kwargs["tile_m"]
240-
tile_n = kwargs["tile_n"]
241-
if kwargs["collective"]:
242-
tile_m *= 2
243-
tile_n *= 2
244-
if m < tile_m or n < tile_n:
245-
continue
246-
if tile_n > 512:
247-
continue
248-
if (m // tile_m) % kwargs["grid_tile_m"]:
249-
continue
250-
try:
251-
with mlir.make_ir_context(), ir.Location.unknown():
252-
f = build_kernel(m, k, n, jnp.float16, **kwargs)
253-
_, runtime = profiler.measure(f)(a, b)
254-
except ValueError as e:
255-
if "Mosaic GPU kernel exceeds available shared memory" not in str(e):
256-
raise
257-
runtime = float("inf")
258-
else:
259-
print(" ".join(f"{k}={v}" for k, v in kwargs.items()), int(runtime * 1000))
260-
if runtime < best_runtime:
261-
best_runtime = runtime
262-
best_kwargs = kwargs
263-
if not best_kwargs:
264-
raise ValueError("No valid configuration found")
285+
# for config in configs:
286+
# kwargs = dict(zip(names, config))
287+
# tile_m = kwargs["tile_m"]
288+
# tile_n = kwargs["tile_n"]
289+
# if kwargs["collective"]:
290+
# tile_m *= 2
291+
# tile_n *= 2
292+
# if m < tile_m or n < tile_n:
293+
# continue
294+
# if tile_n > 512:
295+
# continue
296+
# if (m // tile_m) % kwargs["grid_tile_m"]:
297+
# continue
298+
# try:
299+
# with mlir.make_ir_context(), ir.Location.unknown():
300+
# f = build_kernel(m, k, n, jnp.float16, **kwargs)
301+
# _, runtime = profiler.measure(f)(a, b)
302+
# except ValueError as e:
303+
# if "Mosaic GPU kernel exceeds available shared memory" not in str(e):
304+
# raise
305+
# runtime = float("inf")
306+
# else:
307+
# print(" ".join(f"{k}={v}" for k, v in kwargs.items()), int(runtime * 1000))
308+
# if runtime < best_runtime:
309+
# best_runtime = runtime
310+
# best_kwargs = kwargs
311+
# if not best_kwargs:
312+
# raise ValueError("No valid configuration found")
313+
best_kwargs = dict(
314+
max_concurrent_steps=2,
315+
collective=True,
316+
tile_m=128,
317+
tile_n=256,
318+
grid_tile_m=4,
319+
)
320+
with mlir.make_ir_context(), ir.Location.unknown():
321+
kernel = build_kernel(m, k, n, jnp.float16, **best_kwargs)
322+
for i in range(50):
323+
jax.block_until_ready(kernel(a, b))
265324

266325
with mlir.make_ir_context(), ir.Location.unknown():
267326
d, runtime = profiler.measure(build_kernel(m, k, n, jnp.float16, **best_kwargs))(a, b)

jax/experimental/mosaic/gpu/fragmented_array.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1376,26 +1376,26 @@ def bitcast(self, elt: ir.Type, *, output_is_signed: bool | None = None):
13761376
)
13771377

13781378
def __getitem__(self, idx):
1379-
if self.layout != WGMMA_LAYOUT:
1380-
raise NotImplementedError("Only WGMMA layouts support slicing")
1379+
if not isinstance(self.layout, TiledLayout):
1380+
raise NotImplementedError("Only arrays with tiled layouts can be sliced")
13811381
base_idx, slice_shape, is_squeezed = utils.parse_indices(idx, self.shape)
1382+
if any(isinstance(idx, ir.Value) for idx in base_idx):
1383+
raise ValueError("Only static slicing allowed")
13821384
if any(is_squeezed):
13831385
raise NotImplementedError("Only slicing implemented")
1384-
if (
1385-
base_idx[0] % 64
1386-
or slice_shape[0] % 64
1387-
or base_idx[1] % 8
1388-
or slice_shape[1] % 8
1386+
base_tile_shape = self.layout.base_tile_shape
1387+
if len(base_tile_shape) != len(self.shape):
1388+
raise NotImplementedError("Tiling has different rank than array")
1389+
if any(
1390+
b % t or l % t
1391+
for b, l, t in zip(base_idx, slice_shape, base_tile_shape, strict=True)
13891392
):
13901393
raise NotImplementedError("Only tile aligned slicing supported")
1391-
base_idx[0] //= 64
1392-
slice_shape[0] //= 64
1393-
base_idx[1] //= 8
1394-
slice_shape[1] //= 8
1395-
new_regs = self.registers[
1396-
base_idx[0] : base_idx[0] + slice_shape[0],
1397-
base_idx[1] : base_idx[1] + slice_shape[1],
1398-
]
1394+
register_slices = tuple(
1395+
slice(b // t, (b + l) // t)
1396+
for b, l, t in zip(base_idx, slice_shape, base_tile_shape, strict=True)
1397+
)
1398+
new_regs = self.registers[register_slices]
13991399
return FragmentedArray(
14001400
_registers=new_regs, _layout=self.layout, _is_signed=self.is_signed
14011401
)

jax/experimental/mosaic/gpu/tcgen05.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -477,6 +477,16 @@ def tmem_load(tmem_addr, shape, num, pack: bool):
477477
return [llvm.extractvalue(i32, regs, [i]) for i in range(num_out_regs)]
478478

479479

480+
def wait_tmem_load():
481+
llvm.inline_asm(
482+
ir.Type.parse("!llvm.void"),
483+
[],
484+
"tcgen05.wait::ld.sync.aligned;",
485+
"",
486+
has_side_effects=True,
487+
)
488+
489+
480490
def tmem_store(tmem_addr, shape, num, regs, unpack: bool):
481491
num_out_regs, regs_vector = _tmem_access_helper(shape, num)
482492
pack_mod = ".unpack::16b" if unpack else ""
@@ -832,7 +842,7 @@ def _transfer_32xcols(
832842
regs_per_instr = atom_shape[0] * atom_shape[1] // (utils.WARP_SIZE * reg_packing)
833843
# We artificially lower the instr_num compared to its limits, because higher
834844
# values can lead to register spills..
835-
instr_num = min(total_num, 64 // regs_per_instr)
845+
instr_num = min(total_num, 32 // regs_per_instr)
836846
assert 32 % atom_rows == 0
837847
num_row_steps = 32 // atom_rows
838848
for lane_step in range(num_row_steps):

jax/experimental/mosaic/gpu/utils.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -817,8 +817,19 @@ def update_parities(self, parities: ir.Value) -> tuple[ir.Value, ir.Value]:
817817
)
818818
return parity, arith.xori(parities, bitmask)
819819

820-
def arrive(self, arrival_count: int = 1, can_complete: bool = True):
820+
def arrive(
821+
self,
822+
arrival_count: int = 1,
823+
can_complete: bool = True,
824+
for_tensor_core: bool = False,
825+
):
821826
i64 = ir.IntegerType.get_signless(64)
827+
if for_tensor_core:
828+
llvm.inline_asm(
829+
ir.Type.parse("!llvm.void"),
830+
[], "tcgen05.fence::before_thread_sync;", "",
831+
has_side_effects=True,
832+
)
822833
if can_complete:
823834
if arrival_count > 1:
824835
count = c(arrival_count - 1, ir.IntegerType.get_signless(32))
@@ -982,11 +993,17 @@ def __iter__(self):
982993
def __getitem__(self, offset):
983994
return CollectiveBarrierRef(self.barrier[offset], self.cluster_mask)
984995

985-
def arrive(self):
996+
def arrive(self, for_tensor_core: bool = False):
986997
"""Arrives on a barrier in all blocks that share at least one of the coordinates along the collective dimensions.
987998
988999
Note that unlike in arrive, each warpgroup arrives once.
9891000
"""
1001+
if for_tensor_core:
1002+
llvm.inline_asm(
1003+
ir.Type.parse("!llvm.void"),
1004+
[], "tcgen05.fence::before_thread_sync;", "",
1005+
has_side_effects=True,
1006+
)
9901007
if self.barrier.num_barriers != 1:
9911008
raise ValueError("Can only arrive on a single barrier")
9921009
if self.cluster_mask is None:

0 commit comments

Comments
 (0)