Skip to content

Commit 8bc173a

Browse files
committed
Merge branch 'main' of https://github.com/microsoft/TileLang into main
2 parents 89a0b06 + eb7c598 commit 8bc173a

File tree

90 files changed

+14436
-11093
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

90 files changed

+14436
-11093
lines changed

.github/workflows/ci.yml

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
name: CI
2+
3+
on: [pull_request]
4+
5+
jobs:
6+
format-check:
7+
runs-on: self-hosted
8+
9+
steps:
10+
- name: Checkout repository
11+
uses: actions/checkout@v2
12+
with:
13+
fetch-depth: 0
14+
15+
- name: Set up Python
16+
uses: actions/setup-python@v2
17+
with:
18+
python-version: '3.9'
19+
20+
- name: Create virtual environment
21+
run: python -m venv bitblas_ci
22+
23+
- name: Activate virtual environment and install dependencies
24+
run: |
25+
source bitblas_ci/bin/activate
26+
python -m pip install --upgrade pip
27+
if [ -f requirements-dev.txt ]; then python -m pip install -r requirements-dev.txt; fi
28+
29+
- name: Update submodules recursively
30+
run: git submodule update --init --recursive
31+
32+
- name: Run format check
33+
run: |
34+
source bitblas_ci/bin/activate
35+
./format.sh
36+
37+
build-test:
38+
runs-on: self-hosted
39+
needs: format-check
40+
41+
steps:
42+
- name: Checkout repository
43+
uses: actions/checkout@v2
44+
with:
45+
fetch-depth: 0
46+
47+
- name: Set up Python
48+
uses: actions/setup-python@v2
49+
with:
50+
python-version: '3.9'
51+
52+
- name: Create virtual environment
53+
run: python -m venv bitblas_ci
54+
55+
- name: Activate virtual environment and install dependencies
56+
run: |
57+
source bitblas_ci/bin/activate
58+
python -m pip install --upgrade pip
59+
if [ -f requirements-test.txt ]; then python -m pip install -r requirements-test.txt; fi
60+
61+
- name: Install project in wheel mode
62+
run: |
63+
source bitblas_ci/bin/activate
64+
python -m pip install .
65+
66+
- name: Run tests
67+
run: |
68+
source bitblas_ci/bin/activate
69+
cd testing/python
70+
python -m pytest

README.md

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,7 @@ def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="flo
201201

202202
In addition to GEMM, we provide a variety of examples to showcase the versatility and power of TileLang, including:
203203

