Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .clang-tidy
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,10 @@ Checks: >
-cppcoreguidelines-pro-type-static-cast-downcast,
-performance-unnecessary-value-param,
-performance-enum-size,
-cppcoreguidelines-pro-bounds-pointer-arithmetic,
-cppcoreguidelines-pro-bounds-array-to-pointer-decay,
-clang-analyzer-deadcode.DeadStores,
-clang-analyzer-optin.cplusplus.VirtualCall,
Comment on lines +45 to +48
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Undo the blanket disable of the bounds and virtual-call checks

These three checks were part of the “Retained categories” for a reason—they regularly flag genuine memory-safety and lifecycle bugs across our CUDA backends. Disabling them globally trades concrete signal for convenience and will let regressions slip through (e.g., silent out-of-bounds pointer arithmetic or virtual dispatch in constructors/destructors). If there are nuisance warnings in new SM100 codepaths, please suppress them locally with targeted NOLINT or refactor the offending spots instead of removing the protections repo-wide.

-  -cppcoreguidelines-pro-bounds-pointer-arithmetic,
-  -cppcoreguidelines-pro-bounds-array-to-pointer-decay,
-  -clang-analyzer-optin.cplusplus.VirtualCall,
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
-cppcoreguidelines-pro-bounds-pointer-arithmetic,
-cppcoreguidelines-pro-bounds-array-to-pointer-decay,
-clang-analyzer-deadcode.DeadStores,
-clang-analyzer-optin.cplusplus.VirtualCall,
-clang-analyzer-deadcode.DeadStores,
🤖 Prompt for AI Agents
In .clang-tidy around lines 45 to 48, the rules
cppcoreguidelines-pro-bounds-pointer-arithmetic,
cppcoreguidelines-pro-bounds-array-to-pointer-decay and
clang-analyzer-optin.cplusplus.VirtualCall were removed; restore these checks by
removing them from the negative list so they run globally, and instead address
any false positives in SM100 or other new code by adding targeted NOLINT
comments at the exact offending lines or refactoring the code to eliminate the
warning (e.g., avoid pointer-arithmetic that can overflow, use safe indexing,
and avoid virtual calls in constructors/destructors).


WarningsAsErrors: '*'

Expand Down
106 changes: 106 additions & 0 deletions examples/gemm_sm100/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
# TileLang SM100 Support (Preview)

This directory contains examples for TileLang's experimental SM100 architecture support. **This is a preview version** with limited functionality.

## Current Limitations (Manual Implementation Required)

### 1. Manual TCGEN5.MMA Management
Users must manually handle TCGEN5MMA operations using:
- `T.alloc_tmem()` - Allocate Tensor Memory
- `T.gemm()` with `wg_wait=-1` - Launch TCGEN5MMA without waiting
- Manual synchronization with mbarrier

### 2. Manual mbarrier Synchronization
TCGEN5MMA is asynchronous and requires explicit synchronization:
```python
mbar = T.alloc_barrier(1) # expect-arrive-count = 1
T.gemm(A_shared, B_shared, C_tmem, trans_A, trans_B, mbar=mbar, wg_wait=-1, clear_accum=k==0)
T.mbarrier_wait_parity(mbar, k%2) # Manual phase calculation required
```

## Examples

### TCGEN5MMA Example (`gemm_tcgen5mma.py`)
Demonstrates TCGEN5MMA operations with:
- Tensor Memory allocation
- Manual mbarrier synchronization
- TCGEN5MMA gemm operations

### Traditional MMA Example (`gemm_mma.py`)
Shows standard MMA operations that work across architectures for comparison.

## Code Example

The following code is based on `gemm_tcgen5mma.py`, demonstrating TCGEN5MMA matrix multiplication:

