Skip to content

Commit a773027

Browse files
authored
[Language] Recommend using T.dynamic instead of T.symbolic (#1076)
* recommend using T.dynamic instead of T.symbolic * lint fix * lint fix
1 parent fd6cec5 commit a773027

19 files changed

+128
-162
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,7 @@ def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="flo
178178
return matmul_relu_kernel
179179

180180

181-
M = 1024 # M = T.symbolic("m") if you want to use dynamic shape
181+
M = 1024 # M = T.dynamic("m") if you want to use dynamic shape
182182
N = 1024
183183
K = 1024
184184
block_M = 128

docs/deeplearning_operators/elementwise.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def elementwise_add(
8989
In the compilation process above, a fixed shape was used. However, in practical usage, we often want the kernel to support dynamic shapes. So, how can we compile a kernel in TileLang to handle dynamic shapes? In TileLang, we can replace the target size with a dynamic symbolic value, making the dimension dynamic. The following example illustrates this:
9090

9191
```python
92-
program = elementwise_add(T.symbolic("N"), threads=256, dtype="bfloat16")
92+
program = elementwise_add(T.dynamic("N"), threads=256, dtype="bfloat16")
9393
kernel = tilelang.compile(program, out_idx=-1, target="cuda", execution_backend="cython")
9494
```
9595

examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_paged.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -223,12 +223,12 @@ def __init__(self, batch, heads, heads_kv, dim, dim_v, page_block_size, block_N,
223223
block_N=block_N,
224224
block_H=self.block_H,
225225
page_block_size=page_block_size,
226-
num_split=T.symbolic("num_split"),
226+
num_split=T.dynamic("num_split"),
227227
num_stages=2,
228228
threads=128,
229229
num_pages=num_pages,
230-
max_num_blocks_per_seq=T.symbolic("max_num_blocks_per_seq"),
231-
max_selected_blocks=T.symbolic("max_selected_blocks"),
230+
max_num_blocks_per_seq=T.dynamic("max_num_blocks_per_seq"),
231+
max_selected_blocks=T.dynamic("max_selected_blocks"),
232232
)
233233

234234
props = torch.cuda.get_device_properties(torch.device("cuda:0"))

examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_indice.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -206,11 +206,11 @@ def __init__(self, batch, heads, heads_kv, dim, dim_v, block_size):
206206
self.kernel = flashattn(batch, heads, heads_kv, dim, dim_v)(
207207
block_N=block_size,
208208
block_H=self.block_H,
209-
num_split=T.symbolic("num_split"),
209+
num_split=T.dynamic("num_split"),
210210
num_stages=2,
211211
threads=128,
212-
max_cache_seqlen=T.symbolic("max_cache_seqlen"),
213-
max_selected_blocks=T.symbolic("max_selected_blocks"))
212+
max_cache_seqlen=T.dynamic("max_cache_seqlen"),
213+
max_selected_blocks=T.dynamic("max_selected_blocks"))
214214

215215
props = torch.cuda.get_device_properties(torch.device("cuda:0"))
216216
self.num_sm = props.multi_processor_count
@@ -301,11 +301,11 @@ def sparse_gqa_decode_varlen_indice(query, key, value, block_indices, cache_seql
301301
kernel = flashattn(batch, heads, heads_kv, dim, dim_v)(
302302
block_N=block_size,
303303
block_H=block_H,
304-
num_split=T.symbolic("num_split"),
304+
num_split=T.dynamic("num_split"),
305305
num_stages=2,
306306
threads=128,
307-
max_cache_seqlen=T.symbolic("max_cache_seqlen"),
308-
max_selected_blocks=T.symbolic("max_selected_blocks"))
307+
max_cache_seqlen=T.dynamic("max_cache_seqlen"),
308+
max_selected_blocks=T.dynamic("max_selected_blocks"))
309309

310310
output = kernel(query, key, value, block_indices, cache_seqlens, glse, Output_partial)
311311
return output

examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_mask.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -193,11 +193,11 @@ def __init__(self, batch, heads, heads_kv, dim, dim_v, block_size):
193193
self.kernel = flashattn(batch, heads, heads_kv, dim, dim_v)(
194194
block_N=block_size,
195195
block_H=self.block_H,
196-
num_split=T.symbolic("num_split"),
196+
num_split=T.dynamic("num_split"),
197197
num_stages=2,
198198
threads=128,
199-
max_cache_seqlen=T.symbolic("max_cache_seqlen"),
200-
num_blocks=T.symbolic("num_blocks"))
199+
max_cache_seqlen=T.dynamic("max_cache_seqlen"),
200+
num_blocks=T.dynamic("num_blocks"))
201201

