Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
82b7800
Add kernel selection option for GEMM v1 in environment settings
LeiWang1999 Nov 6, 2025
093d237
bug fix
LeiWang1999 Nov 6, 2025
7089b00
Merge branch 'main' of https://github.com/tile-ai/tilelang into v2_1106
LeiWang1999 Nov 6, 2025
03af3e7
Add kernel selection option for GEMM in environment settings
LeiWang1999 Nov 6, 2025
ca4416f
Refactor GEMM macro generator to use BufferRegion instead of Buffer
LeiWang1999 Nov 7, 2025
aba33f4
Merge branch 'main' of https://github.com/tile-ai/tilelang into v2_1106
LeiWang1999 Nov 7, 2025
b7b9a6f
Refactor GEMM functions to utilize BufferRegion for improved memory h…
LeiWang1999 Nov 7, 2025
32ed22f
Refactor GEMM code for improved readability and consistency
LeiWang1999 Nov 7, 2025
0127322
Update GEMM correctness evaluation and macro generator for improved f…
LeiWang1999 Nov 7, 2025
c0c45d6
Refactor GEMM and intrinsic files for improved clarity and functionality
LeiWang1999 Nov 7, 2025
bbc68ce
Enhance GEMM and MMA functionality for FP64 support
LeiWang1999 Nov 7, 2025
7aeb963
lint fix
LeiWang1999 Nov 7, 2025
27ba821
Refactor GEMM correctness evaluation and shared memory alignment hand…
LeiWang1999 Nov 8, 2025
09e3722
Enhance GEMM and MMA implementations with region-based memory handling
LeiWang1999 Nov 8, 2025
36d8e0e
Refactor test_tilelang_example_deepseek_v32.py to improve import stru…
LeiWang1999 Nov 8, 2025
15035cd
Enhance OnArrayDeclaration method to handle repeated buffer declarations
LeiWang1999 Nov 8, 2025
71f4284
Add abbreviation for bfloat16 data type in mfma_macro_generator.py
LeiWang1999 Nov 8, 2025
6f4b1c6
Refactor CodeGenTileLangHIP to enhance dtype handling and mfma call g…
LeiWang1999 Nov 8, 2025
005ffe9
lint fix
LeiWang1999 Nov 8, 2025
219b9e8
Enhance backend configuration in CMakeLists.txt and improve dtype han…
LeiWang1999 Nov 9, 2025
683d479
Update bfloat16 handling in CodeGenTileLangHIP and mfma_macro_generat…
LeiWang1999 Nov 9, 2025
05b68d0
Change logging level from WARNING to DLOG in LegalizeNegativeIndex fo…
LeiWang1999 Nov 9, 2025
4a74b62
Refactor attention sink examples to simplify index calculations
LeiWang1999 Nov 9, 2025
c2e3f08
Refactor attention sink examples to streamline index calculations
LeiWang1999 Nov 10, 2025
60e65d6
Merge branch 'main' of https://github.com/tile-ai/tilelang into v2_1106
LeiWang1999 Nov 10, 2025
f7fe22d
lint fix
LeiWang1999 Nov 10, 2025
a6bab65
bugfix
LeiWang1999 Nov 10, 2025
cfa62ac
Refactor reduce operation handling in CUDA and Python
LeiWang1999 Nov 10, 2025
87316cb
Fix ReduceOpNode to correctly compute AbsMax by using absolute values…
LeiWang1999 Nov 10, 2025
3f211ae
Enhance unit loop handling by refining annotation checks
LeiWang1999 Nov 10, 2025
23ef354
Merge branch 'main' of https://github.com/tile-ai/tilelang into v2_1106
LeiWang1999 Nov 11, 2025
502d71f
clean clode
LeiWang1999 Nov 11, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -105,3 +105,6 @@ cmake-build-*/

# Git version for sdist
.git_commit.txt