```python
import torch
import tilelang
import tilelang.language as T

@T.prim_func
def main(
A: T.Tensor((M, K), "bfloat16"),
B: T.Tensor((N, K), "bfloat16"),
C: T.Tensor((M, N), "bfloat16"),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=256) as (bx, by):
# 1. Allocate memory buffers
A_shared = T.alloc_shared((block_M, block_K), "bfloat16") # A matrix shared memory
B_shared = T.alloc_shared((block_N, block_K), "bfloat16") # B matrix shared memory
C_tmem = T.alloc_tmem([block_M, block_N], "float") # TCGEN5MMA output to Tensor Memory
mbar = T.alloc_barrier(1) # mbarrier synchronization primitive

C_local = T.alloc_fragment((block_M, block_N), "float") # Register storage
C_shared = T.alloc_shared((block_M, block_N), "bfloat16") # Output shared memory

# 2. Main computation loop
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=1):
# Data loading: global memory to shared memory
T.copy(A[by * block_M, k * block_K], A_shared)
T.copy(B[bx * block_N, k * block_K], B_shared)

# TCGEN5MMA computation: asynchronous launch, output to Tensor Memory
T.gemm(A_shared, B_shared, C_tmem, trans_A=False, trans_B=True,
mbar=mbar, wg_wait=-1, clear_accum=k==0)

# Critical: wait for TCGEN5MMA completion
T.mbarrier_wait_parity(mbar, k%2)

# 3. Output processing (only subset of threads)
T.copy(C_tmem, C_local) # Tensor Memory → registers
T.copy(C_local, C_shared) # registers → shared memory

# 4. Write back to global memory
T.copy(C_shared, C[by * block_M, bx * block_N])
```

### Compilation and Usage

```python
# Parameter setup
M, N, K = 4096, 4096, 8192
block_M, block_N, block_K = 128, 256, 128

# Compile kernel
jit_kernel = tilelang.compile(func, out_idx=[2], target="cuda", pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, # Required
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, # Required
})

# Run test
a = torch.randn(M, K, device="cuda", dtype=torch.bfloat16)
b = torch.randn(N, K, device="cuda", dtype=torch.bfloat16)
c = jit_kernel(a, b)

# Verify correctness
ref_c = (a.to(torch.float) @ b.T.to(torch.float)).to(torch.bfloat16)
torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2)

# Performance benchmark
profiler = jit_kernel.get_profiler()
latency = profiler.do_bench()
print(f"Latency: {latency} ms")
print(f"Performance: {2 * M * N * K / (latency/1e3) / 1e12:.2f} TFLOPS")
Comment on lines +85 to +104
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Fix the example to compile main rather than an undefined symbol.

Lines 86-93 invoke tilelang.compile(func, ...), but the snippet only defines main. Copying the example as written will raise a NameError. Update the call (and any downstream references) to use the defined prim func so users can execute the example successfully.

-jit_kernel = tilelang.compile(func, out_idx=[2], target="cuda", pass_configs={
+jit_kernel = tilelang.compile(main, out_idx=[2], target="cuda", pass_configs={
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
# Compile kernel
jit_kernel = tilelang.compile(func, out_idx=[2], target="cuda", pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, # Required
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, # Required
})
# Run test
a = torch.randn(M, K, device="cuda", dtype=torch.bfloat16)
b = torch.randn(N, K, device="cuda", dtype=torch.bfloat16)
c = jit_kernel(a, b)
# Verify correctness
ref_c = (a.to(torch.float) @ b.T.to(torch.float)).to(torch.bfloat16)
torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2)
# Performance benchmark
profiler = jit_kernel.get_profiler()
latency = profiler.do_bench()
print(f"Latency: {latency} ms")
print(f"Performance: {2 * M * N * K / (latency/1e3) / 1e12:.2f} TFLOPS")
# Compile kernel
jit_kernel = tilelang.compile(main, out_idx=[2], target="cuda", pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, # Required
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, # Required
})
# Run test
a = torch.randn(M, K, device="cuda", dtype=torch.bfloat16)
b = torch.randn(N, K, device="cuda", dtype=torch.bfloat16)
c = jit_kernel(a, b)
# Verify correctness
ref_c = (a.to(torch.float) @ b.T.to(torch.float)).to(torch.bfloat16)
torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2)
# Performance benchmark
profiler = jit_kernel.get_profiler()
latency = profiler.do_bench()
print(f"Latency: {latency} ms")
print(f"Performance: {2 * M * N * K / (latency/1e3) / 1e12:.2f} TFLOPS")
🤖 Prompt for AI Agents
In examples/gemm_sm100/README.md around lines 85 to 104, the code calls
tilelang.compile(func, ...) but only defines a prim func named main; change the
compile invocation to tilelang.compile(main, ...) (and any downstream references
expecting that compiled object remain the same) so the example compiles the
defined symbol instead of the undefined name func; keep the existing
pass_configs, inputs, profiling and verification lines unchanged.

```

