Skip to content
This repository was archived by the owner on Oct 11, 2024. It is now read-only.

Commit 0ccb117

Browse files
tlrmchlsmthRobert Shaw
authored andcommitted
[Kernel] Suppress mma.sp warning on CUDA 12.5 and later (vllm-project#5401)
1 parent deee747 commit 0ccb117

File tree

1 file changed

+42
-32
lines changed
  • csrc/quantization/marlin/sparse/common

1 file changed

+42
-32
lines changed

csrc/quantization/marlin/sparse/common/mma.h

Lines changed: 42 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,19 @@
2020

2121
namespace marlin_24 {
2222

23+
// On CUDA earlier than 12.5, the ordered_metadata version of this instruction
24+
// is not supported. On later versions of CUDA the version without ordered
25+
// metadata results in the following warning:
26+
// | Advisory: Modifier ‘.sp::ordered_metadata’ should be used on instruction
27+
// | ‘mma’ instead of modifier ‘.sp’ as it is expected to have substantially
28+
// | reduced performance on some future architectures
29+
#if defined CUDA_VERSION && CUDA_VERSION >= 12500
30+
#define MMA_SP_INST \
31+
"mma.sp::ordered_metadata.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 "
32+
#else
33+
#define MMA_SP_INST "mma.sp.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 "
34+
#endif
35+
2336
// m16n8k32 sparse tensor core mma instruction with fp16 inputs and fp32
2437
// output/accumulation.
2538
__device__ inline void mma_sp(const FragB& a_frag0, const FragB& a_frag1,
@@ -29,41 +42,38 @@ __device__ inline void mma_sp(const FragB& a_frag0, const FragB& a_frag1,
2942
const uint32_t* a1 = reinterpret_cast<const uint32_t*>(&a_frag1);
3043
const uint32_t* b = reinterpret_cast<const uint32_t*>(&frag_b);
3144
const uint32_t* e = reinterpret_cast<const uint32_t*>(&frag_m);
45+
3246
float* c = reinterpret_cast<float*>(&frag_c);
3347
if (psel == 0) {
34-
asm volatile(
35-
"mma.sp.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 "
36-
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, "
37-
"{%12,%13,%14,%15}, %16, 0x0;\n"
38-
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
39-
: "r"(a0[0]), "r"(a1[0]), "r"(a0[1]), "r"(a1[1]), "r"(b[0]), "r"(b[2]),
40-
"r"(b[4]), "r"(b[6]), "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]),
41-
"r"(e[0]));
42-
asm volatile(
43-
"mma.sp.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 "
44-
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, "
45-
"{%12,%13,%14,%15}, %16, 0x0;\n"
46-
: "=f"(c[4]), "=f"(c[5]), "=f"(c[6]), "=f"(c[7])
47-
: "r"(a0[0]), "r"(a1[0]), "r"(a0[1]), "r"(a1[1]), "r"(b[1]), "r"(b[3]),
48-
"r"(b[5]), "r"(b[7]), "f"(c[4]), "f"(c[5]), "f"(c[6]), "f"(c[7]),
49-
"r"(e[0]));
48+
asm volatile(MMA_SP_INST
49+
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, "
50+
"{%12,%13,%14,%15}, %16, 0x0;\n"
51+
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
52+
: "r"(a0[0]), "r"(a1[0]), "r"(a0[1]), "r"(a1[1]), "r"(b[0]),
53+
"r"(b[2]), "r"(b[4]), "r"(b[6]), "f"(c[0]), "f"(c[1]),
54+
"f"(c[2]), "f"(c[3]), "r"(e[0]));
55+
asm volatile(MMA_SP_INST
56+
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, "
57+
"{%12,%13,%14,%15}, %16, 0x0;\n"
58+
: "=f"(c[4]), "=f"(c[5]), "=f"(c[6]), "=f"(c[7])
59+
: "r"(a0[0]), "r"(a1[0]), "r"(a0[1]), "r"(a1[1]), "r"(b[1]),
60+
"r"(b[3]), "r"(b[5]), "r"(b[7]), "f"(c[4]), "f"(c[5]),
61+
"f"(c[6]), "f"(c[7]), "r"(e[0]));
5062
} else {
51-
asm volatile(
52-
"mma.sp.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 "
53-
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, "
54-
"{%12,%13,%14,%15}, %16, 0x1;\n"
55-
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
56-
: "r"(a0[0]), "r"(a1[0]), "r"(a0[1]), "r"(a1[1]), "r"(b[0]), "r"(b[2]),
57-
"r"(b[4]), "r"(b[6]), "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]),
58-
"r"(e[0]));
59-
asm volatile(
60-
"mma.sp.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 "
61-
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, "
62-
"{%12,%13,%14,%15}, %16, 0x1;\n"
63-
: "=f"(c[4]), "=f"(c[5]), "=f"(c[6]), "=f"(c[7])
64-
: "r"(a0[0]), "r"(a1[0]), "r"(a0[1]), "r"(a1[1]), "r"(b[1]), "r"(b[3]),
65-
"r"(b[5]), "r"(b[7]), "f"(c[4]), "f"(c[5]), "f"(c[6]), "f"(c[7]),
66-
"r"(e[0]));
63+
asm volatile(MMA_SP_INST
64+
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, "
65+
"{%12,%13,%14,%15}, %16, 0x1;\n"
66+
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
67+
: "r"(a0[0]), "r"(a1[0]), "r"(a0[1]), "r"(a1[1]), "r"(b[0]),
68+
"r"(b[2]), "r"(b[4]), "r"(b[6]), "f"(c[0]), "f"(c[1]),
69+
"f"(c[2]), "f"(c[3]), "r"(e[0]));
70+
asm volatile(MMA_SP_INST
71+
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, "
72+
"{%12,%13,%14,%15}, %16, 0x1;\n"
73+
: "=f"(c[4]), "=f"(c[5]), "=f"(c[6]), "=f"(c[7])
74+
: "r"(a0[0]), "r"(a1[0]), "r"(a0[1]), "r"(a1[1]), "r"(b[1]),
75+
"r"(b[3]), "r"(b[5]), "r"(b[7]), "f"(c[4]), "f"(c[5]),
76+
"f"(c[6]), "f"(c[7]), "r"(e[0]));
6777
}
6878
}
6979

0 commit comments

Comments
 (0)