-
Notifications
You must be signed in to change notification settings - Fork 333
[Refactor] Add kernel selection option for GEMM v1 in environment settings #1200
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
82b7800
093d237
7089b00
03af3e7
ca4416f
aba33f4
b7b9a6f
32ed22f
0127322
c0c45d6
bbc68ce
7aeb963
27ba821
09e3722
36d8e0e
15035cd
71f4284
6f4b1c6
005ffe9
219b9e8
683d479
05b68d0
4a74b62
c2e3f08
60e65d6
f7fe22d
a6bab65
cfa62ac
87316cb
3f211ae
23ef354
502d71f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -105,3 +105,6 @@ cmake-build-*/ | |
|
|
||
| # Git version for sdist | ||
| .git_commit.txt | ||
|
|
||
| # pre-commit cache | ||
| .pre-commit-cache/* | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -80,7 +80,6 @@ def fused_chunk_linear_attn_fwd( | |
| T.atomic_add( | ||
| O[i_b, i * chunk_size:(i + 1) * chunk_size, i_h, i_v * BV:(i_v + 1) * BV], | ||
| o_shared) | ||
| #TODO: consider using vectorized atomic add or tma reduce for sm90 | ||
|
|
||
| # Output final state | ||
| T.copy(h, final_state[i_b, i_h, i_k * BK:(i_k + 1) * BK, i_v * BV:(i_v + 1) * BV]) | ||
|
|
@@ -91,6 +90,7 @@ def fused_chunk_linear_attn_fwd( | |
| def tl_fused_chunk_fwd(q, k, v): | ||
| B, S, H, D = q.shape | ||
| kernel = tl_fused_chunk_fwd_kernel(B, S, H, D, D) | ||
| print(kernel.get_kernel_source()) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Make kernel source printing conditional to avoid cluttering benchmark output. The print statement will execute every time Consider one of these alternatives: Option 1: Make it conditional on an environment variable or flag - print(kernel.get_kernel_source())
+ if os.environ.get('TILELANG_DEBUG_KERNEL_SOURCE'):
+ print(kernel.get_kernel_source())Option 2: Print once before benchmarking in main() def tl_fused_chunk_fwd(q, k, v):
B, S, H, D = q.shape
kernel = tl_fused_chunk_fwd_kernel(B, S, H, D, D)
- print(kernel.get_kernel_source())
o = torch.zeros((B, S, H, D), device='cuda', dtype=torch.float32)
h = kernel(q, k, v, o)
return o, hThen add in main() before benchmarking: def main(B=1, S=512, H=16, D=128):
# ... setup code ...
# Print kernel source once for inspection
kernel = tl_fused_chunk_fwd_kernel(B, S, H, D, D)
print(kernel.get_kernel_source())
o, h = tl_fused_chunk_fwd(q, k, v)
# ... rest of main ...🤖 Prompt for AI Agents |
||
| o = torch.zeros((B, S, H, D), device='cuda', dtype=torch.float32) | ||
| h = kernel(q, k, v, o) | ||
| return o, h | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,6 +1,8 @@ | ||
| import tilelang | ||
| import tilelang.language as T | ||
|
|
||
| tilelang.disable_cache() | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🛠️ Refactor suggestion | 🟠 Major 🧩 Analysis chainAvoid top-level side effects; keep a single cache toggle in main(). Calling tilelang.disable_cache() at import time affects anyone importing this module. It’s also duplicated at Line 57. Remove the top-level call and keep a single, gated call in main(). Apply this minimal change: - tilelang.disable_cache()Optional verification (ensures a top-level wrapper exists and avoids surprises at runtime): 🏁 Script executed: #!/bin/bash
# Verify a public wrapper exists and locate its definition(s)
rg -nP -C2 'def\s+disable_cache\s*\('
rg -nP -C2 'disable_cache\s*=\s*.*\.disable_cache'Length of output: 195 🏁 Script executed: #!/bin/bash
set -e
# Locate tilelang package __init__.py
find . -type f -path "./tilelang/__init__.py" -print
# Show first 200 lines to check for disable_cache wrapper
sed -n '1,200p' tilelang/__init__.pyLength of output: 3702 🏁 Script executed: #!/bin/bash
# Display lines 1–80 of the example to confirm top-level call and context
sed -n '1,80p' examples/warp_specialize/example_warp_specialize_gemm_barrierpipe_stage2.pyLength of output: 2815 Remove top-level disable_cache call 🤖 Prompt for AI Agents |
||
|
|
||
|
|
||
| # add decorator @tilelang.jit if you want to return a torch function | ||
| # @tilelang.jit | ||
|
|
@@ -52,11 +54,14 @@ def main( | |
|
|
||
|
|
||
| def main(M=16384, N=16384, K=16384): | ||
| tilelang.disable_cache() | ||
| block_M = 128 | ||
| block_N = 128 | ||
| block_K = 64 | ||
| jit_kernel = matmul(M, N, K, block_M, block_N, block_K) | ||
|
|
||
| print(jit_kernel.get_kernel_source()) | ||
|
|
||
| import torch | ||
|
|
||
| a = torch.randn(M, K, device="cuda", dtype=torch.float16) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Invalid CMake syntax:
DEFINED CACHE{...}is not a recognized construct.CMake does not support
DEFINED CACHE{...}syntax. To check if a cache entry exists and its type, useget_property()instead. For example:This code will likely fail during CMake configuration.
🤖 Prompt for AI Agents