94 changes: 94 additions & 0 deletions examples/gemm_sm100/gemm_mma.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
import tilelang
import tilelang.language as T


# add decorator @tilelang.jit if you want to return a torch function
# @tilelang.jit
def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):

@T.prim_func
def main(
A: T.Tensor((M, K), dtype),
B: T.Tensor((N, K), dtype),
C: T.Tensor((M, N), dtype),
):
# Initialize Kernel Context
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=256) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K), dtype)
B_shared = T.alloc_shared((block_N, block_K), dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)

# Clear local accumulation
T.clear(C_local)

for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=0):
# Copy tile of A
# This is a sugar syntax for parallelized copy
# for i, k in T.Parallel(M, block_K):
# A_shared[i, k] = A[by * block_M + i, ko * block_K + k]
T.copy(A[by * block_M, ko * block_K], A_shared)

# Copy tile of B
T.copy(B[bx * block_N, ko * block_K], B_shared)

# Perform a tile-level GEMM on the shared buffers
# Currently we dispatch to the cute/hip on Nvidia/AMD GPUs
T.gemm(A_shared, B_shared, C_local, transpose_B=True)

# Copy result back to global memory
T.copy(C_local, C[by * block_M, bx * block_N])

return main


M = 128 # M = T.symbolic("m") if you want to use dynamic shape
N = 128
K = 32
block_M = 128
block_N = 128
block_K = 32

# 1. Define the kernel (matmul) and compile/lower it into an executable module
func = matmul(M, N, K, block_M, block_N, block_K)

# 2. Compile the kernel into a torch function
# out_idx specifies the index of the output buffer in the argument list
# if out_idx is specified, the tensor will be created during runtime
# target currently can be "cuda" or "hip" or "cpu".
jit_kernel = tilelang.compile(
func,
out_idx=[2],
target="cuda",
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
})
print(jit_kernel.get_kernel_source())
# 3. Test the kernel in Python with PyTorch data
import torch

# Create random input tensors on the GPU
a = torch.randn(M, K, device="cuda", dtype=torch.float16)
b = torch.randn(N, K, device="cuda", dtype=torch.float16)

# Run the kernel through the Profiler
c = jit_kernel(a, b)

print(c)
# Reference multiplication using PyTorch
ref_c = a @ b.T

# Validate correctness
torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2)
print("Kernel output matches PyTorch reference.")

# 4. Retrieve and inspect the generated CUDA source (optional)
# cuda_source = jit_kernel.get_kernel_source()
# print("Generated CUDA kernel:\n", cuda_source)

# 5.Profile latency with kernel
profiler = jit_kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal)

latency = profiler.do_bench()

print(f"Latency: {latency} ms")
94 changes: 94 additions & 0 deletions examples/gemm_sm100/gemm_tcgen5mma.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
import torch
import tilelang
import tilelang.language as T

tilelang.disable_cache()


def matmul(
M,
N,
K,
block_M,
block_N,
block_K,
trans_A,
trans_B,
in_dtype,
out_dtype,
accum_dtype,
num_stages,
threads,
):
A_shape = (K, M) if trans_A else (M, K)
B_shape = (N, K) if trans_B else (K, N)
A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K)
B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N)

