Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion 3rdparty/tvm
Submodule tvm updated from 5bf17a to 9cda9b
4 changes: 2 additions & 2 deletions examples/gdn/example_chunk_o_bwd.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,8 +256,8 @@ def kernel(
# for i_kv in T.Parallel(block_DK * block_DV):
# dg_last_fragment[i_kv] = h_shared[i_kv // block_DV, i_kv % block_DV] * dh_shared[i_kv // block_DV, i_kv % block_DV]
for i_kv in T.Parallel(block_DK * block_DV):
i_k, i_v = i_kv // block_DV, i_kv % block_DV
dg_last_fragment[i_kv] = h_shared[i_k, i_v] * dh_shared[i_k, i_v]
i_k, i_v_1 = i_kv // block_DV, i_kv % block_DV
dg_last_fragment[i_kv] = h_shared[i_k, i_v_1] * dh_shared[i_k, i_v_1]
T.reduce_sum(dg_last_fragment, dg_last_fragment_scalar, dim=-1, clear=False)
dg_last_local[0] += dg_last_fragment_scalar[0]

Expand Down
74 changes: 74 additions & 0 deletions testing/python/jit/test_tilelang_jit_parcompile.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import tilelang.testing
import tilelang
import torch


@tilelang.jit(
out_idx=-1, # create the output tensor during runtime
verbose=True,
)
def matmul_kernel_jit(
M,
N,
K,
block_M,
block_N,
block_K,
trans_A=False,
trans_B=True,
in_dtype='float16',
out_dtype='float32',
accum_dtype='float32',
num_stages=2,
threads=128,
):
Comment on lines +6 to +24
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion | 🟠 Major

Make test portable across CUDA/MPS and select correct backend for metal.

Currently hard-codes CUDA and default backend; metal requires execution_backend="torch". Choose device/backend at import time and skip if none.

+import pytest
 import tilelang.testing
 import tilelang
 import torch
 
-@tilelang.jit(
-    out_idx=-1,  # create the output tensor during runtime
-    verbose=True,
-)
+# Device/backend selection
+USE_CUDA = torch.cuda.is_available()
+USE_MPS = hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
+DEVICE = "cuda" if USE_CUDA else ("mps" if USE_MPS else None)
+EXEC_BACKEND = "torch" if (USE_MPS and not USE_CUDA) else "cython"
+
+@tilelang.jit(
+    out_idx=-1,  # create the output tensor during runtime
+    verbose=True,
+    execution_backend=EXEC_BACKEND,
+)
 def matmul_kernel_jit(
@@
 def test_par_compile():
+    if DEVICE is None:
+        pytest.skip("No CUDA or MPS device available for JIT test.")
     configs = [
         (1024, 1024, 1024, 128, 128, 32),
         (2048, 2048, 2048, 256, 256, 64),
         (4096, 4096, 4096, 64, 64, 128),
     ]
     kernels = matmul_kernel_jit.par_compile(configs)
     for (M, N, K, _, _, _), kernel in zip(configs, kernels):
-        A = torch.randn(M, K, dtype=torch.float16).cuda()
-        B = torch.randn(N, K, dtype=torch.float16).cuda()
+        A = torch.randn(M, K, dtype=torch.float16, device=DEVICE)
+        B = torch.randn(N, K, dtype=torch.float16, device=DEVICE)
         ref = (A @ B.T).float()
         C = kernel(A, B)
         tilelang.testing.torch_assert_close(C, ref, rtol=1e-2, atol=1e-2)

Also applies to: 58-71

A_shape = (K, M) if trans_A else (M, K)
B_shape = (N, K) if trans_B else (K, N)
A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K)
B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N)

import tilelang.language as T

