-
Notifications
You must be signed in to change notification settings - Fork 293
[Bugfix][Enhancement] Fix a bug in previous commit and enhance cuda backend #887
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,106 @@ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # TileLang SM100 Support (Preview) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| This directory contains examples for TileLang's experimental SM100 architecture support. **This is a preview version** with limited functionality. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ## Current Limitations (Manual Implementation Required) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ### 1. Manual TCGEN5.MMA Management | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Users must manually handle TCGEN5MMA operations using: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| - `T.alloc_tmem()` - Allocate Tensor Memory | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| - `T.gemm()` with `wg_wait=-1` - Launch TCGEN5MMA without waiting | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| - Manual synchronization with mbarrier | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ### 2. Manual mbarrier Synchronization | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| TCGEN5MMA is asynchronous and requires explicit synchronization: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ```python | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| mbar = T.alloc_barrier(1) # expect-arrive-count = 1 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| T.gemm(A_shared, B_shared, C_tmem, trans_A, trans_B, mbar=mbar, wg_wait=-1, clear_accum=k==0) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| T.mbarrier_wait_parity(mbar, k%2) # Manual phase calculation required | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ``` | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ## Examples | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ### TCGEN5MMA Example (`gemm_tcgen5mma.py`) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Demonstrates TCGEN5MMA operations with: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| - Tensor Memory allocation | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| - Manual mbarrier synchronization | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| - TCGEN5MMA gemm operations | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ### Traditional MMA Example (`gemm_mma.py`) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Shows standard MMA operations that work across architectures for comparison. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ## Code Example | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| The following code is based on `gemm_tcgen5mma.py`, demonstrating TCGEN5MMA matrix multiplication: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ```python | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| import torch | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| import tilelang | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| import tilelang.language as T | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| @T.prim_func | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def main( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| A: T.Tensor((M, K), "bfloat16"), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| B: T.Tensor((N, K), "bfloat16"), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| C: T.Tensor((M, N), "bfloat16"), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=256) as (bx, by): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # 1. Allocate memory buffers | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| A_shared = T.alloc_shared((block_M, block_K), "bfloat16") # A matrix shared memory | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| B_shared = T.alloc_shared((block_N, block_K), "bfloat16") # B matrix shared memory | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| C_tmem = T.alloc_tmem([block_M, block_N], "float") # TCGEN5MMA output to Tensor Memory | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| mbar = T.alloc_barrier(1) # mbarrier synchronization primitive | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| C_local = T.alloc_fragment((block_M, block_N), "float") # Register storage | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| C_shared = T.alloc_shared((block_M, block_N), "bfloat16") # Output shared memory | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # 2. Main computation loop | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=1): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Data loading: global memory to shared memory | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| T.copy(A[by * block_M, k * block_K], A_shared) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| T.copy(B[bx * block_N, k * block_K], B_shared) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # TCGEN5MMA computation: asynchronous launch, output to Tensor Memory | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| T.gemm(A_shared, B_shared, C_tmem, trans_A=False, trans_B=True, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| mbar=mbar, wg_wait=-1, clear_accum=k==0) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Critical: wait for TCGEN5MMA completion | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| T.mbarrier_wait_parity(mbar, k%2) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # 3. Output processing (only subset of threads) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| T.copy(C_tmem, C_local) # Tensor Memory → registers | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| T.copy(C_local, C_shared) # registers → shared memory | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # 4. Write back to global memory | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| T.copy(C_shared, C[by * block_M, bx * block_N]) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ``` | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ### Compilation and Usage | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ```python | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Parameter setup | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| M, N, K = 4096, 4096, 8192 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| block_M, block_N, block_K = 128, 256, 128 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Compile kernel | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| jit_kernel = tilelang.compile(func, out_idx=[2], target="cuda", pass_configs={ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, # Required | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, # Required | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| }) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Run test | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| a = torch.randn(M, K, device="cuda", dtype=torch.bfloat16) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| b = torch.randn(N, K, device="cuda", dtype=torch.bfloat16) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| c = jit_kernel(a, b) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Verify correctness | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ref_c = (a.to(torch.float) @ b.T.to(torch.float)).to(torch.bfloat16) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Performance benchmark | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| profiler = jit_kernel.get_profiler() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| latency = profiler.do_bench() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| print(f"Latency: {latency} ms") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| print(f"Performance: {2 * M * N * K / (latency/1e3) / 1e12:.2f} TFLOPS") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+85
to
+104
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fix the example to compile Lines 86-93 invoke -jit_kernel = tilelang.compile(func, out_idx=[2], target="cuda", pass_configs={
+jit_kernel = tilelang.compile(main, out_idx=[2], target="cuda", pass_configs={📝 Committable suggestion
Suggested change
🤖 Prompt for AI Agents |
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ``` | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,94 @@ | ||
| import tilelang | ||
| import tilelang.language as T | ||
|
|
||
|
|
||
| # add decorator @tilelang.jit if you want to return a torch function | ||
| # @tilelang.jit | ||
| def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"): | ||
|
|
||
| @T.prim_func | ||
| def main( | ||
| A: T.Tensor((M, K), dtype), | ||
| B: T.Tensor((N, K), dtype), | ||
| C: T.Tensor((M, N), dtype), | ||
| ): | ||
| # Initialize Kernel Context | ||
| with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=256) as (bx, by): | ||
| A_shared = T.alloc_shared((block_M, block_K), dtype) | ||
| B_shared = T.alloc_shared((block_N, block_K), dtype) | ||
| C_local = T.alloc_fragment((block_M, block_N), accum_dtype) | ||
|
|
||
| # Clear local accumulation | ||
| T.clear(C_local) | ||
|
|
||
| for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=0): | ||
| # Copy tile of A | ||
| # This is a sugar syntax for parallelized copy | ||
| # for i, k in T.Parallel(M, block_K): | ||
| # A_shared[i, k] = A[by * block_M + i, ko * block_K + k] | ||
| T.copy(A[by * block_M, ko * block_K], A_shared) | ||
|
|
||
| # Copy tile of B | ||
| T.copy(B[bx * block_N, ko * block_K], B_shared) | ||
|
|
||
| # Perform a tile-level GEMM on the shared buffers | ||
| # Currently we dispatch to the cute/hip on Nvidia/AMD GPUs | ||
| T.gemm(A_shared, B_shared, C_local, transpose_B=True) | ||
|
|
||
| # Copy result back to global memory | ||
| T.copy(C_local, C[by * block_M, bx * block_N]) | ||
|
|
||
| return main | ||
|
|
||
|
|
||
| M = 128 # M = T.symbolic("m") if you want to use dynamic shape | ||
| N = 128 | ||
| K = 32 | ||
| block_M = 128 | ||
| block_N = 128 | ||
| block_K = 32 | ||
|
|
||
| # 1. Define the kernel (matmul) and compile/lower it into an executable module | ||
| func = matmul(M, N, K, block_M, block_N, block_K) | ||
|
|
||
| # 2. Compile the kernel into a torch function | ||
| # out_idx specifies the index of the output buffer in the argument list | ||
| # if out_idx is specified, the tensor will be created during runtime | ||
| # target currently can be "cuda" or "hip" or "cpu". | ||
| jit_kernel = tilelang.compile( | ||
| func, | ||
| out_idx=[2], | ||
| target="cuda", | ||
| pass_configs={ | ||
| tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, | ||
| tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, | ||
| }) | ||
| print(jit_kernel.get_kernel_source()) | ||
| # 3. Test the kernel in Python with PyTorch data | ||
| import torch | ||
|
|
||
| # Create random input tensors on the GPU | ||
| a = torch.randn(M, K, device="cuda", dtype=torch.float16) | ||
| b = torch.randn(N, K, device="cuda", dtype=torch.float16) | ||
|
|
||
| # Run the kernel through the Profiler | ||
| c = jit_kernel(a, b) | ||
|
|
||
| print(c) | ||
| # Reference multiplication using PyTorch | ||
| ref_c = a @ b.T | ||
|
|
||
| # Validate correctness | ||
| torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2) | ||
| print("Kernel output matches PyTorch reference.") | ||
|
|
||
| # 4. Retrieve and inspect the generated CUDA source (optional) | ||
| # cuda_source = jit_kernel.get_kernel_source() | ||
| # print("Generated CUDA kernel:\n", cuda_source) | ||
|
|
||
| # 5.Profile latency with kernel | ||
| profiler = jit_kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal) | ||
|
|
||
| latency = profiler.do_bench() | ||
|
|
||
| print(f"Latency: {latency} ms") |
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,94 @@ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| import torch | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| import tilelang | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| import tilelang.language as T | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| tilelang.disable_cache() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def matmul( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| M, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| N, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| K, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| block_M, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| block_N, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| block_K, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| trans_A, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| trans_B, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| in_dtype, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| out_dtype, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| accum_dtype, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| num_stages, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| threads, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| A_shape = (K, M) if trans_A else (M, K) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| B_shape = (N, K) if trans_B else (K, N) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| @T.prim_func | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def main( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| A: T.Tensor(A_shape, in_dtype), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| B: T.Tensor(B_shape, in_dtype), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| C: T.Tensor((M, N), out_dtype), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| A_shared = T.alloc_shared(A_shared_shape, in_dtype) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| B_shared = T.alloc_shared(B_shared_shape, in_dtype) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| C_tmem = T.alloc_tmem([block_M, block_N], accum_dtype) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| mbar = T.alloc_barrier(1) # 这里的 1 是 expect-arrive-count | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| C_local = T.alloc_fragment((block_M, block_N), accum_dtype) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| C_shared = T.alloc_shared((block_M, block_N), out_dtype) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=1): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| T.copy(A[by * block_M, k * block_K], A_shared) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| T.copy(B[bx * block_N, k * block_K], B_shared) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+42
to
+44
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Respect the num_stages argument
- for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=1):
+ for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):📝 Committable suggestion
Suggested change
🤖 Prompt for AI Agents |
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| T.gemm( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| A_shared, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| B_shared, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| C_tmem, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| trans_A, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| trans_B, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| mbar=mbar, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| wg_wait=-1, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| clear_accum=k == 0) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| T.mbarrier_wait_parity(mbar, k % 2) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+43
to
+54
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fix TMEM tile loads for transpose flags These two - T.copy(A[by * block_M, k * block_K], A_shared)
- T.copy(B[bx * block_N, k * block_K], B_shared)
+ if trans_A:
+ T.copy(A[k * block_K, by * block_M], A_shared)
+ else:
+ T.copy(A[by * block_M, k * block_K], A_shared)
+ if trans_B:
+ T.copy(B[bx * block_N, k * block_K], B_shared)
+ else:
+ T.copy(B[k * block_K, bx * block_N], B_shared)📝 Committable suggestion
Suggested change
🤖 Prompt for AI Agents |
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if T.get_thread_binding() < 128: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| T.copy(C_tmem, C_local) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| T.copy(C_local, C_shared) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| T.copy(C_shared, C[by * block_M, bx * block_N]) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return main | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| M, N, K = 4096, 4096, 8192 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| block_M, block_N, block_K = 128, 256, 128 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| trans_A, trans_B = False, True | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| in_dtype, out_dtype, accum_dtype = "bfloat16", "bfloat16", "float" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| num_stages = 0 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| threads = 256 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| func = matmul(M, N, K, block_M, block_N, block_K, trans_A, trans_B, in_dtype, out_dtype, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| accum_dtype, num_stages, threads) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| jit_kernel = tilelang.compile( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| func, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| out_idx=[2], | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| target="cuda", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| pass_configs={ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| }) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| print(jit_kernel.get_kernel_source()) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| a = torch.randn(M, K, device="cuda", dtype=torch.bfloat16) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| b = torch.randn(N, K, device="cuda", dtype=torch.bfloat16) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| c = jit_kernel(a, b) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ref_c = (a.to(torch.float) @ b.T.to(torch.float)).to(torch.bfloat16) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| profiler = jit_kernel.get_profiler() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| latency = profiler.do_bench() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| print(f"Latency: {latency} ms") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| print(f"Flops: {2 * M * N * K / (latency/1e3) / 1e12} TFLOPS") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -13,7 +13,7 @@ | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| namespace tvm { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| namespace tl { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| static IterVar make_itervar(std::string name, PrimExpr dom) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| IterVar make_itervar(std::string name, PrimExpr dom) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Var var = Var(name, dom->dtype); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return IterVar(Range(0, dom), var, IterVarType::kDataPar); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -749,16 +749,41 @@ Layout makeGemmABLayoutHopper(int mat_stride, int mat_continuous, | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| element_size); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| int vector_size = 128 / element_size; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if (kfactor == 1 && element_size == 8) // int8 KxN | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if (mat_continuous % (vector_size * 8) == 0) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return makeFullBankSwizzleLayout(mat_stride, mat_continuous, element_size); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| else if (mat_continuous % (vector_size * 4) == 0) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return makeHalfBankSwizzleLayout(mat_stride, mat_continuous, element_size); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| else if (mat_continuous % (vector_size * 2) == 0) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return makeQuarterBankSwizzleLayout(mat_stride, mat_continuous, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| element_size); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| else if (mat_continuous % (vector_size * 8) == 0) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| else if (mat_continuous % vector_size == 0) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return makeGemmLayoutLinear(mat_stride, mat_continuous); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| else | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ICHECK(0) << "Unsupported layout for Hopper with stride=" << mat_stride | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| << ", continuous=" << mat_continuous | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| << ", element_size=" << element_size << ", kfactor=" << kfactor; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+752
to
+765
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Regression: Hopper layout now aborts for valid widths For non power-of-two tiles (e.g., else if (mat_continuous % (vector_size * 2) == 0)
return makeQuarterBankSwizzleLayout(mat_stride, mat_continuous,
element_size);
- else if (mat_continuous % vector_size == 0)
- return makeGemmLayoutLinear(mat_stride, mat_continuous);
- else
- ICHECK(0) << "Unsupported layout for Hopper with stride=" << mat_stride
- << ", continuous=" << mat_continuous
- << ", element_size=" << element_size << ", kfactor=" << kfactor;
+ else if (mat_continuous % vector_size == 0)
+ return makeGemmLayoutLinear(mat_stride, mat_continuous);
+ return makeGemmABLayoutPadded(mat_stride, mat_continuous, element_size);
}📝 Committable suggestion
Suggested change
🤖 Prompt for AI Agents |
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Layout makeGemmABLayoutSm100(int mat_stride, int mat_continuous, int continuity, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| int element_size, int kfactor) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if (element_size == 64) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ICHECK(0) << "float64 on sm100 is not supported now"; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| int vector_size = 128 / element_size; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if (mat_continuous % (vector_size * 8) == 0) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return makeFullBankSwizzleLayout(mat_stride, mat_continuous, element_size); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| else if (mat_continuous % (vector_size * 4) == 0) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return makeHalfBankSwizzleLayout(mat_stride, mat_continuous, element_size); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| else | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| else if (mat_continuous % (vector_size * 2) == 0) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return makeQuarterBankSwizzleLayout(mat_stride, mat_continuous, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| element_size); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| else if (mat_continuous % vector_size == 0) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return makeGemmLayoutLinear(mat_stride, mat_continuous); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| else | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ICHECK(0) << "Unsupported layout for sm100 with stride=" << mat_stride | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| << ", continuous=" << mat_continuous | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| << ", element_size=" << element_size << ", kfactor=" << kfactor; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| __builtin_unreachable(); // to prevent compiler warning | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+767
to
+786
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. SM100 layout should degrade gracefully The new SM100 path has the same issue: any case where else if (mat_continuous % (vector_size * 2) == 0)
return makeQuarterBankSwizzleLayout(mat_stride, mat_continuous,
element_size);
- else if (mat_continuous % vector_size == 0)
- return makeGemmLayoutLinear(mat_stride, mat_continuous);
- else
- ICHECK(0) << "Unsupported layout for sm100 with stride=" << mat_stride
- << ", continuous=" << mat_continuous
- << ", element_size=" << element_size << ", kfactor=" << kfactor;
- __builtin_unreachable(); // to prevent compiler warning
+ else if (mat_continuous % vector_size == 0)
+ return makeGemmLayoutLinear(mat_stride, mat_continuous);
+ return makeGemmABLayoutPadded(mat_stride, mat_continuous, element_size);
}📝 Committable suggestion
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Layout makeGemmABLayoutCDNA(int stride, int continuous, int element_size, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Undo the blanket disable of the bounds and virtual-call checks
These three checks were part of the “Retained categories” for a reason—they regularly flag genuine memory-safety and lifecycle bugs across our CUDA backends. Disabling them globally trades concrete signal for convenience and will let regressions slip through (e.g., silent out-of-bounds pointer arithmetic or virtual dispatch in constructors/destructors). If there are nuisance warnings in new SM100 codepaths, please suppress them locally with targeted
NOLINTor refactor the offending spots instead of removing the protections repo-wide.📝 Committable suggestion
🤖 Prompt for AI Agents