Skip to content

[Mosaic GPU] Bug reproducer #29034

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
191 changes: 125 additions & 66 deletions jax/experimental/mosaic/gpu/examples/matmul_blackwell.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"""Matmul kernel for Blackwell."""

import itertools
import math

import jax
from jax._src.interpreters import mlir
Expand Down Expand Up @@ -81,8 +82,11 @@ def build_kernel(
if (m // block_tile_m) % grid_tile_m:
raise ValueError(f"{m=} // {tile_m=} must be divisible by {grid_tile_m=}")

# We intend this to be iterated in column-major order.
logical_grid = (grid_tile_m, n // tile_n, m // (block_tile_m * grid_tile_m))

def kernel(ctx, a, b, d, smem):
((a_smem, b_smem), d_smem), barriers, mma_done_barrier, acc = smem
((a_smem, b_smem), d_smem), barriers, mma_done_barrier, tmem_done_barrier, acc = smem
(ab_full_barriers, ab_empty_barriers) = barriers

warp_idx = mgpu.warp_idx(sync=True)
Expand All @@ -93,18 +97,29 @@ def kernel(ctx, a, b, d, smem):
is_leader_block = arith.cmpi(
arith.CmpIPredicate.eq, ctx.cluster_idx(gpu.Dimension.x), c(0, index)
)
is_store_warpgroup = arith.cmpi(
arith.CmpIPredicate.eq, mgpu.warpgroup_idx(sync=True), c(1, i32)
)

# This function executes the kernel for a single output tile.
def compute_output(block_m_start, n_start):
"""Compute and store a single output tile."""
def compute_output(block_m_start, n_start, call_counter):
"""Compute and store a single output tile.

call_counter should be 0 the first time this function is called and
incremented by 1 before each subsequent call.
"""
isnt_first_call = arith.cmpi(
arith.CmpIPredicate.ne, call_counter, c(0, index)
)
# All blocks in the cluster share the same m_start -- align it!
m_start = arith.muli(arith.divui(block_m_start, c(tile_m, index)), c(tile_m, index))
with mgpu.when(is_leader_of(TMA_WARP)):
@mgpu.fori(c(k_loop_iter, index), None)
def _tma_body(ki, _):
slot = arith.remui(ki, c(max_concurrent_steps, index))
# TODO(apaszke): Use a predicate instead of a conditional.
with mgpu.when(arith.cmpi(arith.CmpIPredicate.uge, ki, c(max_concurrent_steps, index))):
isnt_warmup = arith.cmpi(
arith.CmpIPredicate.uge, ki, c(max_concurrent_steps, index)
)
with mgpu.when(arith.ori(isnt_first_call, isnt_warmup)):
ab_empty_barriers[slot].wait()
full_barrier = ab_full_barriers[slot]
with mgpu.when(is_leader_block):
Expand Down Expand Up @@ -135,6 +150,9 @@ def _tma_body(ki, _):
**common_args,
)

# We wait in all blocks in the cluster to avoid double arrival errors.
with mgpu.when(arith.andi(is_leader_of(MMA_WARP), isnt_first_call)):
tmem_done_barrier.wait(for_tensor_core=True)
with mgpu.when(arith.andi(is_leader_of(MMA_WARP), is_leader_block)):
@mgpu.fori(c(k_loop_iter, index), arith.constant(i1, 0))
def _mma_body(ki, accumulate):
Expand All @@ -150,41 +168,68 @@ def _mma_body(ki, accumulate):
collective=collective,
)
accumulate = arith.constant(i1, 1)
tcgen05.commit_arrive(ab_empty_barriers[slot], collective=collective, ctx=ctx)
is_last_iter = arith.cmpi(
arith.CmpIPredicate.eq, ki, c(k_loop_iter - 1, index)
)
barrier_ptr = arith.select(
is_last_iter,
mma_done_barrier.get_ptr(),
ab_empty_barriers[slot].get_ptr(),
)
tcgen05.commit_arrive(barrier_ptr, collective=collective, ctx=ctx)
with mgpu.when(is_last_iter):
tcgen05.commit_arrive(mma_done_barrier, collective=collective, ctx=ctx)
return accumulate

gpu.barrier()
mma_done_barrier.wait(for_tensor_core=True)
with mgpu.when(is_store_warpgroup):
mma_done_barrier.wait(for_tensor_core=True)
final_acc = acc.load().astype(mlir.dtype_to_ir_type(jnp.dtype(dtype)))
assert tile_n % epilogue_tile_n == 0
for ni in range(tile_n // epilogue_tile_n):
n_slice = ds(ni * epilogue_tile_n, epilogue_tile_n)
final_acc[:, n_slice].store_tiled(d_smem, swizzle=128)
# We store the first tile before arriving to reduce register pressure.
if ni == 0:
# Make sure we're loaded all of TMEM before we arrive.
tcgen05.wait_tmem_load()
mgpu.warpgroup_barrier()
tmem_done_barrier.arrive(for_tensor_core=True)
mgpu.commit_shared()
store_n_start = arith.addi(n_start, c(ni * epilogue_tile_n, index))
ctx.async_copy(
src_ref=d_smem,
dst_ref=d,
gmem_slice=(
ds(block_m_start, block_tile_m),
ds(store_n_start, epilogue_tile_n),
),
gmem_transform=mgpu.TileTransform((128, swizzle_elems)),
swizzle=128,
)
ctx.await_async_copy(0)

# We statically assign the tiles to SMs.
logical_grid_size = math.prod(logical_grid)
sm_id = gpu.block_id(gpu.Dimension.x)
extra_step = arith.cmpi(
arith.CmpIPredicate.slt, sm_id, c(logical_grid_size % num_sms, index)
) # Some SMs do an extra step when grid size isn't divisible by SM count.
mn_steps = arith.addi(
mgpu.c(logical_grid_size // num_sms, index),
arith.index_castui(index, extra_step),
)

final_acc = acc.load().astype(mlir.dtype_to_ir_type(jnp.dtype(dtype)))
final_acc.store_tiled(d_smem, swizzle=128)
mgpu.commit_shared()
ctx.async_copy(
src_ref=d_smem,
dst_ref=d,
gmem_slice=(ds(block_m_start, block_tile_m), ds(n_start, tile_n)),
gmem_transform=mgpu.TileTransform((128, swizzle_elems)),
swizzle=swizzle,
@mgpu.fori(mn_steps, None)
def _mn_loop(local_mn_step, _):
global_mn_step = arith.addi(
sm_id, arith.muli(local_mn_step, mgpu.c(num_sms, index))
)
ctx.await_async_copy(0)
logical_idxs = []
for dim_size in logical_grid:
logical_idxs.append(arith.remui(global_mn_step, mgpu.c(dim_size, index)))
global_mn_step = arith.divui(global_mn_step, mgpu.c(dim_size, index))
lx, ly, lz = logical_idxs
m_idx = arith.addi(lx, arith.muli(lz, c(grid_tile_m, index)))
n_idx = ly

m_idx = arith.addi(
gpu.block_id(gpu.Dimension.x),
arith.muli(gpu.block_id(gpu.Dimension.z), c(grid_tile_m, index)),
)
n_idx = gpu.block_id(gpu.Dimension.y)
block_m_start = arith.muli(m_idx, c(block_tile_m, index))
n_start = arith.muli(n_idx, c(tile_n,index))
# This is not a persistent kernel, so we only process one tile.
compute_output(block_m_start, n_start)
block_m_start = arith.muli(m_idx, c(block_tile_m, index))
n_start = arith.muli(n_idx, c(tile_n,index))
compute_output(block_m_start, n_start, local_mn_step)

compute_buffers = (
jax.ShapeDtypeStruct(
Expand All @@ -194,20 +239,23 @@ def _mma_body(ki, accumulate):
mgpu.tile_shape((max_concurrent_steps, block_tile_n, tile_k), tiling),
dtype),
)
epilogue_tile_n = 64
epilogue_buffer = jax.ShapeDtypeStruct(
mgpu.tile_shape((block_tile_m, tile_n), (128, swizzle_elems)),
mgpu.tile_shape((block_tile_m, epilogue_tile_n), (128, swizzle_elems)),
dtype)
smem_buffers = mgpu.Union([compute_buffers, epilogue_buffer])
smem_buffers = [compute_buffers, epilogue_buffer]
smem = (
smem_buffers,
[mgpu.Barrier(arrival_count=1, num_barriers=max_concurrent_steps)] * 2,
mgpu.Barrier(arrival_count=1),
mgpu.ClusterBarrier(collective_dims=(gpu.Dimension.x,), num_barriers=1),
mgpu.TMEM((128, tile_n), jnp.float32, collective=collective),
)
num_sms = 148
return mgpu.as_gpu_kernel(
kernel,
(grid_tile_m, n // tile_n, m // (block_tile_m * grid_tile_m)),
(128, 1, 1),
(num_sms, 1, 1), # This is a persistent kernel.
(2 * 128, 1, 1),
(
jax.ShapeDtypeStruct((m, k), dtype),
jax.ShapeDtypeStruct((n, k), dtype),
Expand All @@ -219,7 +267,7 @@ def _mma_body(ki, accumulate):


def main(unused_argv):
m, k, n = 8192, 4096, 8192
m, k, n = 2048, 128, 2048

ka, kb = jr.split(jr.key(0), 2)
a = jr.normal(key=ka, shape=(m, k), dtype=jnp.float16)
Expand All @@ -234,34 +282,45 @@ def main(unused_argv):
names = ("collective", "tile_m", "tile_n", "grid_tile_m", "max_concurrent_steps")
best_runtime = float("inf")
best_kwargs = {}
for config in configs:
kwargs = dict(zip(names, config))
tile_m = kwargs["tile_m"]
tile_n = kwargs["tile_n"]
if kwargs["collective"]:
tile_m *= 2
tile_n *= 2
if m < tile_m or n < tile_n:
continue
if tile_n > 512:
continue
if (m // tile_m) % kwargs["grid_tile_m"]:
continue
try:
with mlir.make_ir_context(), ir.Location.unknown():
f = build_kernel(m, k, n, jnp.float16, **kwargs)
_, runtime = profiler.measure(f)(a, b)
except ValueError as e:
if "Mosaic GPU kernel exceeds available shared memory" not in str(e):
raise
runtime = float("inf")
else:
print(" ".join(f"{k}={v}" for k, v in kwargs.items()), int(runtime * 1000))
if runtime < best_runtime:
best_runtime = runtime
best_kwargs = kwargs
if not best_kwargs:
raise ValueError("No valid configuration found")
# for config in configs:
# kwargs = dict(zip(names, config))
# tile_m = kwargs["tile_m"]
# tile_n = kwargs["tile_n"]
# if kwargs["collective"]:
# tile_m *= 2
# tile_n *= 2
# if m < tile_m or n < tile_n:
# continue
# if tile_n > 512:
# continue
# if (m // tile_m) % kwargs["grid_tile_m"]:
# continue
# try:
# with mlir.make_ir_context(), ir.Location.unknown():
# f = build_kernel(m, k, n, jnp.float16, **kwargs)
# _, runtime = profiler.measure(f)(a, b)
# except ValueError as e:
# if "Mosaic GPU kernel exceeds available shared memory" not in str(e):
# raise
# runtime = float("inf")
# else:
# print(" ".join(f"{k}={v}" for k, v in kwargs.items()), int(runtime * 1000))
# if runtime < best_runtime:
# best_runtime = runtime
# best_kwargs = kwargs
# if not best_kwargs:
# raise ValueError("No valid configuration found")
best_kwargs = dict(
max_concurrent_steps=2,
collective=True,
tile_m=128,
tile_n=256,
grid_tile_m=4,
)
with mlir.make_ir_context(), ir.Location.unknown():
kernel = build_kernel(m, k, n, jnp.float16, **best_kwargs)
for i in range(50):
jax.block_until_ready(kernel(a, b))

with mlir.make_ir_context(), ir.Location.unknown():
d, runtime = profiler.measure(build_kernel(m, k, n, jnp.float16, **best_kwargs))(a, b)
Expand Down
30 changes: 15 additions & 15 deletions jax/experimental/mosaic/gpu/fragmented_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -1376,26 +1376,26 @@ def bitcast(self, elt: ir.Type, *, output_is_signed: bool | None = None):
)

def __getitem__(self, idx):
if self.layout != WGMMA_LAYOUT:
raise NotImplementedError("Only WGMMA layouts support slicing")
if not isinstance(self.layout, TiledLayout):
raise NotImplementedError("Only arrays with tiled layouts can be sliced")
base_idx, slice_shape, is_squeezed = utils.parse_indices(idx, self.shape)
if any(isinstance(idx, ir.Value) for idx in base_idx):
raise ValueError("Only static slicing allowed")
if any(is_squeezed):
raise NotImplementedError("Only slicing implemented")
if (
base_idx[0] % 64
or slice_shape[0] % 64
or base_idx[1] % 8
or slice_shape[1] % 8
base_tile_shape = self.layout.base_tile_shape
if len(base_tile_shape) != len(self.shape):
raise NotImplementedError("Tiling has different rank than array")
if any(
b % t or l % t
for b, l, t in zip(base_idx, slice_shape, base_tile_shape, strict=True)
):
raise NotImplementedError("Only tile aligned slicing supported")
base_idx[0] //= 64
slice_shape[0] //= 64
base_idx[1] //= 8
slice_shape[1] //= 8
new_regs = self.registers[
base_idx[0] : base_idx[0] + slice_shape[0],
base_idx[1] : base_idx[1] + slice_shape[1],
]
register_slices = tuple(
slice(b // t, (b + l) // t)
for b, l, t in zip(base_idx, slice_shape, base_tile_shape, strict=True)
)
new_regs = self.registers[register_slices]
return FragmentedArray(
_registers=new_regs, _layout=self.layout, _is_signed=self.is_signed
)
Expand Down
12 changes: 11 additions & 1 deletion jax/experimental/mosaic/gpu/tcgen05.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,6 +477,16 @@ def tmem_load(tmem_addr, shape, num, pack: bool):
return [llvm.extractvalue(i32, regs, [i]) for i in range(num_out_regs)]


def wait_tmem_load():
llvm.inline_asm(
ir.Type.parse("!llvm.void"),
[],
"tcgen05.wait::ld.sync.aligned;",
"",
has_side_effects=True,
)


def tmem_store(tmem_addr, shape, num, regs, unpack: bool):
num_out_regs, regs_vector = _tmem_access_helper(shape, num)
pack_mod = ".unpack::16b" if unpack else ""
Expand Down Expand Up @@ -832,7 +842,7 @@ def _transfer_32xcols(
regs_per_instr = atom_shape[0] * atom_shape[1] // (utils.WARP_SIZE * reg_packing)
# We artificially lower the instr_num compared to its limits, because higher
# values can lead to register spills..
instr_num = min(total_num, 64 // regs_per_instr)
instr_num = min(total_num, 32 // regs_per_instr)
assert 32 % atom_rows == 0
num_row_steps = 32 // atom_rows
for lane_step in range(num_row_steps):
Expand Down
21 changes: 19 additions & 2 deletions jax/experimental/mosaic/gpu/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -817,8 +817,19 @@ def update_parities(self, parities: ir.Value) -> tuple[ir.Value, ir.Value]:
)
return parity, arith.xori(parities, bitmask)

def arrive(self, arrival_count: int = 1, can_complete: bool = True):
def arrive(
self,
arrival_count: int = 1,
can_complete: bool = True,
for_tensor_core: bool = False,
):
i64 = ir.IntegerType.get_signless(64)
if for_tensor_core:
llvm.inline_asm(
ir.Type.parse("!llvm.void"),
[], "tcgen05.fence::before_thread_sync;", "",
has_side_effects=True,
)
if can_complete:
if arrival_count > 1:
count = c(arrival_count - 1, ir.IntegerType.get_signless(32))
Expand Down Expand Up @@ -982,11 +993,17 @@ def __iter__(self):
def __getitem__(self, offset):
return CollectiveBarrierRef(self.barrier[offset], self.cluster_mask)

def arrive(self):
def arrive(self, for_tensor_core: bool = False):
"""Arrives on a barrier in all blocks that share at least one of the coordinates along the collective dimensions.

Note that unlike in arrive, each warpgroup arrives once.
"""
if for_tensor_core:
llvm.inline_asm(
ir.Type.parse("!llvm.void"),
[], "tcgen05.fence::before_thread_sync;", "",
has_side_effects=True,
)
if self.barrier.num_barriers != 1:
raise ValueError("Can only arrive on a single barrier")
if self.cluster_mask is None:
Expand Down
Loading