Skip to content

Commit 444892f

Browse files
authored
[Dev][jit] Introduce jit for kernel functions (#12)
* instruction update * replace link with TileLang/tile-lang * [Dev][Adapter] Implement Torch DLPack Kernel Adapter and related utilities * lint fix * Implement JIT Compiler Components * Documents update * lint fix * update logo * install script fix
1 parent 9c578fa commit 444892f

26 files changed

+1242
-154
lines changed

README.md

Lines changed: 52 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
<img src=./images/logo-row.svg />
2+
13
<div align="center">
24

35
# Tile Language
@@ -57,7 +59,7 @@ pip install tilelang
5759
Alternatively, you can install directly from the GitHub repository:
5860

5961
```bash
60-
pip install git+https://github.com/TileLang/tile-lang
62+
pip install git+https://github.com/tile-ai/tilelang
6163
```
6264

6365
Or install locally:
@@ -82,6 +84,9 @@ In this section, you’ll learn how to write and execute a straightforward GEMM
8284
Below is an example that demonstrates more advanced features: layout annotation, parallelized copy, and swizzle for improved L2 cache locality. This snippet shows how to adapt your kernel to maximize performance on complex hardware.
8385

8486
```python
87+
# Copyright (c) Microsoft Corporation.
88+
# Licensed under the MIT License.
89+
import tilelang
8590
import tilelang.language as T
8691
# `make_mma_swizzle_layout` is a python defined layout function
8792
# specifically designed for for MMA operations
@@ -91,6 +96,7 @@ from tilelang.intrinsics import (
9196
make_mma_swizzle_layout as make_swizzle_layout,)
9297

9398
def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
99+
# add decorator @tilelang.jit if you want to return a torch function
94100
@T.prim_func
95101
def main(
96102
A: T.Buffer((M, K), dtype),
@@ -105,13 +111,13 @@ def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="flo
105111

106112
# Apply layout optimizations or define your own layout (Optional)
107113
# If not specified, we will deduce the layout automatically
108-
T.annotate_layout({
109-
A_shared: make_swizzle_layout(A_shared),
110-
B_shared: make_swizzle_layout(B_shared),
111-
})
114+
# T.annotate_layout({
115+
# A_shared: make_swizzle_layout(A_shared),
116+
# B_shared: make_swizzle_layout(B_shared),
117+
# })
112118

113119
# Enable rasterization for better L2 cache locality (Optional)
114-
T.use_swizzle(panel_size=10, enable=True)
120+
# T.use_swizzle(panel_size=10, enable=True)
115121

116122
# Clear local accumulation
117123
T.clear(C_local)
@@ -133,6 +139,45 @@ def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="flo
133139
T.copy(C_local, C[by * block_M, bx * block_N])
134140

135141
return main
142+
143+
144+
# 1. Define the kernel (matmul) and compile/lower it into an executable module
145+
func = matmul(1024, 1024, 1024, 128, 128, 32)
146+
147+
# 2. Compile the kernel into a torch function
148+
# out_idx specifies the index of the output buffer in the argument list
149+
# if out_idx is specified, the tensor will be created during runtime
150+
# target currently can be "cuda" or "hip" or "cpu".
151+
jit_kernel = tilelang.JITKernel(func, out_idx=[2], target="cuda")
152+
153+
# 3. Test the kernel in Python with PyTorch data
154+
import torch
155+
156+
# Create random input tensors on the GPU
157+
a = torch.randn(1024, 1024, device="cuda", dtype=torch.float16)
158+
b = torch.randn(1024, 1024, device="cuda", dtype=torch.float16)
159+
160+
161+
# Run the kernel through the Profiler
162+
c = jit_kernel(a, b)
163+
164+
# Reference multiplication using PyTorch
165+
ref_c = a @ b
166+
167+
# Validate correctness
168+
torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2)
169+
print("Kernel output matches PyTorch reference.")
170+
171+
# 4. Retrieve and inspect the generated CUDA source (optional)
172+
cuda_source = jit_kernel.get_kernel_source()
173+
print("Generated CUDA kernel:\n", cuda_source)
174+
175+
# 5.Pofile latency with kernel
176+
profiler = jit_kernel.get_profiler()
177+
178+
latency = profiler.do_bench()
179+
180+
print(f"Latency: {latency} ms")
136181
```
137182

138183
### Dive Deep into TileLang Beyond GEMM
@@ -152,4 +197,4 @@ TileLang has now been used in project [BitBLAS](https://github.com/microsoft/Bit
152197

153198
## Acknowledgements
154199

155-
We learned a lot from the [TVM](https://github.com/apache/tvm) community and would like to thank them for their contributions.
200+
We learned a lot from the [TVM](https://github.com/apache/tvm) community and would like to thank them for their contributions. The initial version of this project is mainly contributed by [LeiWang1999](https://github.com/LeiWang1999), [chengyupku](https://github.com/chengyupku) and [nox-410](https://github.com/nox-410). Part of this work was done during the internship at Microsoft Research, under the supervision of Dr. Lingxiao Ma, Dr. Yuqing Xia, Dr. Jilong Xue, and Dr. Fan Yang.

docker/Dockerfile.cu120

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ RUN conda install pip cmake && conda clean --all
2222

2323
RUN apt-get install -y python3 python3-dev python3-setuptools gcc libtinfo-dev zlib1g-dev build-essential cmake libedit-dev libxml2-dev
2424

25-
RUN git clone https://github.com/TileLang/tile-lang.git --recursive -b main TileLang \
25+
RUN git clone https://github.com/tile-ai/tilelang.git --recursive -b main TileLang \
2626
&& cd TileLang && ./install.sh
2727

2828
CMD bash

docker/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
To ease the process of installing all the dependencies, we provide a Dockerfile and a simple guideline to build a Docker image with all of above installed. The Docker image is built on top of Ubuntu 20.04, and it contains all the dependencies required to run the experiments. We only provide the Dockerfile for NVIDIA GPU, and the Dockerfile for AMD GPU will be provided upon request.
22

33
```bash
4-
git clone --recursive https://github.com/TileLang/tile-lang TileLang
4+
git clone --recursive https://github.com/tile-ai/tilelang TileLang
55
cd TileLang/docker
66
# build the image, this may take a while (around 10+ minutes on our test machine)
77
docker build -t tilelang_cuda -f Dockerfile.cu120 .

docs/Installation.md

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
The easiest way to install TileLang is directly from the PyPi using pip. To install the latest version, run the following command in your terminal.
1111

12-
**Note**: Currently, TileLang whl is only supported on Ubuntu 20.04 or later version as we build the whl files on this platform. Currently we only provide whl files for CUDA>=11.0 and with Python>=3.8. **If you are using a different platform or environment, you may need to [build TileLang from source](https://github.com/TileLang/tile-lang/blob/main/docs/Installation.md#building-from-source).**
12+
**Note**: Currently, TileLang whl is only supported on Ubuntu 20.04 or later version as we build the whl files on this platform. Currently we only provide whl files for CUDA>=11.0 and with Python>=3.8. **If you are using a different platform or environment, you may need to [build TileLang from source](https://github.com/tile-ai/tilelang/blob/main/docs/Installation.md#building-from-source).**
1313

1414
```bash
1515
pip install tilelang
@@ -24,7 +24,7 @@ pip install tilelang-0.0.0.dev0+ubuntu.20.4.cu120-py3-none-any.whl
2424
To install the latest version of TileLang from the github repository, you can run the following command:
2525

2626
```bash
27-
pip install git+https://github.com/TileLang/tile-lang.git
27+
pip install git+https://github.com/tile-ai/tilelang.git
2828
```
2929

3030
After installing TileLang, you can verify the installation by running:
@@ -56,7 +56,7 @@ sudo apt-get install -y python3 python3-dev python3-setuptools gcc libtinfo-dev
5656
After installing the prerequisites, you can clone the TileLang repository and install it using pip:
5757

5858
```bash
59-
git clone --recursive https://github.com/TileLang/tile-lang.git
59+
git clone --recursive https://github.com/tile-ai/tilelang.git
6060
cd TileLang
6161
pip install . # Please be patient, this may take some time.
6262
```
@@ -80,7 +80,7 @@ If you already have a compatible TVM installation, follow these steps:
8080
1. **Clone the Repository:**
8181

8282
```bash
83-
git clone --recursive https://github.com/TileLang/tile-lang
83+
git clone --recursive https://github.com/tile-ai/tilelang
8484
cd TileLang
8585
```
8686

@@ -114,7 +114,7 @@ If you prefer to use the built-in TVM version, follow these instructions:
114114
1. **Clone the Repository:**
115115

116116
```bash
117-
git clone --recursive https://github.com/TileLang/tile-lang
117+
git clone --recursive https://github.com/tile-ai/tilelang
118118
cd TileLang
119119
```
120120

@@ -152,7 +152,7 @@ For a simplified installation, use the provided script:
152152
1. **Clone the Repository:**
153153

154154
```bash
155-
git clone --recursive https://github.com/TileLang/tile-lang
155+
git clone --recursive https://github.com/tile-ai/tilelang
156156
cd TileLang
157157
```
158158

examples/quickstart.py

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
import tilelang
4+
import tilelang.language as T
5+
# `make_mma_swizzle_layout` is a python defined layout function
6+
# specifically designed for for MMA operations
7+
# which ensures the consistency with the nvidia CUTLASS Library.
8+
# to avoid bank conflicts and maximize the performance.
9+
from tilelang.intrinsics import (
10+
make_mma_swizzle_layout as make_swizzle_layout,) # noqa: F401
11+
12+
13+
def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
14+
# add decorator @tilelang.jit if you want to return a torch function
15+
@T.prim_func
16+
def main(
17+
A: T.Buffer((M, K), dtype),
18+
B: T.Buffer((K, N), dtype),
19+
C: T.Buffer((M, N), dtype),
20+
):
21+
# Kernel configuration remains similar
22+
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
23+
A_shared = T.alloc_shared((block_M, block_K), dtype)
24+
B_shared = T.alloc_shared((block_K, block_N), dtype)
25+
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
26+
27+
# Apply layout optimizations or define your own layout (Optional)
28+
# If not specified, we will deduce the layout automatically
29+
# T.annotate_layout({
30+
# A_shared: make_swizzle_layout(A_shared),
31+
# B_shared: make_swizzle_layout(B_shared),
32+
# })
33+
34+
# Enable rasterization for better L2 cache locality (Optional)
35+
# T.use_swizzle(panel_size=10, enable=True)
36+
37+
# Clear local accumulation
38+
T.clear(C_local)
39+
40+
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=3):
41+
# Copy tile of A
42+
# This is a sugar syntax for parallelized copy
43+
T.copy(A[by * block_M, k * block_K], A_shared)
44+
45+
# Demonstrate parallelized copy from global to shared for B
46+
for ko, j in T.Parallel(block_K, block_N):
47+
B_shared[ko, j] = B[k * block_K + ko, bx * block_N + j]
48+
49+
# Perform a tile-level GEMM on the shared buffers
50+
# Currently we dispatch to the cute/hip on Nvidia/AMD GPUs
51+
T.gemm(A_shared, B_shared, C_local)
52+
53+
# Copy result back to global memory
54+
T.copy(C_local, C[by * block_M, bx * block_N])
55+
56+
return main
57+
58+
59+
# 1. Define the kernel (matmul) and compile/lower it into an executable module
60+
func = matmul(1024, 1024, 1024, 128, 128, 32)
61+
62+
# 2. Compile the kernel into a torch function
63+
# out_idx specifies the index of the output buffer in the argument list
64+
# if out_idx is specified, the tensor will be created during runtime
65+
# target currently can be "cuda" or "hip" or "cpu".
66+
jit_kernel = tilelang.JITKernel(func, out_idx=[2], target="cuda")
67+
68+
# 3. Test the kernel in Python with PyTorch data
69+
import torch
70+
71+
# Create random input tensors on the GPU
72+
a = torch.randn(1024, 1024, device="cuda", dtype=torch.float16)
73+
b = torch.randn(1024, 1024, device="cuda", dtype=torch.float16)
74+
75+
# Run the kernel through the Profiler
76+
c = jit_kernel(a, b)
77+
78+
# Reference multiplication using PyTorch
79+
ref_c = a @ b
80+
81+
# Validate correctness
82+
torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2)
83+
print("Kernel output matches PyTorch reference.")
84+
85+
# 4. Retrieve and inspect the generated CUDA source (optional)
86+
cuda_source = jit_kernel.get_kernel_source()
87+
print("Generated CUDA kernel:\n", cuda_source)
88+
89+
# 5.Pofile latency with kernel
90+
profiler = jit_kernel.get_profiler()
91+
92+
latency = profiler.do_bench()
93+
94+
print(f"Latency: {latency} ms")

0 commit comments

Comments
 (0)