@T.prim_func
def main(
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
B_shared = T.alloc_shared(B_shared_shape, in_dtype)
C_tmem = T.alloc_tmem([block_M, block_N], accum_dtype)
mbar = T.alloc_barrier(1) # 这里的 1 是 expect-arrive-count
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
C_shared = T.alloc_shared((block_M, block_N), out_dtype)

for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=1):
T.copy(A[by * block_M, k * block_K], A_shared)
T.copy(B[bx * block_N, k * block_K], B_shared)
Comment on lines +42 to +44
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Respect the num_stages argument

matmul advertises a tunable num_stages, but the pipelined loop hardcodes num_stages=1, so callers (including the example, which passes 0) silently get a different staging depth than requested. That breaks tuning knobs and can desynchronize the mbarrier usage. Please plumb the parameter through.

-            for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=1):
+            for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=1):
T.copy(A[by * block_M, k * block_K], A_shared)
T.copy(B[bx * block_N, k * block_K], B_shared)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
T.copy(A[by * block_M, k * block_K], A_shared)
T.copy(B[bx * block_N, k * block_K], B_shared)
🤖 Prompt for AI Agents
In examples/gemm_sm100/gemm_tcgen5mma.py around lines 42 to 44, the pipelined
loop currently hardcodes num_stages=1 which ignores the matmul’s tunable
num_stages (callers may pass 0 or other values) and can desynchronize mbarrier
usage; replace the hardcoded 1 with the function/local parameter that holds the
requested staging depth (e.g., num_stages) so the call reads
T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages), ensuring the variable
is in scope and propagated from the matmul signature.

T.gemm(
A_shared,
B_shared,
C_tmem,
trans_A,
trans_B,
mbar=mbar,
wg_wait=-1,
clear_accum=k == 0)
T.mbarrier_wait_parity(mbar, k % 2)
Comment on lines +43 to +54
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Fix TMEM tile loads for transpose flags

These two T.copy statements ignore trans_A/trans_B and always fetch tiles as if A were row-major and B were transposed. For trans_A=True, A_shape is (K, M) yet we still step the first dimension with by * block_M, which walks past the K extent; for trans_B=False, B_shape is (K, N) but we index its first dimension with bx * block_N. Anything outside the single (False, True) combination used in the demo will read the wrong region or fall off the tensor. The unit-test kernel in testing/python/kernel/test_tilelang_kernel_gemm.py fixes this by branching on the transpose flags—please mirror that logic here.

-                T.copy(A[by * block_M, k * block_K], A_shared)
-                T.copy(B[bx * block_N, k * block_K], B_shared)
+                if trans_A:
+                    T.copy(A[k * block_K, by * block_M], A_shared)
+                else:
+                    T.copy(A[by * block_M, k * block_K], A_shared)
+                if trans_B:
+                    T.copy(B[bx * block_N, k * block_K], B_shared)
+                else:
+                    T.copy(B[k * block_K, bx * block_N], B_shared)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
T.copy(A[by * block_M, k * block_K], A_shared)
T.copy(B[bx * block_N, k * block_K], B_shared)
T.gemm(
A_shared,
B_shared,
C_tmem,
trans_A,
trans_B,
mbar=mbar,
wg_wait=-1,
clear_accum=k == 0)
T.mbarrier_wait_parity(mbar, k % 2)
if trans_A:
T.copy(A[k * block_K, by * block_M], A_shared)
else:
T.copy(A[by * block_M, k * block_K], A_shared)
if trans_B:
T.copy(B[bx * block_N, k * block_K], B_shared)
else:
T.copy(B[k * block_K, bx * block_N], B_shared)
T.gemm(
A_shared,
B_shared,
C_tmem,
trans_A,
trans_B,
mbar=mbar,
wg_wait=-1,
clear_accum=k == 0)
T.mbarrier_wait_parity(mbar, k % 2)
🤖 Prompt for AI Agents
In examples/gemm_sm100/gemm_tcgen5mma.py around lines 43 to 54, the two T.copy
calls always index A and B as if A is (M,K) and B is (N,K) and ignore
trans_A/trans_B; this causes out-of-bounds or wrong tiles when transposed. Fix
by branching on trans_A and trans_B like the unit test: when trans_A is False
copy A[by * block_M, k * block_K] (tile dims M x K), but when trans_A is True
copy A[k * block_K, by * block_M] (tile dims K x M); similarly, when trans_B is
True copy B[bx * block_N, k * block_K] (tile dims N x K), else when trans_B is
False copy B[k * block_K, bx * block_N] (tile dims K x N). Mirror the
index-swapping and tile extents from
testing/python/kernel/test_tilelang_kernel_gemm.py so TMEM loads match the
declared shapes.


