Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="flo
return matmul_relu_kernel


M = 1024 # M = T.symbolic("m") if you want to use dynamic shape
M = 1024 # M = T.dynamic("m") if you want to use dynamic shape
N = 1024
K = 1024
block_M = 128
Expand Down
2 changes: 1 addition & 1 deletion docs/deeplearning_operators/elementwise.md
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def elementwise_add(
In the compilation process above, a fixed shape was used. However, in practical usage, we often want the kernel to support dynamic shapes. So, how can we compile a kernel in TileLang to handle dynamic shapes? In TileLang, we can replace the target size with a dynamic symbolic value, making the dimension dynamic. The following example illustrates this:

```python
program = elementwise_add(T.symbolic("N"), threads=256, dtype="bfloat16")
program = elementwise_add(T.dynamic("N"), threads=256, dtype="bfloat16")
kernel = tilelang.compile(program, out_idx=-1, target="cuda", execution_backend="cython")
```

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -223,12 +223,12 @@ def __init__(self, batch, heads, heads_kv, dim, dim_v, page_block_size, block_N,
block_N=block_N,
block_H=self.block_H,
page_block_size=page_block_size,
num_split=T.symbolic("num_split"),
num_split=T.dynamic("num_split"),
num_stages=2,
threads=128,
num_pages=num_pages,
max_num_blocks_per_seq=T.symbolic("max_num_blocks_per_seq"),
max_selected_blocks=T.symbolic("max_selected_blocks"),
max_num_blocks_per_seq=T.dynamic("max_num_blocks_per_seq"),
max_selected_blocks=T.dynamic("max_selected_blocks"),
)

props = torch.cuda.get_device_properties(torch.device("cuda:0"))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -206,11 +206,11 @@ def __init__(self, batch, heads, heads_kv, dim, dim_v, block_size):
self.kernel = flashattn(batch, heads, heads_kv, dim, dim_v)(
block_N=block_size,
block_H=self.block_H,
num_split=T.symbolic("num_split"),
num_split=T.dynamic("num_split"),
num_stages=2,
threads=128,
max_cache_seqlen=T.symbolic("max_cache_seqlen"),
max_selected_blocks=T.symbolic("max_selected_blocks"))
max_cache_seqlen=T.dynamic("max_cache_seqlen"),
max_selected_blocks=T.dynamic("max_selected_blocks"))

props = torch.cuda.get_device_properties(torch.device("cuda:0"))
self.num_sm = props.multi_processor_count
Expand Down Expand Up @@ -301,11 +301,11 @@ def sparse_gqa_decode_varlen_indice(query, key, value, block_indices, cache_seql
kernel = flashattn(batch, heads, heads_kv, dim, dim_v)(
block_N=block_size,
block_H=block_H,
num_split=T.symbolic("num_split"),
num_split=T.dynamic("num_split"),
num_stages=2,
threads=128,
max_cache_seqlen=T.symbolic("max_cache_seqlen"),
max_selected_blocks=T.symbolic("max_selected_blocks"))
max_cache_seqlen=T.dynamic("max_cache_seqlen"),
max_selected_blocks=T.dynamic("max_selected_blocks"))

output = kernel(query, key, value, block_indices, cache_seqlens, glse, Output_partial)
return output
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -193,11 +193,11 @@ def __init__(self, batch, heads, heads_kv, dim, dim_v, block_size):
self.kernel = flashattn(batch, heads, heads_kv, dim, dim_v)(
block_N=block_size,
block_H=self.block_H,
num_split=T.symbolic("num_split"),
num_split=T.dynamic("num_split"),
num_stages=2,
threads=128,
max_cache_seqlen=T.symbolic("max_cache_seqlen"),
num_blocks=T.symbolic("num_blocks"))
max_cache_seqlen=T.dynamic("max_cache_seqlen"),
num_blocks=T.dynamic("num_blocks"))