# pre-commit cache
.pre-commit-cache/*
101 changes: 81 additions & 20 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,50 @@ else()
endif()

# Configs
set(USE_CUDA OFF)
set(USE_ROCM OFF)
set(USE_METAL OFF)
set(TILELANG_BACKENDS CUDA ROCM METAL)

set(TILELANG_BACKEND_DOC_CUDA "Enable CUDA backend (ON/OFF/or CUDA SDK path)")
set(TILELANG_BACKEND_DOC_ROCM "Enable ROCm backend (ON/OFF/or ROCm SDK path)")
set(TILELANG_BACKEND_DOC_METAL "Enable Metal backend")

# TVM's config.cmake redefines USE_* options later, so we cache the user's choice
# (including explicit -DUSE_XXX arguments) before we include TVM and restore it
# afterwards.

macro(tilelang_define_backend_option BACKEND)
set(_backend_var "USE_${BACKEND}")
set(_doc "${TILELANG_BACKEND_DOC_${BACKEND}}")
set(_user_override_var "TILELANG_USER_OVERRIDE_${_backend_var}")

set(_user_override OFF)
if(DEFINED ${_user_override_var})
set(_user_override "${${_user_override_var}}")
endif()

if(DEFINED CACHE{${_backend_var}})
get_property(_cache_type CACHE ${_backend_var} PROPERTY TYPE)
if(_cache_type STREQUAL "UNINITIALIZED")
set(_user_override ON)
endif()
endif()
Comment on lines +88 to +93
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Invalid CMake syntax: DEFINED CACHE{...} is not a recognized construct.

CMake does not support DEFINED CACHE{...} syntax. To check if a cache entry exists and its type, use get_property() instead. For example:

if(DEFINED ${_backend_var})
  get_property(_cache_type CACHE ${_backend_var} PROPERTY TYPE)
  if(_cache_type STREQUAL "UNINITIALIZED")
    set(_user_override ON)
  endif()
endif()

This code will likely fail during CMake configuration.

- if(DEFINED CACHE{${_backend_var}})
+ if(DEFINED ${_backend_var})
    get_property(_cache_type CACHE ${_backend_var} PROPERTY TYPE)
🤖 Prompt for AI Agents
In CMakeLists.txt around lines 88 to 93, the conditional uses invalid CMake
syntax `DEFINED CACHE{...}`; replace it by first checking the variable name
existence and/or directly using get_property to query the CACHE entry type.
Specifically, call get_property(CACHE_VAR_TYPE CACHE ${_backend_var} PROPERTY
TYPE RESULT_VARIABLE) or use if(DEFINED ${_backend_var}) before calling
get_property, then test if the returned type equals "UNINITIALIZED" and
set(_user_override ON) accordingly so CMake config will not fail.


set(_default OFF)
if(DEFINED ${_backend_var})
set(_default "${${_backend_var}}")
endif()

option(${_backend_var} "${_doc}" "${_default}")
# Remember if the user explicitly set this option so that later logic
# won't auto-toggle backends they configured on the command line.
set(${_user_override_var} ${_user_override} CACHE INTERNAL
"User explicitly set ${_backend_var} during configuration" FORCE)
set(TILELANG_OPTION_${_backend_var} "${${_backend_var}}")
endmacro()

foreach(BACKEND IN LISTS TILELANG_BACKENDS)
tilelang_define_backend_option(${BACKEND})
endforeach()

set(PREBUILD_CYTHON ON)
# Configs end

Expand All @@ -78,6 +119,14 @@ if(EXISTS ${TVM_SOURCE}/cmake/config.cmake)
else()
message(FATAL_ERROR "Nor tvm provided or submodule checkout-ed.")
endif()
# Re-apply TileLang's preferred backend settings after TVM's config may have
# overridden the USE_* cache entries.
foreach(BACKEND IN LISTS TILELANG_BACKENDS)
set(_backend_var "USE_${BACKEND}")
set(_doc "${TILELANG_BACKEND_DOC_${BACKEND}}")
set(${_backend_var} ${TILELANG_OPTION_${_backend_var}} CACHE STRING "${_doc}" FORCE)
set(${_backend_var} ${TILELANG_OPTION_${_backend_var}})
endforeach()

# Include directories for TileLang
set(TILE_LANG_INCLUDES ${TVM_INCLUDES})
Expand All @@ -95,23 +144,35 @@ file(GLOB TILE_LANG_SRCS
src/target/intrin_rule*.cc
)

# Backend-specific checks and configs
if($ENV{USE_METAL})
set(USE_METAL ON)
elseif(APPLE)
message(STATUS "Enable Metal support by default.")
set(USE_METAL ON)
elseif($ENV{USE_ROCM})
set(USE_ROCM ON)
else()
if($ENV{USE_CUDA})
set(USE_CUDA ON)
elseif(DEFINED ENV{USE_CUDA} AND NOT $ENV{USE_CUDA})
# Build CPU-only when we explicitly disable CUDA
set(USE_CUDA OFF)
# Track if the user explicitly selected a backend via cache options.
set(TILELANG_BACKEND_USER_SELECTED OFF)
foreach(BACKEND IN LISTS TILELANG_BACKENDS)
set(_backend_var "USE_${BACKEND}")
set(_override_var "TILELANG_USER_OVERRIDE_${_backend_var}")
if(${_backend_var} OR ${_override_var})
set(TILELANG_BACKEND_USER_SELECTED ON)
endif()
endforeach()

# Only auto-select a backend when the user didn't specify one explicitly.
if(NOT TILELANG_BACKEND_USER_SELECTED)
if($ENV{USE_METAL})
set(USE_METAL ON)
elseif(APPLE)
message(STATUS "Enable Metal support by default.")
set(USE_METAL ON)
elseif($ENV{USE_ROCM})
set(USE_ROCM ON)
else()
message(STATUS "Enable CUDA support by default.")
set(USE_CUDA ON)
if($ENV{USE_CUDA})
set(USE_CUDA ON)
elseif(DEFINED ENV{USE_CUDA} AND NOT $ENV{USE_CUDA})
# Build CPU-only when we explicitly disable CUDA
set(USE_CUDA OFF)
else()
message(STATUS "Enable CUDA support by default.")
set(USE_CUDA ON)
endif()
endif()
endif()

Expand All @@ -125,7 +186,7 @@ if(USE_METAL)
elseif(USE_ROCM)
set(CMAKE_HIP_STANDARD 17)
include(${TVM_SOURCE}/cmake/utils/FindROCM.cmake)
find_rocm($ENV{USE_ROCM})
find_rocm(${USE_ROCM})
add_compile_definitions(__HIP_PLATFORM_AMD__ __HIP_PLATFORM_HCC__=1)

file(GLOB TILE_LANG_HIP_SRCS
Expand Down
22 changes: 8 additions & 14 deletions examples/attention_sink/example_gqa_sink_bwd_bhsd.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,13 +81,10 @@ def flash_fwd(
sinks[i] = Sinks[by]

end = T.min(T.ceildiv(seq_len, block_N), T.ceildiv((bx + 1) * block_M, block_N))
start = T.alloc_local([1], 'int32')
if window_size is not None:
start[0] = T.max(0, (bx * block_M - window_size) // block_N)
else:
start[0] = 0
start = T.max(0,
(bx * block_M - window_size) // block_N) if window_size is not None else 0

for k in T.Pipelined(start[0], end, num_stages=num_stages):
for k in T.Pipelined(start, end, num_stages=num_stages):
T.copy(K[bz, by // groups, k * block_N:(k + 1) * block_N, :], K_shared)
for i, j in T.Parallel(block_M, block_N):
q_idx = bx * block_M + i
Expand Down Expand Up @@ -266,14 +263,11 @@ def flash_bwd(
T.clear(dk)

loop_st = T.floordiv(by * block_M, block_N)
loop_ed = T.alloc_local([1], 'int32')
if window_size is not None:
loop_ed[0] = T.min(
T.ceildiv((by + 1) * block_M + window_size, block_N),
T.ceildiv(seq_len, block_N))
else:
loop_ed[0] = T.ceildiv(seq_len, block_N)
for k in T.Pipelined(loop_st, loop_ed[0], num_stages=num_stages):
loop_ed = T.min(
T.ceildiv((by + 1) * block_M + window_size, block_N), T.ceildiv(
seq_len, block_N)) if window_size is not None else T.ceildiv(seq_len, block_N)

for k in T.Pipelined(loop_st, loop_ed, num_stages=num_stages):
T.copy(Q[bz, bx, k * block_N:(k + 1) * block_N, :], q)
T.clear(qkT)
T.gemm(K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -172,14 +172,11 @@ def main(
end = T.min(
T.ceildiv(seq_kv, block_N), T.ceildiv((bx + 1) * block_M + past_len, block_N))

start = T.alloc_local([1], 'int32')
if window_size is not None:
start[0] = T.max(0, (bx * block_M + past_len - window_size) // block_N)
else:
start[0] = 0
start = T.max(0, (bx * block_M + past_len - window_size) //
block_N) if window_size is not None else 0

for k in T.Pipelined(
start[0],
start,
end,
num_stages=num_stages,
order=[-1, 0, 3, 1, -1, 2],
Expand Down
21 changes: 7 additions & 14 deletions examples/attention_sink/example_mha_sink_bwd_bhsd.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,13 +78,10 @@ def flash_fwd(
sinks[i] = Sinks[by]

end = T.min(T.ceildiv(seq_len, block_N), T.ceildiv((bx + 1) * block_M, block_N))
start = T.alloc_local([1], 'int32')
if window_size is not None:
start[0] = T.max(0, (bx * block_M - window_size) // block_N)
else:
start[0] = 0
start = T.max(0,
(bx * block_M - window_size) // block_N) if window_size is not None else 0

for k in T.Pipelined(start[0], end, num_stages=num_stages):
for k in T.Pipelined(start, end, num_stages=num_stages):
T.copy(K[bz, by, k * block_N:(k + 1) * block_N, :], K_shared)
for i, j in T.Parallel(block_M, block_N):
q_idx = bx * block_M + i
Expand Down Expand Up @@ -267,14 +264,10 @@ def flash_bwd(
T.clear(dk)

loop_st = T.floordiv(by * block_M, block_N)
loop_ed = T.alloc_local([1], 'int32')
if window_size is not None:
loop_ed[0] = T.min(
T.ceildiv((by + 1) * block_M + window_size, block_N),
T.ceildiv(seq_len, block_N))
else:
loop_ed[0] = T.ceildiv(seq_len, block_N)
for k in T.Pipelined(loop_st, loop_ed[0], num_stages=num_stages):
loop_ed = T.min(
T.ceildiv((by + 1) * block_M + window_size, block_N), T.ceildiv(
seq_len, block_N)) if window_size is not None else T.ceildiv(seq_len, block_N)
for k in T.Pipelined(loop_st, loop_ed, num_stages=num_stages):
T.copy(Q[bz, bx, k * block_N:(k + 1) * block_N, :], q)
T.clear(qkT)
T.gemm(K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
Expand Down
9 changes: 3 additions & 6 deletions examples/attention_sink/example_mha_sink_fwd_bhsd.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,13 +162,10 @@ def main(
end = T.min(
T.ceildiv(seq_kv, block_N), T.ceildiv((bx + 1) * block_M + past_len, block_N))

start = T.alloc_local([1], 'int32')
if window_size is not None:
start[0] = T.max(0, (bx * block_M + past_len - window_size) // block_N)
else:
start[0] = 0
start = T.max(0, (bx * block_M + past_len - window_size) //
block_N) if window_size is not None else 0

for k in T.Pipelined(start[0], end, num_stages=num_stages):
for k in T.Pipelined(start, end, num_stages=num_stages):
MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz)
Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum,
logsum)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -165,14 +165,11 @@ def main(
end = T.min(
T.ceildiv(seq_kv, block_N), T.ceildiv((bx + 1) * block_M + past_len, block_N))

start = T.alloc_local([1], 'int32')
if window_size is not None:
start[0] = T.max(0, (bx * block_M + past_len - window_size) // block_N)
else:
start[0] = 0
start = T.max(0, (bx * block_M + past_len - window_size) //
block_N) if window_size is not None else 0

for k in T.Pipelined(
start[0],
start,
end,
num_stages=num_stages,
order=[-1, 0, 3, 1, -1, 2],
Expand Down
20 changes: 10 additions & 10 deletions examples/deepseek_v32/test_tilelang_example_deepseek_v32.py
Original file line number Diff line number Diff line change
@@ -1,41 +1,41 @@
# ruff: noqa
import tilelang.testing

from topk_selector import test_topk_selector
from fp8_lighting_indexer import test_fp8_lighting_indexer
from sparse_mla_fwd import test_sparse_mla_fwd
from sparse_mla_fwd_pipelined import test_sparse_mla_fwd_pipelined
from sparse_mla_bwd import test_sparse_mla_bwd
import topk_selector
import fp8_lighting_indexer
import sparse_mla_fwd
import sparse_mla_fwd_pipelined
import sparse_mla_bwd


def test_example_topk_selector():
test_topk_selector()
topk_selector.test_topk_selector()


def test_example_fp8_lighting_indexer():
test_fp8_lighting_indexer(S=512, SKV=1024, H=32, HKV=1, D=64, kv_stride=1)
fp8_lighting_indexer.test_fp8_lighting_indexer(S=512, SKV=1024, H=32, HKV=1, D=64, kv_stride=1)


@tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version_ge(9, 0)
def test_example_sparse_mla_fwd():
# small shapes for testing
test_sparse_mla_fwd(
sparse_mla_fwd.test_sparse_mla_fwd(
S=256, SKV=1024, H=64, HKV=1, DQK=576, DV=512, topk=256, check_correctness=False)


@tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version_ge(9, 0)
def test_example_sparse_mla_fwd_pipelined():
# small shapes for testing
test_sparse_mla_fwd_pipelined(
sparse_mla_fwd_pipelined.test_sparse_mla_fwd_pipelined(
S=256, SKV=512, H=64, HKV=1, DQK=576, DV=512, topk=256, check_correctness=False)


@tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version_ge(9, 0)
def test_example_sparse_mla_bwd():
test_sparse_mla_bwd(
sparse_mla_bwd.test_sparse_mla_bwd(
S=256, SKV=512, H=64, HKV=1, DQKV=576, DV=512, topk=256, check_correctness=False)


Expand Down
2 changes: 1 addition & 1 deletion examples/linear_attention/example_linear_attn_fwd.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,6 @@ def fused_chunk_linear_attn_fwd(
T.atomic_add(
O[i_b, i * chunk_size:(i + 1) * chunk_size, i_h, i_v * BV:(i_v + 1) * BV],
o_shared)
#TODO: consider using vectorized atomic add or tma reduce for sm90

# Output final state
T.copy(h, final_state[i_b, i_h, i_k * BK:(i_k + 1) * BK, i_v * BV:(i_v + 1) * BV])
Expand All @@ -91,6 +90,7 @@ def fused_chunk_linear_attn_fwd(
def tl_fused_chunk_fwd(q, k, v):
B, S, H, D = q.shape
kernel = tl_fused_chunk_fwd_kernel(B, S, H, D, D)
print(kernel.get_kernel_source())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Make kernel source printing conditional to avoid cluttering benchmark output.

The print statement will execute every time tl_fused_chunk_fwd is called, including during the do_bench iterations at line 141. This will flood the output with kernel source text and make the benchmark results unreadable.

Consider one of these alternatives:

Option 1: Make it conditional on an environment variable or flag

-    print(kernel.get_kernel_source())
+    if os.environ.get('TILELANG_DEBUG_KERNEL_SOURCE'):
+        print(kernel.get_kernel_source())

Option 2: Print once before benchmarking in main()

 def tl_fused_chunk_fwd(q, k, v):
     B, S, H, D = q.shape
     kernel = tl_fused_chunk_fwd_kernel(B, S, H, D, D)
-    print(kernel.get_kernel_source())
     o = torch.zeros((B, S, H, D), device='cuda', dtype=torch.float32)
     h = kernel(q, k, v, o)
     return o, h

Then add in main() before benchmarking:

def main(B=1, S=512, H=16, D=128):
    # ... setup code ...
    
    # Print kernel source once for inspection
    kernel = tl_fused_chunk_fwd_kernel(B, S, H, D, D)
    print(kernel.get_kernel_source())
    
    o, h = tl_fused_chunk_fwd(q, k, v)
    # ... rest of main ...
🤖 Prompt for AI Agents
In examples/linear_attention/example_linear_attn_fwd.py around line 93, the
unconditional print(kernel.get_kernel_source()) inside tl_fused_chunk_fwd is
flooding benchmark output; remove that print from the hot path and instead
either (A) print the kernel source once before benchmarking in main() after
constructing the kernel, or (B) wrap the print in a conditional guard (e.g.,
check an environment variable or a debug flag) so it only emits when explicitly
enabled; update main() to print the kernel source once if you choose option A,
and ensure no prints occur inside the timed/iterative code paths.

o = torch.zeros((B, S, H, D), device='cuda', dtype=torch.float32)
h = kernel(q, k, v, o)
return o, h
Expand Down
7 changes: 0 additions & 7 deletions examples/linear_attention/example_retention_fwd.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,13 +51,6 @@ def chunk_retention_fwd(
o = T.alloc_fragment([chunk_size, BV], accum_dtype)
T.clear(h)

T.annotate_layout({
q: tl.layout.make_swizzled_layout(q),
k: tl.layout.make_swizzled_layout(k),
v: tl.layout.make_swizzled_layout(v),
h_shared: tl.layout.make_swizzled_layout(h_shared),
s_shared: tl.layout.make_swizzled_layout(s_shared),
})
T.use_swizzle(10)

for i in T.Pipelined(0, NT):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import tilelang
import tilelang.language as T

tilelang.disable_cache()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion | 🟠 Major

🧩 Analysis chain

Avoid top-level side effects; keep a single cache toggle in main().

Calling tilelang.disable_cache() at import time affects anyone importing this module. It’s also duplicated at Line 57. Remove the top-level call and keep a single, gated call in main().

Apply this minimal change:

- tilelang.disable_cache()

Optional verification (ensures a top-level wrapper exists and avoids surprises at runtime):


🏁 Script executed:

#!/bin/bash
# Verify a public wrapper exists and locate its definition(s)
rg -nP -C2 'def\s+disable_cache\s*\('
rg -nP -C2 'disable_cache\s*=\s*.*\.disable_cache'

Length of output: 195


🏁 Script executed:

#!/bin/bash
set -e
# Locate tilelang package __init__.py
find . -type f -path "./tilelang/__init__.py" -print
# Show first 200 lines to check for disable_cache wrapper
sed -n '1,200p' tilelang/__init__.py

Length of output: 3702


🏁 Script executed:

#!/bin/bash
# Display lines 1–80 of the example to confirm top-level call and context
sed -n '1,80p' examples/warp_specialize/example_warp_specialize_gemm_barrierpipe_stage2.py

Length of output: 2815


Remove top-level disable_cache call
Delete the import-time tilelang.disable_cache() (line 4) and rely solely on the call inside main() to avoid unexpected side effects on import.
Optional: Gate the print(jit_kernel.get_kernel_source()) behind a flag to prevent unneeded verbose output.

🤖 Prompt for AI Agents
In examples/warp_specialize/example_warp_specialize_gemm_barrierpipe_stage2.py
around line 4, remove the top-level tilelang.disable_cache() call so importing
the module has no side effects and rely on the existing disable_cache()
invocation inside main(); additionally, optionally wrap or gate the
print(jit_kernel.get_kernel_source()) behind a verbosity or debug flag so kernel
source is only printed when explicitly requested.



# add decorator @tilelang.jit if you want to return a torch function
# @tilelang.jit
Expand Down Expand Up @@ -52,11 +54,14 @@ def main(


def main(M=16384, N=16384, K=16384):
tilelang.disable_cache()
block_M = 128
block_N = 128
block_K = 64
jit_kernel = matmul(M, N, K, block_M, block_N, block_K)

print(jit_kernel.get_kernel_source())

import torch

a = torch.randn(M, K, device="cuda", dtype=torch.float16)
Expand Down
Loading
Loading