if T.get_thread_binding() < 128:
T.copy(C_tmem, C_local)
T.copy(C_local, C_shared)

T.copy(C_shared, C[by * block_M, bx * block_N])

return main


M, N, K = 4096, 4096, 8192
block_M, block_N, block_K = 128, 256, 128
trans_A, trans_B = False, True
in_dtype, out_dtype, accum_dtype = "bfloat16", "bfloat16", "float"
num_stages = 0
threads = 256

func = matmul(M, N, K, block_M, block_N, block_K, trans_A, trans_B, in_dtype, out_dtype,
accum_dtype, num_stages, threads)
jit_kernel = tilelang.compile(
func,
out_idx=[2],
target="cuda",
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
})

print(jit_kernel.get_kernel_source())

a = torch.randn(M, K, device="cuda", dtype=torch.bfloat16)
b = torch.randn(N, K, device="cuda", dtype=torch.bfloat16)
c = jit_kernel(a, b)
ref_c = (a.to(torch.float) @ b.T.to(torch.float)).to(torch.bfloat16)
torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2)

profiler = jit_kernel.get_profiler()
latency = profiler.do_bench()
print(f"Latency: {latency} ms")
print(f"Flops: {2 * M * N * K / (latency/1e3) / 1e12} TFLOPS")
33 changes: 29 additions & 4 deletions src/layout/gemm_layouts.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
namespace tvm {
namespace tl {

static IterVar make_itervar(std::string name, PrimExpr dom) {
IterVar make_itervar(std::string name, PrimExpr dom) {
Var var = Var(name, dom->dtype);
return IterVar(Range(0, dom), var, IterVarType::kDataPar);
}
Expand Down Expand Up @@ -749,16 +749,41 @@ Layout makeGemmABLayoutHopper(int mat_stride, int mat_continuous,
element_size);
}
int vector_size = 128 / element_size;
if (kfactor == 1 && element_size == 8) // int8 KxN
if (mat_continuous % (vector_size * 8) == 0)
return makeFullBankSwizzleLayout(mat_stride, mat_continuous, element_size);
else if (mat_continuous % (vector_size * 4) == 0)
return makeHalfBankSwizzleLayout(mat_stride, mat_continuous, element_size);
else if (mat_continuous % (vector_size * 2) == 0)
return makeQuarterBankSwizzleLayout(mat_stride, mat_continuous,
element_size);
else if (mat_continuous % (vector_size * 8) == 0)
else if (mat_continuous % vector_size == 0)
return makeGemmLayoutLinear(mat_stride, mat_continuous);
else
ICHECK(0) << "Unsupported layout for Hopper with stride=" << mat_stride
<< ", continuous=" << mat_continuous
<< ", element_size=" << element_size << ", kfactor=" << kfactor;
}
Comment on lines +752 to +765
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Regression: Hopper layout now aborts for valid widths

For non power-of-two tiles (e.g., element_size == 8, mat_continuous == 48) we used to fall back to makeGemmABLayoutPadded, so Hopper kernels kept working. The new branch ends with ICHECK(0), which now fatals for those same shapes. That’s a correctness regression that will abort user programs. Please restore a padded/linear fallback instead of hard failing.

