-
Notifications
You must be signed in to change notification settings - Fork 293
[feat] support gemm_sp for ampere arch #691
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
33 commits
Select commit
Hold shift + click to select a range
7ca3299
[feat] add an example mma atom
botbw e309fc6
[fix] fix typo naming
botbw 6ab6423
[feat] add a template to enable compilation
botbw e9c6bcb
[feat] add print util
botbw 680d5f1
[WIP] pass on single block tile
botbw 5e1e047
[feat] add sm80 metadata layout
botbw 0827b0c
[chore] clean codebase
botbw faa5467
[CI] format.sh
botbw de1d2d6
[feat] add sm80 compress utils
botbw c566c50
[bugfix] fix C fragment layout
botbw 60ca0e2
[refactor] use nvcc version instead of str
botbw d989dbe
[test] add test cases
botbw f7f920f
[chore] add a param check
botbw 02396c1
[chore] format a bit
botbw a571f63
[chore] rename func to satisfy PEP 8 and appease gemini
botbw 1df63f1
[chore] add check
botbw 3eee2e0
[feat] support sm75 layout && add assertion && chore
botbw df621cc
[bug] fix illegal memory access when using two warps over N=32
botbw e00a3e3
[chore] add example
botbw 636352e
[chore] format
botbw 1531762
[example] update benchmark
botbw c52aa3c
[bugfix] fix namespace and format
botbw 0e3bd60
Merge remote-tracking branch 'upstream/main' into gemm_sp_sm80
botbw fecadc8
[bugfix] fix incorrect param passing
botbw 51655d7
[refactor] update variable declaration for clarity in gemm_layouts an…
LeiWang1999 333e4d4
Merge branch 'main' of https://github.com/tile-ai/tilelang into gemm_…
LeiWang1999 122bff0
[Cleanup] Remove unnecessary blank lines in metadata layout functions…
LeiWang1999 b8e195f
[CI] fix arch
botbw de09434
[example] add torch sparse benchmark
botbw 9962ec5
[misc] polish && add reference && apply review suggestionsi && format
botbw 88435da
[CI] format with clang-tidy
botbw 597b66f
[Cleanup] Format and align template struct definitions in half.hpp, c…
LeiWang1999 27b4491
[Update] Modify CUDA version requirements in test_gemm_sp_sm80 and ma…
LeiWang1999 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,160 @@ | ||
| # Copyright (c) Tile-AI Corporation. | ||
| # Licensed under the MIT License. | ||
| import argparse | ||
|
|
||
| import tilelang | ||
| import tilelang.language as T | ||
|
|
||
| from tilelang.layout import make_metadata_layout | ||
| from tilelang.utils.sparse import compress | ||
| from tilelang.contrib import nvcc | ||
| from triton.testing import do_bench | ||
|
|
||
| import torch | ||
|
|
||
| arch = nvcc.get_target_compute_version() | ||
|
|
||
| ARCH_INFO = {"8.0": (16, "int16"), "8.9": (16, "int16"), "9.0": (8, "uint8")} | ||
|
|
||
| default_config = { # take best config from autotune script | ||
| "4090": { | ||
| 'float': { | ||
| 'block_M': 128, | ||
| 'block_N': 64, | ||
| 'block_K': 64, | ||
| 'num_stages': 1, | ||
| 'thread_num': 128, | ||
| 'policy': T.GemmWarpPolicy.Square, | ||
| 'enable_rasterization': True | ||
| }, | ||
| 'float16': { | ||
| 'block_M': 256, | ||
| 'block_N': 128, | ||
| 'block_K': 64, | ||
| 'num_stages': 2, | ||
| 'thread_num': 128, | ||
| 'policy': T.GemmWarpPolicy.Square, | ||
| 'enable_rasterization': True | ||
| } | ||
| }, | ||
| "h20": { | ||
| 'float': { | ||
| 'block_M': 128, | ||
| 'block_N': 64, | ||
| 'block_K': 128, | ||
| 'num_stages': 3, | ||
| 'thread_num': 128, | ||
| 'policy': T.GemmWarpPolicy.Square, | ||
| 'enable_rasterization': True | ||
| }, | ||
| 'float16': { | ||
| 'block_M': 128, | ||
| 'block_N': 64, | ||
| 'block_K': 128, | ||
| 'num_stages': 3, | ||
| 'thread_num': 128, | ||
| 'policy': T.GemmWarpPolicy.Square, | ||
| 'enable_rasterization': True | ||
| } | ||
| } | ||
| } | ||
|
|
||
|
|
||
| def generate_sparse_tensor(M: int, K: int, dtype=torch.float16, device='cuda'): | ||
| elem, group = 2, 4 | ||
| full_tensor = torch.randn((M, K), dtype=dtype, device=device).view(M, -1, group) | ||
| indice = full_tensor.topk(elem, dim=-1).indices | ||
| full_tensor.scatter_(-1, indice, 0) | ||
| return full_tensor.view(M, K) | ||
|
|
||
botbw marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| @tilelang.jit(out_idx=[-1]) | ||
| def matmul_sp_fp16(M, N, K, accum_dtype, block_M, block_N, block_K, num_stages, thread_num, policy, | ||
| enable_rasterization): | ||
| e_factor, e_dtype = ARCH_INFO[arch] | ||
|
|
||
| @T.prim_func | ||
| def gemm_sp_fp16( | ||
| A_sparse: T.Tensor((M, K // 2), 'float16'), | ||
| E: T.Tensor((M, K // e_factor), e_dtype), | ||
| B: T.Tensor((K, N), 'float16'), | ||
| C: T.Tensor((M, N), accum_dtype), | ||
| ): | ||
| with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by): | ||
| A_shared = T.alloc_shared((block_M, block_K // 2), 'float16') | ||
| E_shared = T.alloc_shared((block_M, block_K // e_factor), e_dtype) | ||
| B_shared = T.alloc_shared((block_K, block_N), 'float16') | ||
| C_shared = T.alloc_shared((block_M, block_N), accum_dtype) | ||
| C_local = T.alloc_fragment((block_M, block_N), accum_dtype) | ||
|
|
||
| T.clear(C_local) | ||
| T.disable_warp_group_reg_alloc() | ||
| T.use_swizzle(panel_size=10, enable=enable_rasterization) | ||
| T.annotate_layout({ | ||
| E: | ||
| make_metadata_layout( | ||
| E, mma_dtype="float16", backend="cutlass", block_k=block_K, arch=arch), | ||
| E_shared: | ||
| make_metadata_layout( | ||
| E_shared, | ||
| mma_dtype="float16", | ||
| backend="cutlass", | ||
| block_k=block_K, | ||
| arch=arch), | ||
| }) | ||
| for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): | ||
| T.copy(A_sparse[by * block_M, k * block_K // 2], A_shared) | ||
| T.copy(E[by * block_M, k * block_K // e_factor], E_shared) | ||
| T.copy(B[k * block_K, bx * block_N], B_shared) | ||
| T.gemm_sp(A_shared, E_shared, B_shared, C_local, False, False, policy=policy) | ||
|
|
||
| T.copy(C_local, C_shared) | ||
| T.copy(C_shared, C[by * block_M, bx * block_N]) | ||
|
|
||
| return gemm_sp_fp16 | ||
|
|
||
|
|
||
| def main(): | ||
| parser = argparse.ArgumentParser(description="Autotuned MatMul Benchmark") | ||
| parser.add_argument("--m", type=int, default=16384, help="Matrix dimension M") | ||
| parser.add_argument("--n", type=int, default=16384, help="Matrix dimension N") | ||
| parser.add_argument("--k", type=int, default=16384, help="Matrix dimension K") | ||
| parser.add_argument( | ||
| "--accum_dtype", | ||
| type=str, | ||
| default="float", | ||
| choices=["float", "float16"], | ||
| help="Accumulation datatype") | ||
| parser.add_argument("--cfg", type=str, choices=["4090", "h20"], required=True) | ||
| args = parser.parse_args() | ||
| kernel = matmul_sp_fp16(args.m, args.n, args.k, args.accum_dtype, | ||
| **default_config[args.cfg][args.accum_dtype]) | ||
|
|
||
| a = generate_sparse_tensor(args.m, args.k, device='cuda', dtype=torch.half) | ||
| b = torch.randn(args.k, args.n, device='cuda', dtype=torch.half) | ||
|
|
||
| a_sparse, e = compress( | ||
| a, | ||
| transposed=False, | ||
| block_k=default_config[args.cfg][args.accum_dtype]['block_K'], | ||
| arch=arch) | ||
| c = kernel(a_sparse, e, b) | ||
|
|
||
| ref_c = a @ b | ||
|
|
||
| assert not c.isnan().any(), "Reference result contains NaNs, please report an issue" | ||
| torch.testing.assert_close(c, ref_c.to(c.dtype), rtol=1e-2, atol=1e-2) | ||
| print(f"Precision check passed. diff: {(c - ref_c).abs().mean()}") | ||
|
|
||
| latency = do_bench(lambda: kernel(a_sparse, e, b)) | ||
| ref_latency = do_bench(lambda: a @ b) | ||
|
|
||
| total_flops = 2 * args.m * args.n * args.k | ||
| tflops = total_flops / latency / 1e9 | ||
| ref_tflops = total_flops / ref_latency / 1e9 | ||
| print(f"Sparse TFLOPS: {tflops:.2f}, Latency: {latency/1e3} s") | ||
| print(f"Reference TFLOPS: {ref_tflops:.2f}, Latency: {ref_latency/1e3:} s") | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| main() | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.