204-
- [Dequantize GEMM](./examples/dequantize_gemm/): Achieve high-performance dequantization by **fine-grained control over per-thread operations**, with many features now adopted as default behaviors in [BitBLAS](https://github.com/microsoft/BitBLAS), which utilzing magic layout transformation and intrins to accelerate dequantize gemm.
204+
- [Dequantize GEMM](./examples/dequantize_gemm/): Achieve high-performance dequantization by **fine-grained control over per-thread operations**, with many features now adopted as default behaviors in [BitBLAS](https://github.com/microsoft/BitBLAS), which utilizing magic layout transformation and intrins to accelerate dequantize gemm.
205205
- [FlashAttention](./examples/flash_attention/): Enable cross-operator fusion with simple and intuitive syntax, and we also provide an example of auto tuning.
206206
- [LinearAttention](./examples/linear_attention/): Examples include RetNet and Mamba implementations.
207207
- [Convolution](./examples/convolution/): Implementations of Convolution with IM2Col.
@@ -227,5 +227,3 @@ This project may contain trademarks or logos for projects, products, or services
227227
## Acknowledgements
228228

229229
We learned a lot from the [TVM](https://github.com/apache/tvm) community and would like to thank them for their contributions.
230-
231-
This project was initiated by [yining shi](https://github.com/nox-410), and continued by [lei wang](https://github.com/LeiWang1999) and [yu cheng](https://github.com/chengyupku). It was completed under the guidance of [yuqing xia](https://github.com/xiayuqing0622), [lingxiao ma](https://github.com/xysmlx) and [jilong xue](https://github.com/jlxue) from [MSRA System Research Group](https://www.microsoft.com/en-us/research/group/systems-and-networking-research-group-asia/).

docs/Installation.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
- **Python Version**: >= 3.8
88
- **CUDA Version**: >= 11.0
99

10-
The easiest way to install TileLang is direcly from the PyPi using pip. To install the latest version, run the following command in your terminal.
10+
The easiest way to install TileLang is directly from the PyPi using pip. To install the latest version, run the following command in your terminal.
1111

1212
**Note**: Currently, TileLang whl is only supported on Ubuntu 20.04 or later version as we build the whl files on this platform. Currently we only provide whl files for CUDA>=11.0 and with Python>=3.8. **If you are using a different platform or environment, you may need to [build TileLang from source](https://github.com/microsoft/TileLang/blob/main/docs/Installation.md#building-from-source).**
1313

examples/dequantize_gemm/example_dequant_gemm_fp4_hopper.py

Lines changed: 68 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import argparse
1313
from functools import partial
1414

15+
1516
def _tir_u8_to_f4_to_f16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, dtype: str):
1617
assert nbit == 4
1718
assert dtype == "float16"
@@ -24,12 +25,15 @@ def _tir_u8_to_f4_to_f16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, dtype:
2425
s = f4 >> tir.const(3, "uint16")
2526
e_f4 = f4 & tir.const(7, "uint16")
2627
e_f16 = e_f4 | tir.const(8, "uint16")
27-
val_f16 = tir.reinterpret("float16",
28-
((e_f16 | (s << tir.const(5, "uint16"))) << tir.const(10, "uint16")).astype("uint16"))
28+
val_f16 = tir.reinterpret(
29+
"float16",
30+
((e_f16 | (s << tir.const(5, "uint16"))) << tir.const(10, "uint16")).astype("uint16"))
2931
# return tir.Select(e_f4 == tir.const(0, "uint32"), tir.const(0, "float16"), val_f16)
3032
return val_f16
3133

34+
3235
def torch_convert(tensor):
36+
3337
def print_bit(name, val):
3438
val_cpu = val.cpu().item()
3539
binary_repr = f'{val_cpu:032b}'
@@ -46,7 +50,7 @@ def _convert(val, pos):
4650
val_f16 = ((e_f16 | (s << 5)) << 10) & 0xFFFF
4751
lower_16_bits = (val_f16 & 0xFFFF).to(torch.uint16)
4852
return lower_16_bits.view(torch.float16)
49-
53+
5054
N = tensor.shape[0]
5155
K = tensor.shape[1]
5256
new_tensor = torch.empty(N, K * 2, dtype=torch.float16, device=tensor.device)
@@ -55,6 +59,7 @@ def _convert(val, pos):
5559
new_tensor[i][j] = _convert(tensor[i][j // 2], j % 2)
5660
return new_tensor
5761

62+
5863
def test_convert(N, K, block_N, block_K, in_dtype, num_bits=4, threads=128):
5964
num_elems_per_byte = 8 // num_bits
6065
storage_dtype = "uint8"
@@ -64,18 +69,15 @@ def test_convert(N, K, block_N, block_K, in_dtype, num_bits=4, threads=128):
6469

6570
@T.prim_func
6671
def main(
67-
B: T.Buffer(B_shape, storage_dtype),
68-
C: T.Buffer((N, K), in_dtype),
72+
B: T.Buffer(B_shape, storage_dtype),
73+
C: T.Buffer((N, K), in_dtype),
6974
):
7075
with T.Kernel(T.ceildiv(N, block_N), threads=threads) as (bx):
7176
B_shared = T.alloc_shared(B_shared_shape, storage_dtype)
7277
B_local = T.alloc_fragment(B_shared_shape, storage_dtype)
7378
B_dequantize_local = T.alloc_fragment(B_dequantize_shared_shape, in_dtype)
7479

75-
for k in T.Pipelined(
76-
T.ceildiv(K, block_K),
77-
num_stages=1
78-
):
80+
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=1):
7981
T.copy(B[bx * block_N, k * block_K // num_elems_per_byte], B_shared)
8082
T.copy(B_shared, B_local)
8183
for i, j in T.Parallel(block_N, block_K):
@@ -89,12 +91,13 @@ def main(
8991

9092
return main
9193

94+
9295
def test_fp4_fp16_convert_close():
9396
N, K = 256, 256
9497
block_N, block_K = 64, 64
9598
program = test_convert(
9699
N,
97-
K,
100+
K,
98101
block_N,
99102
block_K,
100103
"float16",
@@ -109,6 +112,7 @@ def test_fp4_fp16_convert_close():
109112
assert torch.allclose(tl_out, ref_out, rtol=0.01, atol=0.01), (tl_out, ref_out)
110113
print("Pass")
111114

115+
112116
def get_configs():
113117
block_M = [128]
114118
block_N = [128, 256]
@@ -118,13 +122,19 @@ def get_configs():
118122
splits = [1]
119123
_configs = list(itertools.product(block_M, block_N, block_K, num_stages, threads, splits))
120124

121-
configs = [
122-
{'block_M': c[0], 'block_N': c[1], 'block_K': c[2], 'num_stages': c[3], 'threads': c[4], 'split': c[5]}
123-
for c in _configs
124-
]
125+
configs = [{
126+
'block_M': c[0],
127+
'block_N': c[1],
128+
'block_K': c[2],
129+
'num_stages': c[3],
130+
'threads': c[4],
131+
'split': c[5]
132+
} for c in _configs]
125133
return configs
126134

135+
127136
def matmul(M, N, K, in_dtype, out_dtype, accum_dtype, num_bits=4, tune=False):
137+
128138
def kernel_func(block_M, block_N, block_K, num_stages, threads, split=1):
129139
num_elems_per_byte = 8 // num_bits
130140
storage_dtype = "uint8"
@@ -142,10 +152,13 @@ def main_split(
142152
B: T.Buffer(B_shape, storage_dtype),
143153
Ct: T.Buffer((N, M), out_dtype),
144154
):
145-
SplitC = T.alloc_buffer(
146-
[split, (N + block_N - 1) // block_N * block_N, (M + block_M - 1) // block_M * block_M], out_dtype
147-
)
148-
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), split, threads=threads) as (bx, by, bz):
155+
SplitC = T.alloc_buffer([
156+
split, (N + block_N - 1) // block_N * block_N,
157+
(M + block_M - 1) // block_M * block_M
158+
], out_dtype)
159+
with T.Kernel(
160+
T.ceildiv(N, block_N), T.ceildiv(M, block_M), split,
161+
threads=threads) as (bx, by, bz):
149162
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
150163
B_shared = T.alloc_shared(B_shared_shape, storage_dtype)
151164
B_local = T.alloc_fragment(B_shared_shape, storage_dtype)
@@ -154,12 +167,10 @@ def main_split(
154167
Ct_local = T.alloc_fragment((block_N, block_M), accum_dtype)
155168
Ct_shared = T.alloc_shared((block_N, block_M), out_dtype)
156169

157-
T.annotate_layout(
158-
{
159-
B_shared: tilelang.layout.make_swizzled_layout(B_shared),
160-
Ct_shared: tilelang.layout.make_swizzled_layout(Ct_shared),
161-
}
162-
)
170+
T.annotate_layout({
171+
B_shared: tilelang.layout.make_swizzled_layout(B_shared),
172+
Ct_shared: tilelang.layout.make_swizzled_layout(Ct_shared),
173+
})
163174

164175
T.clear(Ct_local)
165176
for k in T.Pipelined(K // (block_K * split), num_stages=num_stages):
@@ -175,22 +186,24 @@ def main_split(
175186
)
176187
T.copy(B_dequantize_local, B_dequantize_prev_local)
177188
T.gemm(B_dequantize_prev_local, A_shared, Ct_local, transpose_B=True)
178-
T.copy(Ct_local, SplitC[bz, bx * block_N : (bx + 1) * block_N, by * block_M : (by + 1) * block_M])
189+
T.copy(Ct_local, SplitC[bz, bx * block_N:(bx + 1) * block_N,
190+
by * block_M:(by + 1) * block_M])
179191
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M)) as (bx, by):
180192
acc = T.alloc_fragment((block_N, block_M), out_dtype)
181193
T.clear(acc)
182194
for k in range(split):
183195
for i, j in T.Parallel(block_N, block_M):
184196
acc[i, j] += SplitC[k, bx * block_N + i, by * block_M + j]
185197
T.copy(acc, Ct[bx * block_N, by * block_M])
186-
198+
187199
@T.prim_func
188200
def main(
189201
A: T.Buffer(A_shape, in_dtype),
190202
B: T.Buffer(B_shape, storage_dtype),
191203
Ct: T.Buffer((N, M), out_dtype),
192204
):
193-
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
205+
with T.Kernel(
206+
T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
194207
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
195208
B_shared = T.alloc_shared(B_shared_shape, storage_dtype)
196209
B_local = T.alloc_fragment(B_shared_shape, storage_dtype)
@@ -199,12 +212,10 @@ def main(
199212
Ct_local = T.alloc_fragment((block_N, block_M), accum_dtype)
200213
Ct_shared = T.alloc_shared((block_N, block_M), out_dtype)
201214

202-
T.annotate_layout(
203-
{
204-
B_shared: tilelang.layout.make_swizzled_layout(B_shared),
205-
Ct_shared: tilelang.layout.make_swizzled_layout(Ct_shared),
206-
}
207-
)
215+
T.annotate_layout({
216+
B_shared: tilelang.layout.make_swizzled_layout(B_shared),
217+
Ct_shared: tilelang.layout.make_swizzled_layout(Ct_shared),
218+
})
208219

209220
T.clear(Ct_local)
210221
for k in T.Pipelined(K // block_K, num_stages=num_stages):
@@ -221,38 +232,51 @@ def main(
221232
T.copy(B_dequantize_local, B_dequantize_prev_local)
222233
T.gemm(B_dequantize_prev_local, A_shared, Ct_local, transpose_B=True)
223234
T.copy(Ct_local, Ct_shared)
224-
T.copy(Ct_shared, Ct[bx * block_N : (bx + 1) * block_N, by * block_M : (by + 1) * block_M])
235+
T.copy(Ct_shared, Ct[bx * block_N:(bx + 1) * block_N,
236+
by * block_M:(by + 1) * block_M])
225237

226238
if split == 1:
227239
return main
228240
else:
229241
return main_split
230-
242+
231243
if tune:
244+
232245
@autotune(
233246
configs=get_configs(),
234247
keys=["block_M", "block_N", "block_K", "num_stages", "threads", "split"],
235248
warmup=10,
236-
rep=10
237-
)
238-
@jit(out_idx=[2], supply_type=tilelang.TensorSupplyType.Integer, ref_prog=None, profiler="auto")
239-
def kernel(block_M=None, block_N=None, block_K=None, num_stages=None, threads=None, split=None):
249+
rep=10)
250+
@jit(
251+
out_idx=[2],
252+
supply_type=tilelang.TensorSupplyType.Integer,
253+
ref_prog=None,
254+
profiler="auto")
255+
def kernel(block_M=None,
256+
block_N=None,
257+
block_K=None,
258+
num_stages=None,
259+
threads=None,
260+
split=None):
240261
return kernel_func(block_M, block_N, block_K, num_stages, threads, split)
241262

242263
return kernel()
243264
else:
265+
244266
def kernel(block_M, block_N, block_K, num_stages, threads, split=1):
245267
return kernel_func(block_M, block_N, block_K, num_stages, threads, split)
246268

247269
return kernel
248270

271+
249272
def ref_program(A, qB):
250273
dtypeC = "float16"
251274
B = torch_convert(qB)
252275
C = torch.matmul(A.to(torch.float), B.T.to(torch.float))
253276
C = C.to(torch.__getattribute__(dtypeC))
254277
return C.transpose(0, 1)
255278

279+
256280
if __name__ == "__main__":
257281
parser = argparse.ArgumentParser()
258282
parser.add_argument('--m', type=int, default=256, help='M')
@@ -264,7 +288,9 @@ def ref_program(A, qB):
264288
total_flops = 2 * M * N * K
265289

266290
if (not args.tune):
267-
program = matmul(M, N, K, "float16", "float16", "float32", num_bits=4, tune=args.tune)(block_M=128, block_N=128, block_K=128, num_stages=2, threads=256, split=1)
291+
program = matmul(
292+
M, N, K, "float16", "float16", "float32", num_bits=4, tune=args.tune)(
293+
block_M=128, block_N=128, block_K=128, num_stages=2, threads=256, split=1)
268294
mod, params = tilelang.lower(program)
269295
mod = Profiler(mod, params, [2], tilelang.TensorSupplyType.Integer)
270296
mod.assert_allclose(ref_program, rtol=0.01, atol=0.01)
@@ -276,7 +302,8 @@ def ref_program(A, qB):
276302
print("Tile-lang: {:.2f} ms".format(latency))
277303
print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9))
278304
else:
279-
best_latency, best_config, ref_latency = matmul(M, N, K, "float16", "float16", "float32", num_bits=4, tune=args.tune)
305+
best_latency, best_config, ref_latency = matmul(
306+
M, N, K, "float16", "float16", "float32", num_bits=4, tune=args.tune)
280307
print(f"Best latency: {best_latency}")
281308
print(f"Best TFlops: {total_flops / best_latency * 1e-9}")
282309
print(f"Best config: {best_config}")

examples/gemm/example_gemm.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,18 +6,15 @@
66
import tilelang.language as T
77

88

9-
def matmul(
10-
M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"
11-
):
9+
def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
10+
1211
@T.prim_func
1312
def main(
14-
A: T.Buffer((M, K), dtype),
15-
B: T.Buffer((K, N), dtype),
16-
C: T.Buffer((M, N), dtype),
13+
A: T.Buffer((M, K), dtype),
14+
B: T.Buffer((K, N), dtype),
15+
C: T.Buffer((M, N), dtype),
1716
):
18-
with T.Kernel(
19-
T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128
20-
) as (bx, by):
17+
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
2118
A_shared = T.alloc_shared((block_M, block_K), dtype)
2219
B_shared = T.alloc_shared((block_K, block_N), dtype)
2320
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)

0 commit comments

Comments
 (0)