   else if (mat_continuous % (vector_size * 2) == 0)
     return makeQuarterBankSwizzleLayout(mat_stride, mat_continuous,
                                         element_size);
-  else if (mat_continuous % vector_size == 0)
-    return makeGemmLayoutLinear(mat_stride, mat_continuous);
-  else
-    ICHECK(0) << "Unsupported layout for Hopper with stride=" << mat_stride
-              << ", continuous=" << mat_continuous
-              << ", element_size=" << element_size << ", kfactor=" << kfactor;
+  else if (mat_continuous % vector_size == 0)
+    return makeGemmLayoutLinear(mat_stride, mat_continuous);
+  return makeGemmABLayoutPadded(mat_stride, mat_continuous, element_size);
 }
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if (mat_continuous % (vector_size * 8) == 0)
return makeFullBankSwizzleLayout(mat_stride, mat_continuous, element_size);
else if (mat_continuous % (vector_size * 4) == 0)
return makeHalfBankSwizzleLayout(mat_stride, mat_continuous, element_size);
else if (mat_continuous % (vector_size * 2) == 0)
return makeQuarterBankSwizzleLayout(mat_stride, mat_continuous,
element_size);
else if (mat_continuous % (vector_size * 8) == 0)
else if (mat_continuous % vector_size == 0)
return makeGemmLayoutLinear(mat_stride, mat_continuous);
else
ICHECK(0) << "Unsupported layout for Hopper with stride=" << mat_stride
<< ", continuous=" << mat_continuous
<< ", element_size=" << element_size << ", kfactor=" << kfactor;
}
if (mat_continuous % (vector_size * 8) == 0)
return makeFullBankSwizzleLayout(mat_stride, mat_continuous, element_size);
else if (mat_continuous % (vector_size * 4) == 0)
return makeHalfBankSwizzleLayout(mat_stride, mat_continuous, element_size);
else if (mat_continuous % (vector_size * 2) == 0)
return makeQuarterBankSwizzleLayout(mat_stride, mat_continuous,
element_size);
else if (mat_continuous % vector_size == 0)
return makeGemmLayoutLinear(mat_stride, mat_continuous);
return makeGemmABLayoutPadded(mat_stride, mat_continuous, element_size);
}
🤖 Prompt for AI Agents
In src/layout/gemm_layouts.cc around lines 752-765, the final ICHECK(0) causes
Hopper to abort for valid non-power-of-two widths (e.g., element_size==8,
mat_continuous==48); restore the previous padded/linear fallback instead of
fatally asserting by returning the padded GEMM A/B layout (same call used
before, e.g. makeGemmABLayoutPadded(mat_stride, mat_continuous, element_size,
kfactor) or, if that API isn’t available here, fall back to
makeGemmLayoutLinear/makeGemmABLayoutPadded as appropriate) so the code returns
a valid layout for those shapes rather than calling ICHECK(0).