@T.prim_func
def main(
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
B_shared = T.alloc_shared(B_shared_shape, in_dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
T.clear(C_local)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
if trans_A:
T.copy(A[k * block_K, by * block_M], A_shared)
else:
T.copy(A[by * block_M, k * block_K], A_shared)
if trans_B:
T.copy(B[bx * block_N, k * block_K], B_shared)
else:
T.copy(B[k * block_K, bx * block_N], B_shared)
T.gemm(A_shared, B_shared, C_local, trans_A, trans_B)
T.copy(C_local, C[by * block_M, bx * block_N])

return main


def test_par_compile():
configs = [
(1024, 1024, 1024, 128, 128, 32),
(2048, 2048, 2048, 256, 256, 64),
(4096, 4096, 4096, 64, 64, 128),
]
kernels = matmul_kernel_jit.par_compile(configs)
for (M, N, K, _, _, _), kernel in zip(configs, kernels):
A = torch.randn(M, K, dtype=torch.float16).cuda()
B = torch.randn(N, K, dtype=torch.float16).cuda()
ref = (A @ B.T).float()
C = kernel(A, B)
tilelang.testing.torch_assert_close(C, ref, rtol=1e-2, atol=1e-2)


if __name__ == "__main__":
tilelang.testing.main()
222 changes: 222 additions & 0 deletions testing/python/language/test_tilelang_language_dtype.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,222 @@
import tilelang
import tilelang.language as T
import torch
import tilelang.testing
import tvm


def test_argument():

@T.prim_func
def test_argument(
t_1: T.bool,
t_2: T.short,
t_3: T.int,
t_4: T.long,
t_5: T.half,
t_6: T.float,
t_7: T.long,
t_8: T.int8,
t_9: T.int16,
t_10: T.int32,
t_11: T.int64,
t_12: T.uint8,
t_13: T.uint16,
t_14: T.uint32,
t_15: T.uint64,
t_16: T.float8_e4m3fn,
t_17: T.float8_e4m3fnuz,
t_18: T.float8_e5m2,
t_19: T.float8_e5m2fnuz,
t_20: T.float8_e8m0fnu,
t_21: T.float16,
t_22: T.bfloat16,
t_23: T.float32,
t_24: T.float64,
):
pass


def test_expr():
from tilelang.language.v2.dtypes import _all_dtypes
errors = []
for name in _all_dtypes:
dtype = getattr(T, name)
assert isinstance(dtype, tvm.DataType), f"{dtype} is not tvm.DataType"
try:
dtype(1.0)
dtype()
except TypeError:
pass
except Exception:
errors.append(name)
assert not errors


def test_var_decl_sugar():

@T.prim_func
def test_var_decl_sugar():
with T.Kernel(128, 128) as (bx, by):
var_1: T.bool = 1.0
var_2: T.short = 1.0
var_3: T.int = 1.0
var_4: T.long = 1.0
var_5: T.half = 1.0
var_6: T.float = 1.0
var_7: T.long = 1.0
var_8: T.int8 = 1.0
var_9: T.int16 = 1.0
var_10: T.int32 = 1.0
var_11: T.int64 = 1.0
var_12: T.uint8 = 1.0
var_13: T.uint16 = 1.0
var_14: T.uint32 = 1.0
var_15: T.uint64 = 1.0
var_16: T.float8_e4m3fn = 1.0
var_17: T.float8_e4m3fnuz = 1.0
var_18: T.float8_e5m2 = 1.0
var_19: T.float8_e5m2fnuz = 1.0
var_20: T.float8_e8m0fnu = 1.0
var_21: T.float16 = 1.0
var_22: T.bfloat16 = 1.0
var_23: T.float32 = 1.0
var_24: T.float64 = 1.0
var_1: T.bool = var_1
var_2: T.short = var_2
var_3: T.int = var_3
var_4: T.long = var_4
var_5: T.half = var_5
var_6: T.float = var_6
var_7: T.long = var_7
var_8: T.int8 = var_8
var_9: T.int16 = var_9
var_10: T.int32 = var_10
var_11: T.int64 = var_11
var_12: T.uint8 = var_12
var_13: T.uint16 = var_13
var_14: T.uint32 = var_14
var_15: T.uint64 = var_15
var_16: T.float8_e4m3fn = var_16
var_17: T.float8_e4m3fnuz = var_17
var_18: T.float8_e5m2 = var_18
var_19: T.float8_e5m2fnuz = var_19
var_20: T.float8_e8m0fnu = var_20
var_21: T.float16 = var_21
var_22: T.bfloat16 = var_22
var_23: T.float32 = var_23
var_24: T.float64 = var_24

s = test_var_decl_sugar.script()
for i in range(1, 25):
assert f'var_{i}_1' in s
assert 'tl.local_var_init' in s


def test_dtype_str_repr():

@T.prim_func
def test_str_repr():
buf_1 = T.alloc_buffer((1,), dtype=T.bool, scope='shared') # noqa F841
buf_2 = T.alloc_buffer((1,), dtype=T.short, scope='shared') # noqa F841
buf_3 = T.alloc_buffer((1,), dtype=T.int, scope='shared') # noqa F841
buf_4 = T.alloc_buffer((1,), dtype=T.long, scope='shared') # noqa F841
buf_5 = T.alloc_buffer((1,), dtype=T.half, scope='shared') # noqa F841
buf_6 = T.alloc_buffer((1,), dtype=T.float, scope='shared') # noqa F841
buf_7 = T.alloc_buffer((1,), dtype=T.long, scope='shared') # noqa F841
buf_8 = T.alloc_buffer((1,), dtype=T.int8, scope='shared') # noqa F841
buf_9 = T.alloc_buffer((1,), dtype=T.int16, scope='shared') # noqa F841
buf_10 = T.alloc_buffer((1,), dtype=T.int32, scope='shared') # noqa F841
buf_11 = T.alloc_buffer((1,), dtype=T.int64, scope='shared') # noqa F841
buf_12 = T.alloc_buffer((1,), dtype=T.uint8, scope='shared') # noqa F841
buf_13 = T.alloc_buffer((1,), dtype=T.uint16, scope='shared') # noqa F841
buf_14 = T.alloc_buffer((1,), dtype=T.uint32, scope='shared') # noqa F841
buf_15 = T.alloc_buffer((1,), dtype=T.uint64, scope='shared') # noqa F841
buf_16 = T.alloc_buffer((1,), dtype=T.float8_e4m3fn, scope='shared') # noqa F841
buf_17 = T.alloc_buffer((1,), dtype=T.float8_e4m3fnuz, scope='shared') # noqa F841
buf_18 = T.alloc_buffer((1,), dtype=T.float8_e5m2, scope='shared') # noqa F841
buf_19 = T.alloc_buffer((1,), dtype=T.float8_e5m2fnuz, scope='shared') # noqa F841
buf_20 = T.alloc_buffer((1,), dtype=T.float8_e8m0fnu, scope='shared') # noqa F841
buf_21 = T.alloc_buffer((1,), dtype=T.float16, scope='shared') # noqa F841
buf_22 = T.alloc_buffer((1,), dtype=T.bfloat16, scope='shared') # noqa F841
buf_23 = T.alloc_buffer((1,), dtype=T.float32, scope='shared') # noqa F841
buf_24 = T.alloc_buffer((1,), dtype=T.float64, scope='shared') # noqa F841


def test_torch_eq():
dtypes = [
T.bool,
T.short,
T.int,
T.long,
T.half,
T.float,
T.long,
T.int8,
T.int16,
T.int32,
T.int64,
T.uint8,
T.uint16,
T.uint32,
T.uint64,
T.float8_e4m3fn,
T.float8_e4m3fnuz,
T.float8_e5m2,
T.float8_e5m2fnuz,
T.float8_e8m0fnu,
T.float16,
T.bfloat16,
T.float32,
T.float64,
]
torch_dtypes = [
torch.bool,
torch.short,
torch.int,
torch.long,
torch.half,
torch.float,
torch.long,
torch.int8,
torch.int16,
torch.int32,
torch.int64,
torch.uint8,
torch.uint16,
torch.uint32,
torch.uint64,
torch.float8_e4m3fn,
torch.float8_e4m3fnuz,
torch.float8_e5m2,
torch.float8_e5m2fnuz,
torch.float8_e8m0fnu,
torch.float16,
torch.bfloat16,
torch.float32,
torch.float64,
]
for a, b in zip(dtypes, torch_dtypes):
assert a == b, f"{a} and {b} are not equal"
Comment on lines +147 to +200
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Duplicate dtype comparison at indices 3 and 6.

Both T.long and torch.long appear twice in the test lists (lines 151/154 and 177/180). This appears to be unintentional duplication.

Apply this diff to test a different dtype at position 6:

         T.half,
         T.float,
-        T.long,
+        T.double,
         T.int8,
         torch.half,
         torch.float,
-        torch.long,
+        torch.double,
         torch.int8,
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
dtypes = [
T.bool,
T.short,
T.int,
T.long,
T.half,
T.float,
T.long,
T.int8,
T.int16,
T.int32,
T.int64,
T.uint8,
T.uint16,
T.uint32,
T.uint64,
T.float8_e4m3fn,
T.float8_e4m3fnuz,
T.float8_e5m2,
T.float8_e5m2fnuz,
T.float8_e8m0fnu,
T.float16,
T.bfloat16,
T.float32,
T.float64,
]
torch_dtypes = [
torch.bool,
torch.short,
torch.int,
torch.long,
torch.half,
torch.float,
torch.long,
torch.int8,
torch.int16,
torch.int32,
torch.int64,
torch.uint8,
torch.uint16,
torch.uint32,
torch.uint64,
torch.float8_e4m3fn,
torch.float8_e4m3fnuz,
torch.float8_e5m2,
torch.float8_e5m2fnuz,
torch.float8_e8m0fnu,
torch.float16,
torch.bfloat16,
torch.float32,
torch.float64,
]
for a, b in zip(dtypes, torch_dtypes):
assert a == b, f"{a} and {b} are not equal"
dtypes = [
T.bool,
T.short,
T.int,
T.long,
T.half,
T.float,
T.double,
T.int8,
T.int16,
T.int32,
T.int64,
T.uint8,
T.uint16,
T.uint32,
T.uint64,
T.float8_e4m3fn,
T.float8_e4m3fnuz,
T.float8_e5m2,
T.float8_e5m2fnuz,
T.float8_e8m0fnu,
T.float16,
T.bfloat16,
T.float32,
T.float64,
]
torch_dtypes = [
torch.bool,
torch.short,
torch.int,
torch.long,
torch.half,
torch.float,
torch.double,
torch.int8,
torch.int16,
torch.int32,
torch.int64,
torch.uint8,
torch.uint16,
torch.uint32,
torch.uint64,
torch.float8_e4m3fn,
torch.float8_e4m3fnuz,
torch.float8_e5m2,
torch.float8_e5m2fnuz,
torch.float8_e8m0fnu,
torch.float16,
torch.bfloat16,
torch.float32,
torch.float64,
]
for a, b in zip(dtypes, torch_dtypes):
assert a == b, f"{a} and {b} are not equal"
🤖 Prompt for AI Agents
In testing/python/language/test_tilelang_language_dtype.py around lines 147 to
200, the dtype pair at the 7th position is duplicated (T.long / torch.long
appear twice); replace the second occurrence (the entries at lines ~154 and
~180) with the intended dtype so both lists remain identical (update both dtypes
arrays consistently to the correct dtype and re-run tests to confirm).



def test_var_assign():

@tilelang.jit(out_idx=-1)
@T.prim_func
def test_var_assign(A: T.Tensor((2,), T.int32)):
with T.Kernel(1) as _:
a: T.int32 = 1
b: T.int32 = a
a = 2
d: T.int32 = a
A[0] = b
A[1] = d

res = test_var_assign()()
assert res[0] == 1
assert res[1] == 2


if __name__ == '__main__':
tilelang.testing.main()
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,7 @@ def test_loop_tail_split(block_M, block_N, block_K, threads, vec_load_b, dtype):
N = tvm.te.var("n")
K = tvm.te.var("k")

@tvm.script.ir.ir_module
class Before:
def before():

@T.prim_func
def main(B: T.Tensor((K, N), dtype),):
Expand All @@ -38,8 +37,9 @@ def main(B: T.Tensor((K, N), dtype),):
(block_N // vec_load_b) * (block_N // vec_load_b) + vec],
T.float16(0))

@tvm.script.ir.ir_module
class After:
return tvm.IRModule({'main': main})

def after():

@T.prim_func
def main(B: T.Tensor((K, N), dtype),):
Expand Down Expand Up @@ -77,11 +77,13 @@ def main(B: T.Tensor((K, N), dtype),):
bx * block_N + t % (block_N // vec_load_b) *
(block_N // vec_load_b) + vec], T.float16(0))

return tvm.IRModule({'main': main})

with tvm.target.Target(auto_target):
mod = tvm.tir.transform.BindTarget(auto_target)(Before)
mod = tvm.tir.transform.BindTarget(auto_target)(before())
mod = tl.transform.LayoutInference()(mod)
mod = tvm.tir.transform.Simplify()(mod)
ref_mod = tvm.tir.transform.BindTarget(auto_target)(After)
ref_mod = tvm.tir.transform.BindTarget(auto_target)(after())
ref_mod = tvm.tir.transform.Simplify()(ref_mod)
# Note(tzj): The structures are equal except one more "for" loop after the LayoutInference pass
# This loop is "for vec in T.parallel(1)",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def assert_vectorize_access(M: int = 64, N: int = 64):
def issue_1013_buggy_kernel():
# NOTE: This kernel is mainly to test some corner cases in boundary check

num_tokens = T.dynamic('num_tokens')
num_tokens = T.Var('num_tokens', 'int32')
num_threads = 128

@T.prim_func
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,7 @@ def test_loop_tail_split(block_M, block_N, block_K, threads, vec_load_b, dtype):
N = tvm.te.var("n")
K = tvm.te.var("k")

@tvm.script.ir.ir_module
class Before:
def before():

@T.prim_func
def main(B: T.Tensor((K, N), dtype),):
Expand All @@ -25,8 +24,9 @@ def main(B: T.Tensor((K, N), dtype),):
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=3):
T.copy(B[k * block_K, bx * block_N], B_shared)

@tvm.script.ir.ir_module
class After:
return tvm.IRModule({'main': main})

def after():

@T.prim_func
def main(B: T.Tensor((K, N), dtype),):
Expand Down Expand Up @@ -64,11 +64,13 @@ def main(B: T.Tensor((K, N), dtype),):
bx * block_N + t % (block_N // vec_load_b) *
(block_N // vec_load_b) + vec], T.float16(0))

return tvm.IRModule({'main': main})

with tvm.transform.PassContext():
mod = tvm.tir.transform.BindTarget(auto_target)(Before)
mod = tvm.tir.transform.BindTarget(auto_target)(before())
mod = tl.transform.LowerTileOp()(mod)
mod = tvm.tir.transform.Simplify()(mod)
ref_mod = tvm.tir.transform.BindTarget(auto_target)(After)
ref_mod = tvm.tir.transform.BindTarget(auto_target)(after())
ref_mod = tvm.tir.transform.Simplify()(ref_mod)
# Note(tzj): The structures are equal except the argument in "T.reads" function.
# The difference is just between the first index and the indices range, which is totally equivalent
Expand Down
Loading
Loading