props = torch.cuda.get_device_properties(torch.device("cuda:0"))
self.num_sm = props.multi_processor_count
Expand Down Expand Up @@ -282,11 +282,11 @@ def sparse_gqa_decode_varlen_mask(query, key, value, block_mask, cache_seqlens,
kernel = flashattn(batch, heads, heads_kv, dim, dim_v)(
block_N=block_size,
block_H=block_H,
num_split=T.symbolic("num_split"),
num_split=T.dynamic("num_split"),
num_stages=2,
threads=128,
max_cache_seqlen=T.symbolic("max_cache_seqlen"),
num_blocks=T.symbolic("num_blocks"))
max_cache_seqlen=T.dynamic("max_cache_seqlen"),
num_blocks=T.dynamic("num_blocks"))
glse = torch.empty((batch, heads, num_split), dtype=torch.float32, device='cuda')
Output_partial = torch.empty((batch, heads, num_split, dim_v),
dtype=torch.float32,
Expand Down
8 changes: 4 additions & 4 deletions examples/deepseek_v32/fp8_lighting_indexer.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,8 @@ def mqa_attn_return_logits(
accum_dtype = "float"
index_dtype = "int32"

seq_len = T.symbolic("seq_len")
seq_len_kv = T.symbolic("seq_len_kv")
seq_len = T.dynamic("seq_len")
seq_len_kv = T.dynamic("seq_len_kv")

index_q_shape = [seq_len * heads, index_dim]
index_k_shape = [seq_len_kv, index_dim]
Expand Down Expand Up @@ -182,8 +182,8 @@ def clean_logits_(
threads: int = 512,
block_K: int = 4096,
):
seq_len = T.symbolic("seq_len")
seq_len_kv = T.symbolic("seq_len_kv")
seq_len = T.dynamic("seq_len")
seq_len_kv = T.dynamic("seq_len_kv")

dtype = "float"
indices_dtype = "int32"
Expand Down
10 changes: 5 additions & 5 deletions examples/deepseek_v32/inference/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def fast_round_scale(amax, fp8_max_inv):

@tilelang.jit(pass_configs=pass_configs)
def act_quant_kernel(N, in_dtype=BF16, out_dtype=FP8, scale_dtype=FP32, round_scale=False):
M = T.symbolic("M")
M = T.dynamic("M")
fp8_min = -448.0
fp8_max = 448.0
fp8_max_inv = 1 / fp8_max
Expand Down Expand Up @@ -110,7 +110,7 @@ def act_quant(x: torch.Tensor,
def fp8_gemm_kernel(N, K, out_dtype=BF16, accum_dtype="float32"):
assert out_dtype in [BF16, "float32"]

M = T.symbolic("M")
M = T.dynamic("M")
group_size = 128
block_M = 32
block_N = 128
Expand Down Expand Up @@ -192,9 +192,9 @@ def fp8_gemm(a: torch.Tensor, a_s: torch.Tensor, b: torch.Tensor,

@tilelang.jit(out_idx=[4], pass_configs=pass_configs)
def fp8_index_kernel(h: int, d: int):
b = T.symbolic("b")
m = T.symbolic("m")
n = T.symbolic("n")
b = T.dynamic("b")
m = T.dynamic("m")
n = T.dynamic("n")

blk_n1 = 512
blk_n2 = 128
Expand Down
6 changes: 3 additions & 3 deletions examples/deepseek_v32/sparse_mla_fwd.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,9 @@ def sparse_mla_fwd(
else:
sm_scale = sm_scale * 1.44269504 # log2(e)

batch = T.symbolic("batch")
seq_len = T.symbolic("seq_len")
seq_len_kv = T.symbolic("seq_len_kv")
batch = T.dynamic("batch")
seq_len = T.dynamic("seq_len")
seq_len_kv = T.dynamic("seq_len_kv")

head_kv = heads // kv_group
q_shape = [batch, seq_len, heads, dim + tail_dim]
Expand Down
4 changes: 2 additions & 2 deletions examples/deepseek_v32/topk_selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ def convert_to_uint32(x):

@tilelang.jit(pass_configs=pass_configs)
def tl_topk_impl(topk, in_dtype="float32", out_dtype="int32"):
batch = T.symbolic("batch")
seq_len = T.symbolic("seq_len")
batch = T.dynamic("batch")
seq_len = T.dynamic("seq_len")
RADIX = 1 << 8
BLOCK_SIZE = 1024
SMEM_INPUT_SIZE = 4096 # assume the threshold bucket size after first pass is less than 4K
Expand Down
2 changes: 1 addition & 1 deletion examples/gemm_sm100/gemm_mma.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def main(
return main


M = 128 # M = T.symbolic("m") if you want to use dynamic shape
M = 128 # M = T.dynamic("m") if you want to use dynamic shape
N = 128
K = 32
block_M = 128
Expand Down
2 changes: 1 addition & 1 deletion examples/quickstart.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def matmul_relu_kernel(
return matmul_relu_kernel


M = 1024 # M = T.symbolic("m") if you want to use dynamic shape
M = 1024 # M = T.dynamic("m") if you want to use dynamic shape
N = 1024
K = 1024
block_M = 128
Expand Down
2 changes: 1 addition & 1 deletion testing/python/issue/test_tilelang_issue_830.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def test_empty_kernel_lowering():

@tilelang.jit
def _empty_with_dead_code_kernel():
num_tokens = T.symbolic("num_tokens")
num_tokens = T.dynamic("num_tokens")

@T.prim_func
def buggy_kernel(x: T.Tensor[(num_tokens,), "float32"]):
Expand Down
6 changes: 3 additions & 3 deletions testing/python/jit/test_tilelang_jit_gemm_ctypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,14 +395,14 @@ def run_ctypes_dynamic_shape(M,

def test_ctypes_dynamic_shape():
run_ctypes_dynamic_shape(
T.symbolic("m"), 1024, 768, False, False, "float16", "float16", "float16", 128, 256, 32, 2)
T.dynamic("m"), 1024, 768, False, False, "float16", "float16", "float16", 128, 256, 32, 2)

run_ctypes_dynamic_shape(
T.symbolic("m"), T.symbolic("n"), 768, False, False, "float16", "float16", "float16", 128,
T.dynamic("m"), T.dynamic("n"), 768, False, False, "float16", "float16", "float16", 128,
256, 32, 2)

run_ctypes_dynamic_shape(
T.symbolic("m"), T.symbolic("n"), T.symbolic("k"), False, False, "float16", "float16",
T.dynamic("m"), T.dynamic("n"), T.dynamic("k"), False, False, "float16", "float16",
"float16", 128, 256, 32, 2)


Expand Down
8 changes: 4 additions & 4 deletions testing/python/jit/test_tilelang_jit_gemm_cython.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,14 +404,14 @@ def run_cython_dynamic_shape(M,

def test_cython_dynamic_shape():
run_cython_dynamic_shape(
T.symbolic("m"), 1024, 768, False, False, "float16", "float16", "float16", 128, 256, 32, 2)
T.dynamic("m"), 1024, 768, False, False, "float16", "float16", "float16", 128, 256, 32, 2)

run_cython_dynamic_shape(
T.symbolic("m"), T.symbolic("n"), 768, False, False, "float16", "float16", "float16", 128,
T.dynamic("m"), T.dynamic("n"), 768, False, False, "float16", "float16", "float16", 128,
256, 32, 2)

run_cython_dynamic_shape(
T.symbolic("m"), T.symbolic("n"), T.symbolic("k"), False, False, "float16", "float16",
T.dynamic("m"), T.dynamic("n"), T.dynamic("k"), False, False, "float16", "float16",
"float16", 128, 256, 32, 2)


Expand Down Expand Up @@ -473,7 +473,7 @@ def run_cython_dynamic_shape_with_out_idx(M,

def test_cython_dynamic_shape_with_out_idx():
run_cython_dynamic_shape_with_out_idx(
T.symbolic("m"), 1024, 768, False, False, "float16", "float16", "float16", 128, 256, 32, 2)
T.dynamic("m"), 1024, 768, False, False, "float16", "float16", "float16", 128, 256, 32, 2)


def matmul_int_variable(
Expand Down
2 changes: 1 addition & 1 deletion testing/python/language/test_tilelang_language_copy.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def run_tilelang_copy_with_stride(M=1024,

def test_tilelang_copy_with_stride():
run_tilelang_copy_with_stride(M=1024, N=1024, NN=2048, block_M=128, block_N=128)
run_tilelang_copy_with_stride(M=1024, N=1024, NN=T.symbolic("NN"), block_M=128, block_N=128)
run_tilelang_copy_with_stride(M=1024, N=1024, NN=T.dynamic("NN"), block_M=128, block_N=128)


def tilelang_copy_bufferload(num_tokens, dtype="float16"):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def assert_vectorize_access(M: int = 64, N: int = 64):
def issue_1013_buggy_kernel():
# NOTE: This kernel is mainly to test some corner cases in boundary check

num_tokens = T.symbolic('num_tokens')
num_tokens = T.dynamic('num_tokens')
num_threads = 128

@T.prim_func
Expand Down
124 changes: 5 additions & 119 deletions tilelang/language/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""The language interface for tl programs."""

from typing import Optional, Callable, Dict
from typing import Optional
# from .parser import *
# now is fully compatible with the upstream
# tir script
Expand Down Expand Up @@ -84,124 +84,10 @@

from .utils import index_to_coordinates # noqa: F401


def symbolic(name: str, dtype: str = "int32"):
"""
Create a TIR symbolic variable.

Parameters:
name (str): Identifier for the variable in generated TIR.
dtype (str): Data type string for the variable (e.g., "int32"). Defaults to "int32".

Returns:
tir.Var: A TIR variable with the given name and dtype for use in TIR/TensorIR kernels.
"""
return tir.Var(name, dtype)


def use_swizzle(panel_size: int, order: str = "row", enable: bool = True):
# If order is row, use rasterization2DRow, otherwise use rasterization2DColumn
# The panel size is the number of threads in a warp
# Use to improve the L2 Cache Locality
device_func = ("rasterization2DRow" if order == "row" else "rasterization2DColumn")
return attr(None, "threadblock_swizzle_pattern",
f"tl::{device_func}<{panel_size}>") if enable else None


def annotate_layout(layout_map: Dict):
"""Annotate the layout of the buffer

Args:
layout_map (Dict): a dictionary of buffer to layout

Returns:
block_attr: a block attribute

Example:
@T.prim_func
def main(
A: T.Tensor((M, N), dtype),
B: T.Tensor((M, N), dtype),
):
# Initialize Kernel Context
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
A_shared = T.alloc_shared((block_M, block_N), dtype)

T.annotate_layout({A_shared: layout})
for i, j in T.Parallel(block_M, block_N):
A_shared[i, j] = A[by * block_M + i, bx * block_N + j]

for i, j in T.Parallel(block_M, block_N):
B[by * block_M + i, bx * block_N + j] = A_shared[i, j]

return main
"""
# layout_map is a dictionary of buffer to layout
_layout_map = {}
for buffer, layout in layout_map.items():
if isinstance(layout, Layout):
_layout_map[buffer.data] = layout
elif isinstance(layout, Callable):
_layout_map[buffer.data] = Layout(buffer.shape, layout)
else:
raise ValueError(f"Invalid layout: {layout}")

return block_attr({"layout_map": _layout_map})


def annotate_safe_value(safe_value_map: Dict):
"""Annotate the safe value of the buffer.

A safe value of a buffer is the value that will be used when the
buffer is accessed out of bounds.

Args:
safe_value_map (dict): a dictionary of buffer to safe value

Returns:
block_attr: a block attribute

Example:
@T.prim_func
def main(
A: T.Tensor((M, N), dtype),
B: T.Tensor((M, N), dtype),
):
# Initialize Kernel Context
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
A_shared = T.alloc_shared((block_M, block_N), dtype)

T.annotate_safe_value({A: safe_value})
for i, j in T.Parallel(block_M, block_N):
A_shared[i, j] = A[by * block_M + i - 10, bx * block_N + j]

for i, j in T.Parallel(block_M, block_N):
B[by * block_M + i, bx * block_N + j] = A_shared[i, j]

return main
"""
# safe_value_map is a dictionary of buffer to safe value
_safe_value_map = {}
for buffer, safe_value in safe_value_map.items():
_safe_value_map[buffer.data] = safe_value
return block_attr({"safe_value_map": _safe_value_map})


def annotate_l2_hit_ratio(l2_hit_ratio_map: Dict):
"""Annotate the L2 hit ratio of the buffer, detailed explanation please refer to:
https://docs.nvidia.com/cuda/cuda-c-programming-guide/#l2-policy-for-persisting-accesses

Args:
l2_hit_ratio_map (dict): a dictionary of buffer to L2 hit ratio value
Example:
# 0.5 is the hit ratio
T.annotate_l2_hit_ratio({A: 0.5})
"""
_l2_hit_ratio_map = {}
for buffer, hit_ratio in l2_hit_ratio_map.items():
assert buffer.scope() == "global", "persistent L2 can only be applied to global buffers"
_l2_hit_ratio_map[buffer.data] = float(hit_ratio)
return block_attr({"l2_hit_ratio_map": _l2_hit_ratio_map})
from .symbolics import dynamic, symbolic # noqa: F401
from .annotations import ( # noqa: F401
use_swizzle, annotate_layout, annotate_safe_value, annotate_l2_hit_ratio,
)


def import_source(source: Optional[str] = None):
Expand Down
Loading