Skip to content

Commit f5fca05

Browse files
[CUDA] Init support for sm_120 (#716)
* Init support for sm120 * fmt * resolve comments * unify mma gemm * fmt --------- Co-authored-by: LeiWang1999 <leiwang1999@outlook.com>
1 parent 6610c7b commit f5fca05

File tree

8 files changed

+477
-881
lines changed

8 files changed

+477
-881
lines changed

src/op/gemm.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -370,7 +370,8 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) {
370370
results.Set(B, makeGemmVoltaABLayout(*as_const_int(B->shape[dim_B - 2]),
371371
*as_const_int(B->shape[dim_B - 1]),
372372
false, trans_B ? 2 : 1));
373-
} else if (TargetIsAmpere(T.target) || TargetIsTuring(T.target)) {
373+
} else if (TargetIsAmpere(T.target) || TargetIsTuring(T.target) ||
374+
TargetIsSM120(T.target)) {
374375
auto fragment =
375376
makeGemmFragmentC(M, N, M / warp_m, N / warp_n, C->dtype.bits());
376377
results.Set(C, fragment->BindThreadRange(thread_range));

src/target/utils.cc

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,14 @@ bool TargetIsHopper(Target target) {
5050
if (!TargetIsCuda(target))
5151
return false;
5252
int arch = GetArchInt(target);
53-
return arch >= 90;
53+
return arch >= 90 && arch < 100;
54+
}
55+
56+
bool TargetIsSM120(Target target) {
57+
if (!TargetIsCuda(target))
58+
return false;
59+
int arch = GetArchInt(target);
60+
return arch >= 120 && arch < 130;
5461
}
5562

5663
bool TargetIsCDNA(Target target) {

src/target/utils.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ bool TargetIsVolta(Target target);
1919
bool TargetIsTuring(Target target);
2020
bool TargetIsAmpere(Target target);
2121
bool TargetIsHopper(Target target);
22+
bool TargetIsSM120(Target target);
2223
bool TargetIsCDNA(Target target);
2324

2425
bool TargetHasAsyncCopy(Target target);

src/tl_templates/cuda/gemm.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
#pragma once
2-
#if (defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ >= 900))
2+
#if (defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ >= 1200))
3+
#include "gemm_sm120.h"
4+
#elif (defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ >= 900))
35
#include "gemm_sm90.h"
46
#elif (defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ >= 890))
57
#include "gemm_sm89.h"

src/tl_templates/cuda/gemm_mma.h

Lines changed: 458 additions & 0 deletions
Large diffs are not rendered by default.

src/tl_templates/cuda/gemm_sm120.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
#pragma once
2+
3+
#include "gemm_mma.h"

0 commit comments

Comments
 (0)