@@ -33,10 +33,126 @@ using ADataType = ck::bhalf_t;
33
33
using BDataType = ck::bhalf_t ;
34
34
using CDataType = ck::bhalf_t ;
35
35
36
- GroupedKernel grouped_heuristic_dispatch (int M, int N, int K) {
36
+ // Define a custom hash function for std::tuple<int, int, int>
37
+ struct IntTupleHash {
38
+ size_t operator ()(const std::tuple<int , int >& t) const {
39
+ auto hash1 = std::hash<int >{}(std::get<0 >(t));
40
+ auto hash2 = std::hash<int >{}(std::get<1 >(t));
41
+ return hash1 ^ hash2;
42
+ }
43
+ size_t operator ()(const std::tuple<int , int , int >& t) const {
44
+ auto hash1 = std::hash<int >{}(std::get<0 >(t));
45
+ auto hash2 = std::hash<int >{}(std::get<1 >(t));
46
+ auto hash3 = std::hash<int >{}(std::get<2 >(t));
47
+ return hash1 ^ hash2 ^ hash3;
48
+ }
49
+ size_t operator ()(const std::tuple<int , int , int , int >& t) const {
50
+ auto hash1 = std::hash<int >{}(std::get<0 >(t));
51
+ auto hash2 = std::hash<int >{}(std::get<1 >(t));
52
+ auto hash3 = std::hash<int >{}(std::get<2 >(t));
53
+ auto hash4 = std::hash<int >{}(std::get<3 >(t));
54
+ return hash1 ^ hash2 ^ hash3 ^ hash4;
55
+ }
56
+ };
57
+
58
+ // For certain high priority shapes, we directly map to the best kernel rather
59
+ // than use heuristics.
60
+ static const std::unordered_map<std::tuple<int , int , int , int >, GroupedKernel, IntTupleHash> bf16_grouped_lookup_dispatch = {
61
+ {{16 ,16 ,2048 ,5120 },bf16_grouped_128x16x64x128_16x16_1x2_16x8x1_16x8x1_1x16x1x8_8x8x1_1x2_intrawave_v2},
62
+ {{16 ,16 ,5120 ,1024 },bf16_grouped_64x16x16x128_16x16_1x1_16x4x1_16x4x1_1x16x1x4_4x4x1_1x1_interwave_v2},
63
+ {{16 ,16 ,16384 ,5120 },bf16_grouped_64x16x32x128_16x16_1x2_16x4x1_16x4x1_1x16x1x4_8x8x1_1x2_intrawave_v2},
64
+ {{16 ,16 ,5120 ,8192 },bf16_grouped_128x16x64x128_16x16_1x2_16x8x1_16x8x1_1x16x1x8_8x8x1_1x2_interwave_v1},
65
+ {{16 ,32 ,2048 ,5120 },bf16_grouped_128x16x64x128_16x16_1x2_16x8x1_16x8x1_1x16x1x8_8x8x1_1x2_interwave_v2},
66
+ {{16 ,32 ,5120 ,1024 },bf16_grouped_256x16x64x128_16x16_1x1_16x16x1_16x16x1_1x16x1x16_4x4x1_1x1_interwave_v2},
67
+ {{16 ,32 ,16384 ,5120 },bf16_grouped_128x16x96x128_16x16_1x3_16x8x1_16x8x1_1x16x1x8_4x4x1_1x1_interwave_v1},
68
+ {{16 ,32 ,5120 ,8192 },bf16_grouped_128x16x96x64_16x16_1x3_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v1},
69
+ {{16 ,64 ,2048 ,5120 },bf16_grouped_256x16x64x128_16x16_1x1_16x16x1_16x16x1_1x16x1x16_4x4x1_1x1_interwave_v2},
70
+ {{16 ,64 ,5120 ,1024 },bf16_grouped_256x16x64x128_16x16_1x1_16x16x1_16x16x1_1x16x1x16_4x4x1_1x1_interwave_v2},
71
+ {{16 ,64 ,16384 ,5120 },bf16_grouped_128x16x96x128_16x16_1x3_16x8x1_16x8x1_1x16x1x8_4x4x1_1x1_interwave_v1},
72
+ {{16 ,64 ,5120 ,8192 },bf16_grouped_128x16x96x64_16x16_1x3_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v1},
73
+ {{16 ,128 ,2048 ,5120 },bf16_grouped_128x16x64x128_16x16_1x2_16x8x1_16x8x1_1x16x1x8_8x8x1_1x2_interwave_v2},
74
+ {{16 ,128 ,5120 ,1024 },bf16_grouped_256x16x64x128_16x16_1x1_16x16x1_16x16x1_1x16x1x16_4x4x1_1x1_interwave_v2},
75
+ {{16 ,128 ,16384 ,5120 },bf16_grouped_64x16x64x128_16x16_1x4_16x4x1_16x4x1_1x16x1x4_8x8x1_1x2_intrawave_v2},
76
+ {{16 ,128 ,5120 ,8192 },bf16_grouped_128x16x96x128_16x16_1x3_16x8x1_16x8x1_1x16x1x8_4x4x1_1x1_intrawave_v1},
77
+ {{16 ,256 ,2048 ,5120 },bf16_grouped_128x16x64x128_16x16_1x2_16x8x1_16x8x1_1x16x1x8_8x8x1_1x2_interwave_v2},
78
+ {{16 ,256 ,5120 ,1024 },bf16_grouped_128x16x32x128_16x16_1x1_16x8x1_16x8x1_1x16x1x8_4x4x1_1x1_interwave_v2},
79
+ {{16 ,256 ,16384 ,5120 },bf16_grouped_128x16x96x128_16x16_1x3_16x8x1_16x8x1_1x16x1x8_4x4x1_1x1_intrawave_v1},
80
+ {{16 ,256 ,5120 ,8192 },bf16_grouped_128x16x96x128_16x16_1x3_16x8x1_16x8x1_1x16x1x8_4x4x1_1x1_interwave_v1},
81
+ {{16 ,512 ,2048 ,5120 },bf16_grouped_128x32x64x128_32x32_1x1_16x8x1_16x8x1_1x16x1x8_8x8x1_1x1_interwave_v2},
82
+ {{16 ,512 ,5120 ,1024 },bf16_grouped_256x32x96x64_16x16_1x3_8x32x1_8x32x1_1x32x1x8_4x4x1_1x1_interwave_v2},
83
+ {{16 ,512 ,16384 ,5120 },bf16_grouped_128x32x96x128_16x16_2x3_16x8x1_16x8x1_1x32x1x4_8x8x1_2x1_intrawave_v2},
84
+ {{16 ,512 ,5120 ,8192 },bf16_grouped_256x32x96x64_16x16_1x3_8x32x1_8x32x1_1x32x1x8_4x4x1_1x1_interwave_v1},
85
+ {{16 ,1024 ,2048 ,5120 },bf16_grouped_256x64x128x128_32x32_2x1_16x16x1_16x16x1_1x16x1x16_8x8x1_1x1_intrawave_v3},
86
+ {{16 ,1024 ,5120 ,1024 },bf16_grouped_256x64x96x64_16x16_2x3_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3},
87
+ {{16 ,1024 ,16384 ,5120 },bf16_grouped_256x64x192x128_16x16_4x3_16x16x1_16x16x1_1x32x1x8_8x8x1_2x1_intrawave_v3},
88
+ {{16 ,1024 ,5120 ,8192 },bf16_grouped_128x64x96x64_16x16_4x3_8x16x1_8x16x1_1x32x1x4_8x8x1_2x1_intrawave_v3},
89
+ {{16 ,2048 ,2048 ,5120 },bf16_grouped_256x128x128x128_32x32_2x2_16x16x1_16x16x1_1x32x1x8_8x8x1_1x1_intrawave_v3},
90
+ {{16 ,2048 ,5120 ,1024 },bf16_grouped_256x128x96x64_16x16_4x3_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3},
91
+ {{16 ,2048 ,16384 ,5120 },bf16_grouped_256x128x224x64_16x16_4x7_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3},
92
+ {{16 ,2048 ,5120 ,8192 },bf16_grouped_256x128x96x64_16x16_4x3_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3},
93
+ {{16 ,4096 ,2048 ,5120 },bf16_grouped_256x128x256x64_32x32_4x2_8x32x1_8x32x1_1x16x1x16_8x8x1_1x1_intrawave_v3},
94
+ {{16 ,4096 ,5120 ,1024 },bf16_grouped_256x128x96x64_16x16_4x3_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3},
95
+ {{16 ,4096 ,16384 ,5120 },bf16_grouped_256x256x224x64_16x16_8x7_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3},
96
+ {{16 ,4096 ,5120 ,8192 },bf16_grouped_256x256x160x64_16x16_8x5_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3},
97
+ {{16 ,8192 ,2048 ,5120 },bf16_grouped_256x256x256x64_32x32_4x4_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3},
98
+ {{16 ,8192 ,5120 ,1024 },bf16_grouped_256x256x192x64_32x32_4x3_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3},
99
+ {{16 ,8192 ,16384 ,5120 },bf16_grouped_256x256x224x64_16x16_8x7_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3},
100
+ {{16 ,8192 ,5120 ,8192 },bf16_grouped_256x256x192x64_32x32_4x3_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3},
101
+ {{128 ,128 ,2048 ,5120 },bf16_grouped_128x16x64x128_16x16_1x2_16x8x1_16x8x1_1x16x1x8_8x8x1_1x2_interwave_v2},
102
+ {{128 ,128 ,5120 ,1024 },bf16_grouped_256x16x64x128_16x16_1x1_16x16x1_16x16x1_1x16x1x16_4x4x1_1x1_interwave_v2},
103
+ {{128 ,128 ,16384 ,5120 },bf16_grouped_256x16x128x128_16x16_1x2_16x16x1_16x16x1_1x16x1x16_8x8x1_1x2_intrawave_v2},
104
+ {{128 ,128 ,5120 ,8192 },bf16_grouped_128x16x96x128_16x16_1x3_16x8x1_16x8x1_1x16x1x8_4x4x1_1x1_interwave_v2},
105
+ {{128 ,256 ,2048 ,5120 },bf16_grouped_128x16x96x128_16x16_1x3_16x8x1_16x8x1_1x16x1x8_4x4x1_1x1_intrawave_v1},
106
+ {{128 ,256 ,5120 ,1024 },bf16_grouped_256x16x64x128_16x16_1x1_16x16x1_16x16x1_1x16x1x16_4x4x1_1x1_interwave_v2},
107
+ {{128 ,256 ,16384 ,5120 },bf16_grouped_256x16x128x128_16x16_1x2_16x16x1_16x16x1_1x16x1x16_8x8x1_1x2_intrawave_v1},
108
+ {{128 ,256 ,5120 ,8192 },bf16_grouped_64x16x48x128_16x16_1x3_16x4x1_16x4x1_1x16x1x4_4x4x1_1x1_intrawave_v1},
109
+ {{128 ,512 ,2048 ,5120 },bf16_grouped_128x16x64x128_16x16_1x2_16x8x1_16x8x1_1x16x1x8_8x8x1_1x2_intrawave_v2},
110
+ {{128 ,512 ,5120 ,1024 },bf16_grouped_128x16x64x128_16x16_1x2_16x8x1_16x8x1_1x16x1x8_8x8x1_1x2_interwave_v2},
111
+ {{128 ,512 ,16384 ,5120 },bf16_grouped_256x16x128x128_16x16_1x2_16x16x1_16x16x1_1x16x1x16_8x8x1_1x2_intrawave_v2},
112
+ {{128 ,512 ,5120 ,8192 },bf16_grouped_128x16x96x128_16x16_1x3_16x8x1_16x8x1_1x16x1x8_4x4x1_1x1_interwave_v2},
113
+ {{128 ,1024 ,2048 ,5120 },bf16_grouped_128x16x96x128_16x16_1x3_16x8x1_16x8x1_1x16x1x8_4x4x1_1x1_intrawave_v2},
114
+ {{128 ,1024 ,5120 ,1024 },bf16_grouped_128x16x64x128_16x16_1x2_16x8x1_16x8x1_1x16x1x8_8x8x1_1x2_intrawave_v2},
115
+ {{128 ,1024 ,16384 ,5120 },bf16_grouped_256x16x128x128_16x16_1x2_16x16x1_16x16x1_1x16x1x16_8x8x1_1x2_intrawave_v2},
116
+ {{128 ,1024 ,5120 ,8192 },bf16_grouped_128x16x96x128_16x16_1x3_16x8x1_16x8x1_1x16x1x8_4x4x1_1x1_interwave_v2},
117
+ {{128 ,2048 ,2048 ,5120 },bf16_grouped_128x16x96x128_16x16_1x3_16x8x1_16x8x1_1x16x1x8_4x4x1_1x1_intrawave_v2},
118
+ {{128 ,2048 ,5120 ,1024 },bf16_grouped_128x16x32x128_16x16_1x1_16x8x1_16x8x1_1x16x1x8_4x4x1_1x1_interwave_v2},
119
+ {{128 ,2048 ,16384 ,5120 },bf16_grouped_256x16x128x128_16x16_1x2_16x16x1_16x16x1_1x16x1x16_8x8x1_1x2_intrawave_v2},
120
+ {{128 ,2048 ,5120 ,8192 },bf16_grouped_128x16x96x128_16x16_1x3_16x8x1_16x8x1_1x16x1x8_4x4x1_1x1_intrawave_v1},
121
+ {{128 ,4096 ,2048 ,5120 },bf16_grouped_128x32x64x128_32x32_1x1_16x8x1_16x8x1_1x16x1x8_8x8x1_1x1_intrawave_v1},
122
+ {{128 ,4096 ,5120 ,1024 },bf16_grouped_128x32x64x128_32x32_1x1_16x8x1_16x8x1_1x16x1x8_8x8x1_1x1_interwave_v2},
123
+ {{128 ,4096 ,16384 ,5120 },bf16_grouped_256x32x128x128_16x16_1x4_16x16x1_16x16x1_1x32x1x8_8x8x1_1x2_intrawave_v2},
124
+ {{128 ,4096 ,5120 ,8192 },bf16_grouped_256x32x224x64_16x16_1x7_8x32x1_8x32x1_1x32x1x8_4x4x1_1x1_intrawave_v1},
125
+ {{128 ,8192 ,2048 ,5120 },bf16_grouped_256x64x192x128_16x16_4x3_16x16x1_16x16x1_1x32x1x8_8x8x1_2x1_intrawave_v3},
126
+ {{128 ,8192 ,5120 ,1024 },bf16_grouped_128x64x128x64_32x32_2x2_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_intrawave_v3},
127
+ {{128 ,8192 ,16384 ,5120 },bf16_grouped_256x64x192x128_16x16_4x3_16x16x1_16x16x1_1x32x1x8_8x8x1_2x1_intrawave_v3},
128
+ {{128 ,8192 ,5120 ,8192 },bf16_grouped_256x64x192x128_16x16_4x3_16x16x1_16x16x1_1x32x1x8_8x8x1_2x1_intrawave_v3},
129
+ };
130
+
131
+
132
+
133
+ // Helper function to return the next largest power of 2
134
+ static constexpr int nextPow2 (unsigned int num)
135
+ {
136
+ if (num <= 1 )
137
+ return 1 ;
138
+ return 1 << (CHAR_BIT * sizeof (num) - __builtin_clz (num - 1 ));
139
+ }
140
+
141
+ GroupedKernel grouped_heuristic_dispatch (int G, int total_M, int N, int K) {
37
142
// We use shape heuristics to find the best kernel.
38
143
// To do this, we divide by the size of M and find the best
39
144
// option within that grouping.
145
+
146
+ // First check if this shape is available in the direct lookup.
147
+ int padded_m = nextPow2 (total_M);
148
+ padded_m = padded_m < G ? G : padded_m;
149
+ auto it = bf16_grouped_lookup_dispatch.find ({G, padded_m, N, K});
150
+ // If we found an optimal kernel, use it.
151
+ if (it != bf16_grouped_lookup_dispatch.end ()) {
152
+ return it->second ;
153
+ }
154
+
155
+ #if 0
40
156
if (M <= 16) {
41
157
if (N < 8192 && K <= 8192) {
42
158
return bf16_grouped_64x16x16x256_16x16_1x1_16x4x1_16x4x1_1x4x1x16_4x4x1_1x1_intrawave_v1;
@@ -73,8 +189,9 @@ GroupedKernel grouped_heuristic_dispatch(int M, int N, int K) {
73
189
}
74
190
return bf16_grouped_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3;
75
191
}
192
+ #endif
76
193
// Default kernel for all other shapes.
77
- return bf16_grouped_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_interwave_v1 ;
194
+ return bf16_grouped_256x128x128x64_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3 ;
78
195
}
79
196
80
197
__global__ void set_kernel_args_kernel (
@@ -343,7 +460,7 @@ std::vector<at::Tensor> bf16bf16bf16_grouped(
343
460
MaxN = max (MaxN, B[i].size (0 ));
344
461
MaxK = max (MaxK, A[i].size (1 ));
345
462
}
346
- GroupedKernel selected_kernel = grouped_heuristic_dispatch (MaxM, MaxN, MaxK);
463
+ GroupedKernel selected_kernel = grouped_heuristic_dispatch (group_count, MaxM, MaxN, MaxK);
347
464
return selected_kernel (A, B, kernel_args, Y);
348
465
}
349
466
@@ -418,7 +535,7 @@ at::Tensor bf16bf16bf16_grouped_dynamic(
418
535
MaxN = max (MaxN, B[i].size (0 ));
419
536
MaxK = max (MaxK, A[i].size (1 ));
420
537
}
421
- GroupedKernel selected_kernel = grouped_heuristic_dispatch (MaxM, MaxN, MaxK);
538
+ GroupedKernel selected_kernel = grouped_heuristic_dispatch (group_count, MaxM, MaxN, MaxK);
422
539
// Run kernel to populate output.
423
540
selected_kernel (A, B, kernel_args, Y);
424
541
// Return coalesced view of output tensor.
0 commit comments