Layout makeGemmABLayoutSm100(int mat_stride, int mat_continuous, int continuity,
int element_size, int kfactor) {
if (element_size == 64) {
ICHECK(0) << "float64 on sm100 is not supported now";
}
int vector_size = 128 / element_size;
if (mat_continuous % (vector_size * 8) == 0)
return makeFullBankSwizzleLayout(mat_stride, mat_continuous, element_size);
else if (mat_continuous % (vector_size * 4) == 0)
return makeHalfBankSwizzleLayout(mat_stride, mat_continuous, element_size);
else
else if (mat_continuous % (vector_size * 2) == 0)
return makeQuarterBankSwizzleLayout(mat_stride, mat_continuous,
element_size);
else if (mat_continuous % vector_size == 0)
return makeGemmLayoutLinear(mat_stride, mat_continuous);
else
ICHECK(0) << "Unsupported layout for sm100 with stride=" << mat_stride
<< ", continuous=" << mat_continuous
<< ", element_size=" << element_size << ", kfactor=" << kfactor;
__builtin_unreachable(); // to prevent compiler warning
Comment on lines +767 to +786
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

SM100 layout should degrade gracefully

The new SM100 path has the same issue: any case where mat_continuous isn’t a multiple of vector_size immediately hits the ICHECK(0). That blocks otherwise valid GEMM tiles (e.g., mat_continuous == 48 for int8) from compiling. Let’s mirror the Hopper fix and fall back to the padded layout instead of aborting.

   else if (mat_continuous % (vector_size * 2) == 0)
     return makeQuarterBankSwizzleLayout(mat_stride, mat_continuous,
                                         element_size);
-  else if (mat_continuous % vector_size == 0)
-    return makeGemmLayoutLinear(mat_stride, mat_continuous);
-  else
-    ICHECK(0) << "Unsupported layout for sm100 with stride=" << mat_stride
-              << ", continuous=" << mat_continuous
-              << ", element_size=" << element_size << ", kfactor=" << kfactor;
-  __builtin_unreachable(); // to prevent compiler warning
+  else if (mat_continuous % vector_size == 0)
+    return makeGemmLayoutLinear(mat_stride, mat_continuous);
+  return makeGemmABLayoutPadded(mat_stride, mat_continuous, element_size);
 }
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
Layout makeGemmABLayoutSm100(int mat_stride, int mat_continuous, int continuity,
int element_size, int kfactor) {
if (element_size == 64) {
ICHECK(0) << "float64 on sm100 is not supported now";
}
int vector_size = 128 / element_size;
if (mat_continuous % (vector_size * 8) == 0)
return makeFullBankSwizzleLayout(mat_stride, mat_continuous, element_size);
else if (mat_continuous % (vector_size * 4) == 0)
return makeHalfBankSwizzleLayout(mat_stride, mat_continuous, element_size);
else
else if (mat_continuous % (vector_size * 2) == 0)
return makeQuarterBankSwizzleLayout(mat_stride, mat_continuous,
element_size);
else if (mat_continuous % vector_size == 0)
return makeGemmLayoutLinear(mat_stride, mat_continuous);
else
ICHECK(0) << "Unsupported layout for sm100 with stride=" << mat_stride
<< ", continuous=" << mat_continuous
<< ", element_size=" << element_size << ", kfactor=" << kfactor;
__builtin_unreachable(); // to prevent compiler warning
Layout makeGemmABLayoutSm100(int mat_stride, int mat_continuous, int continuity,
int element_size, int kfactor) {
if (element_size == 64) {
ICHECK(0) << "float64 on sm100 is not supported now";
}
int vector_size = 128 / element_size;
if (mat_continuous % (vector_size * 8) == 0)
return makeFullBankSwizzleLayout(mat_stride, mat_continuous, element_size);
else if (mat_continuous % (vector_size * 4) == 0)
return makeHalfBankSwizzleLayout(mat_stride, mat_continuous, element_size);
else if (mat_continuous % (vector_size * 2) == 0)
return makeQuarterBankSwizzleLayout(mat_stride, mat_continuous,
element_size);
else if (mat_continuous % vector_size == 0)
return makeGemmLayoutLinear(mat_stride, mat_continuous);
return makeGemmABLayoutPadded(mat_stride, mat_continuous, element_size);
}

}

Layout makeGemmABLayoutCDNA(int stride, int continuous, int element_size,
Expand Down
3 changes: 3 additions & 0 deletions src/layout/layout.h
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ class Fragment : public Layout {

Var InputPlaceholder(size_t idx);
Var ReplicationPlaceholder();
IterVar make_itervar(std::string name, PrimExpr dom);

Fragment makeGemmFragment8x8();
Fragment makeGemmFragment8x8Transposed();
Expand Down Expand Up @@ -166,6 +167,8 @@ Layout makeGemmABLayout(int mat_stride, int mat_continuous, int continuity,
int element_size, int kfactor);
Layout makeGemmABLayoutHopper(int mat_stride, int mat_continuous,
int continuity, int element_size, int kfactor);
Layout makeGemmABLayoutSm100(int mat_stride, int mat_continuous, int continuity,
int element_size, int kfactor);
Layout makeGemmABLayoutCDNA(int stride, int continuous, int element_size,
int kfactor);

Expand Down
Loading
Loading