202202
props = torch.cuda.get_device_properties(torch.device("cuda:0"))
203203
self.num_sm = props.multi_processor_count
@@ -282,11 +282,11 @@ def sparse_gqa_decode_varlen_mask(query, key, value, block_mask, cache_seqlens,
282282
kernel = flashattn(batch, heads, heads_kv, dim, dim_v)(
283283
block_N=block_size,
284284
block_H=block_H,
285-
num_split=T.symbolic("num_split"),
285+
num_split=T.dynamic("num_split"),
286286
num_stages=2,
287287
threads=128,
288-
max_cache_seqlen=T.symbolic("max_cache_seqlen"),
289-
num_blocks=T.symbolic("num_blocks"))
288+
max_cache_seqlen=T.dynamic("max_cache_seqlen"),
289+
num_blocks=T.dynamic("num_blocks"))
290290
glse = torch.empty((batch, heads, num_split), dtype=torch.float32, device='cuda')
291291
Output_partial = torch.empty((batch, heads, num_split, dim_v),
292292
dtype=torch.float32,

examples/deepseek_v32/fp8_lighting_indexer.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -103,8 +103,8 @@ def mqa_attn_return_logits(
103103
accum_dtype = "float"
104104
index_dtype = "int32"
105105

106-
seq_len = T.symbolic("seq_len")
107-
seq_len_kv = T.symbolic("seq_len_kv")
106+
seq_len = T.dynamic("seq_len")
107+
seq_len_kv = T.dynamic("seq_len_kv")
108108

109109
index_q_shape = [seq_len * heads, index_dim]
110110
index_k_shape = [seq_len_kv, index_dim]
@@ -182,8 +182,8 @@ def clean_logits_(
182182
threads: int = 512,
183183
block_K: int = 4096,
184184
):
185-
seq_len = T.symbolic("seq_len")
186-
seq_len_kv = T.symbolic("seq_len_kv")
185+
seq_len = T.dynamic("seq_len")
186+
seq_len_kv = T.dynamic("seq_len_kv")
187187

188188
dtype = "float"
189189
indices_dtype = "int32"

examples/deepseek_v32/inference/kernel.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def fast_round_scale(amax, fp8_max_inv):
3434

3535
@tilelang.jit(pass_configs=pass_configs)
3636
def act_quant_kernel(N, in_dtype=BF16, out_dtype=FP8, scale_dtype=FP32, round_scale=False):
37-
M = T.symbolic("M")
37+
M = T.dynamic("M")
3838
fp8_min = -448.0
3939
fp8_max = 448.0
4040
fp8_max_inv = 1 / fp8_max
@@ -110,7 +110,7 @@ def act_quant(x: torch.Tensor,
110110
def fp8_gemm_kernel(N, K, out_dtype=BF16, accum_dtype="float32"):
111111
assert out_dtype in [BF16, "float32"]
112112

113-
M = T.symbolic("M")
113+
M = T.dynamic("M")
114114
group_size = 128
115115
block_M = 32
116116
block_N = 128
@@ -192,9 +192,9 @@ def fp8_gemm(a: torch.Tensor, a_s: torch.Tensor, b: torch.Tensor,
192192

193193
@tilelang.jit(out_idx=[4], pass_configs=pass_configs)
194194
def fp8_index_kernel(h: int, d: int):
195-
b = T.symbolic("b")
196-
m = T.symbolic("m")
197-
n = T.symbolic("n")
195+
b = T.dynamic("b")
196+
m = T.dynamic("m")
197+
n = T.dynamic("n")
198198

199199
blk_n1 = 512
200200
blk_n2 = 128

examples/deepseek_v32/sparse_mla_fwd.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,9 @@ def sparse_mla_fwd(
3737
else:
3838
sm_scale = sm_scale * 1.44269504 # log2(e)
3939

40-
batch = T.symbolic("batch")
41-
seq_len = T.symbolic("seq_len")
42-
seq_len_kv = T.symbolic("seq_len_kv")
40+
batch = T.dynamic("batch")
41+
seq_len = T.dynamic("seq_len")
42+
seq_len_kv = T.dynamic("seq_len_kv")
4343

4444
head_kv = heads // kv_group
4545
q_shape = [batch, seq_len, heads, dim + tail_dim]

examples/deepseek_v32/topk_selector.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,8 @@ def convert_to_uint32(x):
2626

2727
@tilelang.jit(pass_configs=pass_configs)
2828
def tl_topk_impl(topk, in_dtype="float32", out_dtype="int32"):
29-
batch = T.symbolic("batch")
30-
seq_len = T.symbolic("seq_len")
29+
batch = T.dynamic("batch")
30+
seq_len = T.dynamic("seq_len")
3131
RADIX = 1 << 8
3232
BLOCK_SIZE = 1024
3333
SMEM_INPUT_SIZE = 4096 # assume the threshold bucket size after first pass is less than 4K

examples/gemm_sm100/gemm_mma.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def main(
4141
return main
4242

4343

44-
M = 128 # M = T.symbolic("m") if you want to use dynamic shape
44+
M = 128 # M = T.dynamic("m") if you want to use dynamic shape
4545
N = 128
4646
K = 32
4747
block_M = 128

0 commit comments

Comments
 (0)