Skip to content

Commit 6e32e69

Browse files
committed
Merge branch 'main' of https://github.com/tile-ai/tilelang
2 parents 724f8f0 + 30d8ded commit 6e32e69

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

51 files changed

+2031
-1320
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,3 +105,6 @@ cmake-build-*/
105105

106106
# Git version for sdist
107107
.git_commit.txt
108+
109+
# pre-commit cache
110+
.pre-commit-cache/*

CMakeLists.txt

Lines changed: 81 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -65,9 +65,50 @@ else()
6565
endif()
6666

6767
# Configs
68-
set(USE_CUDA OFF)
69-
set(USE_ROCM OFF)
70-
set(USE_METAL OFF)
68+
set(TILELANG_BACKENDS CUDA ROCM METAL)
69+
70+
set(TILELANG_BACKEND_DOC_CUDA "Enable CUDA backend (ON/OFF/or CUDA SDK path)")
71+
set(TILELANG_BACKEND_DOC_ROCM "Enable ROCm backend (ON/OFF/or ROCm SDK path)")
72+
set(TILELANG_BACKEND_DOC_METAL "Enable Metal backend")
73+
74+
# TVM's config.cmake redefines USE_* options later, so we cache the user's choice
75+
# (including explicit -DUSE_XXX arguments) before we include TVM and restore it
76+
# afterwards.
77+
78+
macro(tilelang_define_backend_option BACKEND)
79+
set(_backend_var "USE_${BACKEND}")
80+
set(_doc "${TILELANG_BACKEND_DOC_${BACKEND}}")
81+
set(_user_override_var "TILELANG_USER_OVERRIDE_${_backend_var}")
82+
83+
set(_user_override OFF)
84+
if(DEFINED ${_user_override_var})
85+
set(_user_override "${${_user_override_var}}")
86+
endif()
87+
88+
if(DEFINED CACHE{${_backend_var}})
89+
get_property(_cache_type CACHE ${_backend_var} PROPERTY TYPE)
90+
if(_cache_type STREQUAL "UNINITIALIZED")
91+
set(_user_override ON)
92+
endif()
93+
endif()
94+
95+
set(_default OFF)
96+
if(DEFINED ${_backend_var})
97+
set(_default "${${_backend_var}}")
98+
endif()
99+
100+
option(${_backend_var} "${_doc}" "${_default}")
101+
# Remember if the user explicitly set this option so that later logic
102+
# won't auto-toggle backends they configured on the command line.
103+
set(${_user_override_var} ${_user_override} CACHE INTERNAL
104+
"User explicitly set ${_backend_var} during configuration" FORCE)
105+
set(TILELANG_OPTION_${_backend_var} "${${_backend_var}}")
106+
endmacro()
107+
108+
foreach(BACKEND IN LISTS TILELANG_BACKENDS)
109+
tilelang_define_backend_option(${BACKEND})
110+
endforeach()
111+
71112
set(PREBUILD_CYTHON ON)
72113
# Configs end
73114

@@ -78,6 +119,14 @@ if(EXISTS ${TVM_SOURCE}/cmake/config.cmake)
78119
else()
79120
message(FATAL_ERROR "Nor tvm provided or submodule checkout-ed.")
80121
endif()
122+
# Re-apply TileLang's preferred backend settings after TVM's config may have
123+
# overridden the USE_* cache entries.
124+
foreach(BACKEND IN LISTS TILELANG_BACKENDS)
125+
set(_backend_var "USE_${BACKEND}")
126+
set(_doc "${TILELANG_BACKEND_DOC_${BACKEND}}")
127+
set(${_backend_var} ${TILELANG_OPTION_${_backend_var}} CACHE STRING "${_doc}" FORCE)
128+
set(${_backend_var} ${TILELANG_OPTION_${_backend_var}})
129+
endforeach()
81130

82131
# Include directories for TileLang
83132
set(TILE_LANG_INCLUDES ${TVM_INCLUDES})
@@ -95,23 +144,35 @@ file(GLOB TILE_LANG_SRCS
95144
src/target/intrin_rule*.cc
96145
)
97146

98-
# Backend-specific checks and configs
99-
if($ENV{USE_METAL})
100-
set(USE_METAL ON)
101-
elseif(APPLE)
102-
message(STATUS "Enable Metal support by default.")
103-
set(USE_METAL ON)
104-
elseif($ENV{USE_ROCM})
105-
set(USE_ROCM ON)
106-
else()
107-
if($ENV{USE_CUDA})
108-
set(USE_CUDA ON)
109-
elseif(DEFINED ENV{USE_CUDA} AND NOT $ENV{USE_CUDA})
110-
# Build CPU-only when we explicitly disable CUDA
111-
set(USE_CUDA OFF)
147+
# Track if the user explicitly selected a backend via cache options.
148+
set(TILELANG_BACKEND_USER_SELECTED OFF)
149+
foreach(BACKEND IN LISTS TILELANG_BACKENDS)
150+
set(_backend_var "USE_${BACKEND}")
151+
set(_override_var "TILELANG_USER_OVERRIDE_${_backend_var}")
152+
if(${_backend_var} OR ${_override_var})
153+
set(TILELANG_BACKEND_USER_SELECTED ON)
154+
endif()
155+
endforeach()
156+
157+
# Only auto-select a backend when the user didn't specify one explicitly.
158+
if(NOT TILELANG_BACKEND_USER_SELECTED)
159+
if($ENV{USE_METAL})
160+
set(USE_METAL ON)
161+
elseif(APPLE)
162+
message(STATUS "Enable Metal support by default.")
163+
set(USE_METAL ON)
164+
elseif($ENV{USE_ROCM})
165+
set(USE_ROCM ON)
112166
else()
113-
message(STATUS "Enable CUDA support by default.")
114-
set(USE_CUDA ON)
167+
if($ENV{USE_CUDA})
168+
set(USE_CUDA ON)
169+
elseif(DEFINED ENV{USE_CUDA} AND NOT $ENV{USE_CUDA})
170+
# Build CPU-only when we explicitly disable CUDA
171+
set(USE_CUDA OFF)
172+
else()
173+
message(STATUS "Enable CUDA support by default.")
174+
set(USE_CUDA ON)
175+
endif()
115176
endif()
116177
endif()
117178

@@ -125,7 +186,7 @@ if(USE_METAL)
125186
elseif(USE_ROCM)
126187
set(CMAKE_HIP_STANDARD 17)
127188
include(${TVM_SOURCE}/cmake/utils/FindROCM.cmake)
128-
find_rocm($ENV{USE_ROCM})
189+
find_rocm(${USE_ROCM})
129190
add_compile_definitions(__HIP_PLATFORM_AMD__ __HIP_PLATFORM_HCC__=1)
130191

131192
file(GLOB TILE_LANG_HIP_SRCS

examples/attention_sink/example_gqa_sink_bwd_bhsd.py

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -81,13 +81,10 @@ def flash_fwd(
8181
sinks[i] = Sinks[by]
8282

8383
end = T.min(T.ceildiv(seq_len, block_N), T.ceildiv((bx + 1) * block_M, block_N))
84-
start = T.alloc_local([1], 'int32')
85-
if window_size is not None:
86-
start[0] = T.max(0, (bx * block_M - window_size) // block_N)
87-
else:
88-
start[0] = 0
84+
start = T.max(0,
85+
(bx * block_M - window_size) // block_N) if window_size is not None else 0
8986

90-
for k in T.Pipelined(start[0], end, num_stages=num_stages):
87+
for k in T.Pipelined(start, end, num_stages=num_stages):
9188
T.copy(K[bz, by // groups, k * block_N:(k + 1) * block_N, :], K_shared)
9289
for i, j in T.Parallel(block_M, block_N):
9390
q_idx = bx * block_M + i
@@ -266,14 +263,11 @@ def flash_bwd(
266263
T.clear(dk)
267264

268265
loop_st = T.floordiv(by * block_M, block_N)
269-
loop_ed = T.alloc_local([1], 'int32')
270-
if window_size is not None:
271-
loop_ed[0] = T.min(
272-
T.ceildiv((by + 1) * block_M + window_size, block_N),
273-
T.ceildiv(seq_len, block_N))
274-
else:
275-
loop_ed[0] = T.ceildiv(seq_len, block_N)
276-
for k in T.Pipelined(loop_st, loop_ed[0], num_stages=num_stages):
266+
loop_ed = T.min(
267+
T.ceildiv((by + 1) * block_M + window_size, block_N), T.ceildiv(
268+
seq_len, block_N)) if window_size is not None else T.ceildiv(seq_len, block_N)
269+
270+
for k in T.Pipelined(loop_st, loop_ed, num_stages=num_stages):
277271
T.copy(Q[bz, bx, k * block_N:(k + 1) * block_N, :], q)
278272
T.clear(qkT)
279273
T.gemm(K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)

examples/attention_sink/example_gqa_sink_fwd_bhsd_wgmma_pipelined.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -172,14 +172,11 @@ def main(
172172
end = T.min(
173173
T.ceildiv(seq_kv, block_N), T.ceildiv((bx + 1) * block_M + past_len, block_N))
174174

175-
start = T.alloc_local([1], 'int32')
176-
if window_size is not None:
177-
start[0] = T.max(0, (bx * block_M + past_len - window_size) // block_N)
178-
else:
179-
start[0] = 0
175+
start = T.max(0, (bx * block_M + past_len - window_size) //
176+
block_N) if window_size is not None else 0
180177

181178
for k in T.Pipelined(
182-
start[0],
179+
start,
183180
end,
184181
num_stages=num_stages,
185182
order=[-1, 0, 3, 1, -1, 2],

examples/attention_sink/example_mha_sink_bwd_bhsd.py

Lines changed: 7 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -78,13 +78,10 @@ def flash_fwd(
7878
sinks[i] = Sinks[by]
7979

8080
end = T.min(T.ceildiv(seq_len, block_N), T.ceildiv((bx + 1) * block_M, block_N))
81-
start = T.alloc_local([1], 'int32')
82-
if window_size is not None:
83-
start[0] = T.max(0, (bx * block_M - window_size) // block_N)
84-
else:
85-
start[0] = 0
81+
start = T.max(0,
82+
(bx * block_M - window_size) // block_N) if window_size is not None else 0
8683

87-
for k in T.Pipelined(start[0], end, num_stages=num_stages):
84+
for k in T.Pipelined(start, end, num_stages=num_stages):
8885
T.copy(K[bz, by, k * block_N:(k + 1) * block_N, :], K_shared)
8986
for i, j in T.Parallel(block_M, block_N):
9087
q_idx = bx * block_M + i
@@ -267,14 +264,10 @@ def flash_bwd(
267264
T.clear(dk)
268265

269266
loop_st = T.floordiv(by * block_M, block_N)
270-
loop_ed = T.alloc_local([1], 'int32')
271-
if window_size is not None:
272-
loop_ed[0] = T.min(
273-
T.ceildiv((by + 1) * block_M + window_size, block_N),
274-
T.ceildiv(seq_len, block_N))
275-
else:
276-
loop_ed[0] = T.ceildiv(seq_len, block_N)
277-
for k in T.Pipelined(loop_st, loop_ed[0], num_stages=num_stages):
267+
loop_ed = T.min(
268+
T.ceildiv((by + 1) * block_M + window_size, block_N), T.ceildiv(
269+
seq_len, block_N)) if window_size is not None else T.ceildiv(seq_len, block_N)
270+
for k in T.Pipelined(loop_st, loop_ed, num_stages=num_stages):
278271
T.copy(Q[bz, bx, k * block_N:(k + 1) * block_N, :], q)
279272
T.clear(qkT)
280273
T.gemm(K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)

examples/attention_sink/example_mha_sink_fwd_bhsd.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -162,13 +162,10 @@ def main(
162162
end = T.min(
163163
T.ceildiv(seq_kv, block_N), T.ceildiv((bx + 1) * block_M + past_len, block_N))
164164

165-
start = T.alloc_local([1], 'int32')
166-
if window_size is not None:
167-
start[0] = T.max(0, (bx * block_M + past_len - window_size) // block_N)
168-
else:
169-
start[0] = 0
165+
start = T.max(0, (bx * block_M + past_len - window_size) //
166+
block_N) if window_size is not None else 0
170167

171-
for k in T.Pipelined(start[0], end, num_stages=num_stages):
168+
for k in T.Pipelined(start, end, num_stages=num_stages):
172169
MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz)
173170
Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum,
174171
logsum)

examples/attention_sink/example_mha_sink_fwd_bhsd_wgmma_pipelined.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -165,14 +165,11 @@ def main(
165165
end = T.min(
166166
T.ceildiv(seq_kv, block_N), T.ceildiv((bx + 1) * block_M + past_len, block_N))
167167

168-
start = T.alloc_local([1], 'int32')
169-
if window_size is not None:
170-
start[0] = T.max(0, (bx * block_M + past_len - window_size) // block_N)
171-
else:
172-
start[0] = 0
168+
start = T.max(0, (bx * block_M + past_len - window_size) //
169+
block_N) if window_size is not None else 0
173170

174171
for k in T.Pipelined(
175-
start[0],
172+
start,
176173
end,
177174
num_stages=num_stages,
178175
order=[-1, 0, 3, 1, -1, 2],

examples/deepseek_v32/test_tilelang_example_deepseek_v32.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,41 +1,41 @@
11
# ruff: noqa
22
import tilelang.testing
33

4-
from topk_selector import test_topk_selector
5-
from fp8_lighting_indexer import test_fp8_lighting_indexer
6-
from sparse_mla_fwd import test_sparse_mla_fwd
7-
from sparse_mla_fwd_pipelined import test_sparse_mla_fwd_pipelined
8-
from sparse_mla_bwd import test_sparse_mla_bwd
4+
import topk_selector
5+
import fp8_lighting_indexer
6+
import sparse_mla_fwd
7+
import sparse_mla_fwd_pipelined
8+
import sparse_mla_bwd
99

1010

1111
def test_example_topk_selector():
12-
test_topk_selector()
12+
topk_selector.test_topk_selector()
1313

1414

1515
def test_example_fp8_lighting_indexer():
16-
test_fp8_lighting_indexer(S=512, SKV=1024, H=32, HKV=1, D=64, kv_stride=1)
16+
fp8_lighting_indexer.test_fp8_lighting_indexer(S=512, SKV=1024, H=32, HKV=1, D=64, kv_stride=1)
1717

1818

1919
@tilelang.testing.requires_cuda
2020
@tilelang.testing.requires_cuda_compute_version_ge(9, 0)
2121
def test_example_sparse_mla_fwd():
2222
# small shapes for testing
23-
test_sparse_mla_fwd(
23+
sparse_mla_fwd.test_sparse_mla_fwd(
2424
S=256, SKV=1024, H=64, HKV=1, DQK=576, DV=512, topk=256, check_correctness=False)
2525

2626

2727
@tilelang.testing.requires_cuda
2828
@tilelang.testing.requires_cuda_compute_version_ge(9, 0)
2929
def test_example_sparse_mla_fwd_pipelined():
3030
# small shapes for testing
31-
test_sparse_mla_fwd_pipelined(
31+
sparse_mla_fwd_pipelined.test_sparse_mla_fwd_pipelined(
3232
S=256, SKV=512, H=64, HKV=1, DQK=576, DV=512, topk=256, check_correctness=False)
3333

3434

3535
@tilelang.testing.requires_cuda
3636
@tilelang.testing.requires_cuda_compute_version_ge(9, 0)
3737
def test_example_sparse_mla_bwd():
38-
test_sparse_mla_bwd(
38+
sparse_mla_bwd.test_sparse_mla_bwd(
3939
S=256, SKV=512, H=64, HKV=1, DQKV=576, DV=512, topk=256, check_correctness=False)
4040

4141

examples/linear_attention/example_linear_attn_fwd.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,6 @@ def fused_chunk_linear_attn_fwd(
8080
T.atomic_add(
8181
O[i_b, i * chunk_size:(i + 1) * chunk_size, i_h, i_v * BV:(i_v + 1) * BV],
8282
o_shared)
83-
#TODO: consider using vectorized atomic add or tma reduce for sm90
8483

8584
# Output final state
8685
T.copy(h, final_state[i_b, i_h, i_k * BK:(i_k + 1) * BK, i_v * BV:(i_v + 1) * BV])
@@ -91,6 +90,7 @@ def fused_chunk_linear_attn_fwd(
9190
def tl_fused_chunk_fwd(q, k, v):
9291
B, S, H, D = q.shape
9392
kernel = tl_fused_chunk_fwd_kernel(B, S, H, D, D)
93+
print(kernel.get_kernel_source())
9494
o = torch.zeros((B, S, H, D), device='cuda', dtype=torch.float32)
9595
h = kernel(q, k, v, o)
9696
return o, h

examples/linear_attention/example_retention_fwd.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -51,13 +51,6 @@ def chunk_retention_fwd(
5151
o = T.alloc_fragment([chunk_size, BV], accum_dtype)
5252
T.clear(h)
5353

54-
T.annotate_layout({
55-
q: tl.layout.make_swizzled_layout(q),
56-
k: tl.layout.make_swizzled_layout(k),
57-
v: tl.layout.make_swizzled_layout(v),
58-
h_shared: tl.layout.make_swizzled_layout(h_shared),
59-
s_shared: tl.layout.make_swizzled_layout(s_shared),
60-
})
6154
T.use_swizzle(10)
6255

6356
for i in T.Pipelined(0, NT):

0 commit comments

Comments
 (0)