20
20
21
21
namespace marlin_24 {
22
22
23
+ // On CUDA earlier than 12.5, the ordered_metadata version of this instruction
24
+ // is not supported. On later versions of CUDA the version without ordered
25
+ // metadata results in the following warning:
26
+ // | Advisory: Modifier ‘.sp::ordered_metadata’ should be used on instruction
27
+ // | ‘mma’ instead of modifier ‘.sp’ as it is expected to have substantially
28
+ // | reduced performance on some future architectures
29
+ #if defined CUDA_VERSION && CUDA_VERSION >= 12500
30
+ #define MMA_SP_INST \
31
+ " mma.sp::ordered_metadata.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 "
32
+ #else
33
+ #define MMA_SP_INST " mma.sp.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 "
34
+ #endif
35
+
23
36
// m16n8k32 sparse tensor core mma instruction with fp16 inputs and fp32
24
37
// output/accumulation.
25
38
__device__ inline void mma_sp (const FragB& a_frag0, const FragB& a_frag1,
@@ -29,41 +42,38 @@ __device__ inline void mma_sp(const FragB& a_frag0, const FragB& a_frag1,
29
42
const uint32_t * a1 = reinterpret_cast <const uint32_t *>(&a_frag1);
30
43
const uint32_t * b = reinterpret_cast <const uint32_t *>(&frag_b);
31
44
const uint32_t * e = reinterpret_cast <const uint32_t *>(&frag_m);
45
+
32
46
float * c = reinterpret_cast <float *>(&frag_c);
33
47
if (psel == 0 ) {
34
- asm volatile (
35
- " mma.sp.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 "
36
- " {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, "
37
- " {%12,%13,%14,%15}, %16, 0x0;\n "
38
- : " =f" (c[0 ]), " =f" (c[1 ]), " =f" (c[2 ]), " =f" (c[3 ])
39
- : " r" (a0[0 ]), " r" (a1[0 ]), " r" (a0[1 ]), " r" (a1[1 ]), " r" (b[0 ]), " r" (b[2 ]),
40
- " r" (b[4 ]), " r" (b[6 ]), " f" (c[0 ]), " f" (c[1 ]), " f" (c[2 ]), " f" (c[3 ]),
41
- " r" (e[0 ]));
42
- asm volatile (
43
- " mma.sp.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 "
44
- " {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, "
45
- " {%12,%13,%14,%15}, %16, 0x0;\n "
46
- : " =f" (c[4 ]), " =f" (c[5 ]), " =f" (c[6 ]), " =f" (c[7 ])
47
- : " r" (a0[0 ]), " r" (a1[0 ]), " r" (a0[1 ]), " r" (a1[1 ]), " r" (b[1 ]), " r" (b[3 ]),
48
- " r" (b[5 ]), " r" (b[7 ]), " f" (c[4 ]), " f" (c[5 ]), " f" (c[6 ]), " f" (c[7 ]),
49
- " r" (e[0 ]));
48
+ asm volatile (MMA_SP_INST
49
+ " {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, "
50
+ " {%12,%13,%14,%15}, %16, 0x0;\n "
51
+ : " =f" (c[0 ]), " =f" (c[1 ]), " =f" (c[2 ]), " =f" (c[3 ])
52
+ : " r" (a0[0 ]), " r" (a1[0 ]), " r" (a0[1 ]), " r" (a1[1 ]), " r" (b[0 ]),
53
+ " r" (b[2 ]), " r" (b[4 ]), " r" (b[6 ]), " f" (c[0 ]), " f" (c[1 ]),
54
+ " f" (c[2 ]), " f" (c[3 ]), " r" (e[0 ]));
55
+ asm volatile (MMA_SP_INST
56
+ " {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, "
57
+ " {%12,%13,%14,%15}, %16, 0x0;\n "
58
+ : " =f" (c[4 ]), " =f" (c[5 ]), " =f" (c[6 ]), " =f" (c[7 ])
59
+ : " r" (a0[0 ]), " r" (a1[0 ]), " r" (a0[1 ]), " r" (a1[1 ]), " r" (b[1 ]),
60
+ " r" (b[3 ]), " r" (b[5 ]), " r" (b[7 ]), " f" (c[4 ]), " f" (c[5 ]),
61
+ " f" (c[6 ]), " f" (c[7 ]), " r" (e[0 ]));
50
62
} else {
51
- asm volatile (
52
- " mma.sp.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 "
53
- " {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, "
54
- " {%12,%13,%14,%15}, %16, 0x1;\n "
55
- : " =f" (c[0 ]), " =f" (c[1 ]), " =f" (c[2 ]), " =f" (c[3 ])
56
- : " r" (a0[0 ]), " r" (a1[0 ]), " r" (a0[1 ]), " r" (a1[1 ]), " r" (b[0 ]), " r" (b[2 ]),
57
- " r" (b[4 ]), " r" (b[6 ]), " f" (c[0 ]), " f" (c[1 ]), " f" (c[2 ]), " f" (c[3 ]),
58
- " r" (e[0 ]));
59
- asm volatile (
60
- " mma.sp.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 "
61
- " {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, "
62
- " {%12,%13,%14,%15}, %16, 0x1;\n "
63
- : " =f" (c[4 ]), " =f" (c[5 ]), " =f" (c[6 ]), " =f" (c[7 ])
64
- : " r" (a0[0 ]), " r" (a1[0 ]), " r" (a0[1 ]), " r" (a1[1 ]), " r" (b[1 ]), " r" (b[3 ]),
65
- " r" (b[5 ]), " r" (b[7 ]), " f" (c[4 ]), " f" (c[5 ]), " f" (c[6 ]), " f" (c[7 ]),
66
- " r" (e[0 ]));
63
+ asm volatile (MMA_SP_INST
64
+ " {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, "
65
+ " {%12,%13,%14,%15}, %16, 0x1;\n "
66
+ : " =f" (c[0 ]), " =f" (c[1 ]), " =f" (c[2 ]), " =f" (c[3 ])
67
+ : " r" (a0[0 ]), " r" (a1[0 ]), " r" (a0[1 ]), " r" (a1[1 ]), " r" (b[0 ]),
68
+ " r" (b[2 ]), " r" (b[4 ]), " r" (b[6 ]), " f" (c[0 ]), " f" (c[1 ]),
69
+ " f" (c[2 ]), " f" (c[3 ]), " r" (e[0 ]));
70
+ asm volatile (MMA_SP_INST
71
+ " {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, "
72
+ " {%12,%13,%14,%15}, %16, 0x1;\n "
73
+ : " =f" (c[4 ]), " =f" (c[5 ]), " =f" (c[6 ]), " =f" (c[7 ])
74
+ : " r" (a0[0 ]), " r" (a1[0 ]), " r" (a0[1 ]), " r" (a1[1 ]), " r" (b[1 ]),
75
+ " r" (b[3 ]), " r" (b[5 ]), " r" (b[7 ]), " f" (c[4 ]), " f" (c[5 ]),
76
+ " f" (c[6 ]), " f" (c[7 ]), " r" (e[0 ]));
67
77
}
68
78
}
69
79
0 commit comments