@@ -33,48 +33,128 @@ using ADataType = ck::bhalf_t;
33
33
using BDataType = ck::bhalf_t ;
34
34
using CDataType = ck::bhalf_t ;
35
35
36
- GroupedKernel grouped_heuristic_dispatch (int M, int N, int K) {
37
- // We use shape heuristics to find the best kernel.
38
- // To do this, we divide by the size of M and find the best
39
- // option within that grouping.
40
- if (M <= 16 ) {
41
- if (N < 8192 && K <= 8192 ) {
42
- return bf16_grouped_64x16x16x256_16x16_1x1_16x4x1_16x4x1_1x4x1x16_4x4x1_1x1_intrawave_v1;
43
- }
44
- if (K <= 8192 ) {
45
- return bf16_grouped_128x16x64x128_16x16_1x2_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2;
46
- }
47
- return bf16_grouped_128x16x32x256_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2;
36
+ // Define a custom hash function for std::tuple<int, int, int>
37
+ struct IntTupleHash {
38
+ size_t operator ()(const std::tuple<int , int >& t) const {
39
+ auto hash1 = std::hash<int >{}(std::get<0 >(t));
40
+ auto hash2 = std::hash<int >{}(std::get<1 >(t));
41
+ return hash1 ^ hash2;
48
42
}
49
- if (M <= 32 ) {
50
- if (N < 8192 && K <= 8192 ) {
51
- return bf16_grouped_128x32x64x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2;
52
- }
53
- if (K <= 8192 ) {
54
- return bf16_grouped_128x32x128x128_32x32_1x2_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2;
55
- }
56
- return bf16_grouped_128x32x128x128_32x32_1x2_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_intrawave_v2;
57
- }
58
- if (M <= 64 ) {
59
- return bf16_grouped_256x64x64x128_32x32_1x1_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3;
43
+ size_t operator ()(const std::tuple<int , int , int >& t) const {
44
+ auto hash1 = std::hash<int >{}(std::get<0 >(t));
45
+ auto hash2 = std::hash<int >{}(std::get<1 >(t));
46
+ auto hash3 = std::hash<int >{}(std::get<2 >(t));
47
+ return hash1 ^ hash2 ^ hash3;
60
48
}
61
- if (M <= 128 ) {
62
- if (N < 8192 && K <= 8192 ) {
63
- return bf16_grouped_256x128x64x128_32x32_2x1_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3;
64
- }
65
- return bf16_grouped_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3;
49
+ size_t operator ()(const std::tuple<int , int , int , int >& t) const {
50
+ auto hash1 = std::hash<int >{}(std::get<0 >(t));
51
+ auto hash2 = std::hash<int >{}(std::get<1 >(t));
52
+ auto hash3 = std::hash<int >{}(std::get<2 >(t));
53
+ auto hash4 = std::hash<int >{}(std::get<3 >(t));
54
+ return hash1 ^ hash2 ^ hash3 ^ hash4;
66
55
}
67
- if (M <= 256 ) {
68
- return bf16_grouped_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3;
69
- }
70
- if (M <= 512 ) {
71
- if (K <= 8192 ) {
72
- return bf16_grouped_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_interwave_v1;
73
- }
74
- return bf16_grouped_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3;
56
+ };
57
+
58
+ // For certain high priority shapes, we directly map to the best kernel rather
59
+ // than use heuristics.
60
+ static const std::unordered_map<std::tuple<int , int , int , int >, GroupedKernel, IntTupleHash> bf16_grouped_lookup_dispatch = {
61
+ {{16 ,16 ,2048 ,5120 },bf16_grouped_128x16x64x128_16x16_1x2_16x8x1_16x8x1_1x16x1x8_8x8x1_1x2_intrawave_v2},
62
+ {{16 ,16 ,5120 ,1024 },bf16_grouped_64x16x16x128_16x16_1x1_16x4x1_16x4x1_1x16x1x4_4x4x1_1x1_interwave_v2},
63
+ {{16 ,16 ,16384 ,5120 },bf16_grouped_64x16x32x128_16x16_1x2_16x4x1_16x4x1_1x16x1x4_8x8x1_1x2_intrawave_v2},
64
+ {{16 ,16 ,5120 ,8192 },bf16_grouped_128x16x64x128_16x16_1x2_16x8x1_16x8x1_1x16x1x8_8x8x1_1x2_interwave_v1},
65
+ {{16 ,32 ,2048 ,5120 },bf16_grouped_128x16x64x128_16x16_1x2_16x8x1_16x8x1_1x16x1x8_8x8x1_1x2_interwave_v2},
66
+ {{16 ,32 ,5120 ,1024 },bf16_grouped_256x16x64x128_16x16_1x1_16x16x1_16x16x1_1x16x1x16_4x4x1_1x1_interwave_v2},
67
+ {{16 ,32 ,16384 ,5120 },bf16_grouped_128x16x96x128_16x16_1x3_16x8x1_16x8x1_1x16x1x8_4x4x1_1x1_interwave_v1},
68
+ {{16 ,32 ,5120 ,8192 },bf16_grouped_128x16x96x64_16x16_1x3_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v1},
69
+ {{16 ,64 ,2048 ,5120 },bf16_grouped_256x16x64x128_16x16_1x1_16x16x1_16x16x1_1x16x1x16_4x4x1_1x1_interwave_v2},
70
+ {{16 ,64 ,5120 ,1024 },bf16_grouped_256x16x64x128_16x16_1x1_16x16x1_16x16x1_1x16x1x16_4x4x1_1x1_interwave_v2},
71
+ {{16 ,64 ,16384 ,5120 },bf16_grouped_128x16x96x128_16x16_1x3_16x8x1_16x8x1_1x16x1x8_4x4x1_1x1_interwave_v1},
72
+ {{16 ,64 ,5120 ,8192 },bf16_grouped_128x16x96x64_16x16_1x3_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v1},
73
+ {{16 ,128 ,2048 ,5120 },bf16_grouped_128x16x64x128_16x16_1x2_16x8x1_16x8x1_1x16x1x8_8x8x1_1x2_interwave_v2},
74
+ {{16 ,128 ,5120 ,1024 },bf16_grouped_256x16x64x128_16x16_1x1_16x16x1_16x16x1_1x16x1x16_4x4x1_1x1_interwave_v2},
75
+ {{16 ,128 ,16384 ,5120 },bf16_grouped_64x16x64x128_16x16_1x4_16x4x1_16x4x1_1x16x1x4_8x8x1_1x2_intrawave_v2},
76
+ {{16 ,128 ,5120 ,8192 },bf16_grouped_128x16x96x128_16x16_1x3_16x8x1_16x8x1_1x16x1x8_4x4x1_1x1_intrawave_v1},
77
+ {{16 ,256 ,2048 ,5120 },bf16_grouped_128x16x64x128_16x16_1x2_16x8x1_16x8x1_1x16x1x8_8x8x1_1x2_interwave_v2},
78
+ {{16 ,256 ,5120 ,1024 },bf16_grouped_128x16x32x128_16x16_1x1_16x8x1_16x8x1_1x16x1x8_4x4x1_1x1_interwave_v2},
79
+ {{16 ,256 ,16384 ,5120 },bf16_grouped_128x16x96x128_16x16_1x3_16x8x1_16x8x1_1x16x1x8_4x4x1_1x1_intrawave_v1},
80
+ {{16 ,256 ,5120 ,8192 },bf16_grouped_128x16x96x128_16x16_1x3_16x8x1_16x8x1_1x16x1x8_4x4x1_1x1_interwave_v1},
81
+ {{16 ,512 ,2048 ,5120 },bf16_grouped_128x32x64x128_32x32_1x1_16x8x1_16x8x1_1x16x1x8_8x8x1_1x1_interwave_v2},
82
+ {{16 ,512 ,5120 ,1024 },bf16_grouped_256x32x96x64_16x16_1x3_8x32x1_8x32x1_1x32x1x8_4x4x1_1x1_interwave_v2},
83
+ {{16 ,512 ,16384 ,5120 },bf16_grouped_128x32x96x128_16x16_2x3_16x8x1_16x8x1_1x32x1x4_8x8x1_2x1_intrawave_v2},
84
+ {{16 ,512 ,5120 ,8192 },bf16_grouped_256x32x96x64_16x16_1x3_8x32x1_8x32x1_1x32x1x8_4x4x1_1x1_interwave_v1},
85
+ {{16 ,1024 ,2048 ,5120 },bf16_grouped_256x64x128x128_32x32_2x1_16x16x1_16x16x1_1x16x1x16_8x8x1_1x1_intrawave_v3},
86
+ {{16 ,1024 ,5120 ,1024 },bf16_grouped_256x64x96x64_16x16_2x3_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3},
87
+ {{16 ,1024 ,16384 ,5120 },bf16_grouped_256x64x192x128_16x16_4x3_16x16x1_16x16x1_1x32x1x8_8x8x1_2x1_intrawave_v3},
88
+ {{16 ,1024 ,5120 ,8192 },bf16_grouped_128x64x96x64_16x16_4x3_8x16x1_8x16x1_1x32x1x4_8x8x1_2x1_intrawave_v3},
89
+ {{16 ,2048 ,2048 ,5120 },bf16_grouped_256x128x128x128_32x32_2x2_16x16x1_16x16x1_1x32x1x8_8x8x1_1x1_intrawave_v3},
90
+ {{16 ,2048 ,5120 ,1024 },bf16_grouped_256x128x96x64_16x16_4x3_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3},
91
+ {{16 ,2048 ,16384 ,5120 },bf16_grouped_256x128x224x64_16x16_4x7_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3},
92
+ {{16 ,2048 ,5120 ,8192 },bf16_grouped_256x128x96x64_16x16_4x3_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3},
93
+ {{16 ,4096 ,2048 ,5120 },bf16_grouped_256x128x256x64_32x32_4x2_8x32x1_8x32x1_1x16x1x16_8x8x1_1x1_intrawave_v3},
94
+ {{16 ,4096 ,5120 ,1024 },bf16_grouped_256x128x96x64_16x16_4x3_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3},
95
+ {{16 ,4096 ,16384 ,5120 },bf16_grouped_256x256x224x64_16x16_8x7_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3},
96
+ {{16 ,4096 ,5120 ,8192 },bf16_grouped_256x256x160x64_16x16_8x5_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3},
97
+ {{16 ,8192 ,2048 ,5120 },bf16_grouped_256x256x256x64_32x32_4x4_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3},
98
+ {{16 ,8192 ,5120 ,1024 },bf16_grouped_256x256x192x64_32x32_4x3_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3},
99
+ {{16 ,8192 ,16384 ,5120 },bf16_grouped_256x256x224x64_16x16_8x7_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3},
100
+ {{16 ,8192 ,5120 ,8192 },bf16_grouped_256x256x192x64_32x32_4x3_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3},
101
+ {{128 ,128 ,2048 ,5120 },bf16_grouped_128x16x64x128_16x16_1x2_16x8x1_16x8x1_1x16x1x8_8x8x1_1x2_interwave_v2},
102
+ {{128 ,128 ,5120 ,1024 },bf16_grouped_256x16x64x128_16x16_1x1_16x16x1_16x16x1_1x16x1x16_4x4x1_1x1_interwave_v2},
103
+ {{128 ,128 ,16384 ,5120 },bf16_grouped_256x16x128x128_16x16_1x2_16x16x1_16x16x1_1x16x1x16_8x8x1_1x2_intrawave_v2},
104
+ {{128 ,128 ,5120 ,8192 },bf16_grouped_128x16x96x128_16x16_1x3_16x8x1_16x8x1_1x16x1x8_4x4x1_1x1_interwave_v2},
105
+ {{128 ,256 ,2048 ,5120 },bf16_grouped_128x16x96x128_16x16_1x3_16x8x1_16x8x1_1x16x1x8_4x4x1_1x1_intrawave_v1},
106
+ {{128 ,256 ,5120 ,1024 },bf16_grouped_256x16x64x128_16x16_1x1_16x16x1_16x16x1_1x16x1x16_4x4x1_1x1_interwave_v2},
107
+ {{128 ,256 ,16384 ,5120 },bf16_grouped_256x16x128x128_16x16_1x2_16x16x1_16x16x1_1x16x1x16_8x8x1_1x2_intrawave_v1},
108
+ {{128 ,256 ,5120 ,8192 },bf16_grouped_64x16x48x128_16x16_1x3_16x4x1_16x4x1_1x16x1x4_4x4x1_1x1_intrawave_v1},
109
+ {{128 ,512 ,2048 ,5120 },bf16_grouped_128x16x64x128_16x16_1x2_16x8x1_16x8x1_1x16x1x8_8x8x1_1x2_intrawave_v2},
110
+ {{128 ,512 ,5120 ,1024 },bf16_grouped_128x16x64x128_16x16_1x2_16x8x1_16x8x1_1x16x1x8_8x8x1_1x2_interwave_v2},
111
+ {{128 ,512 ,16384 ,5120 },bf16_grouped_256x16x128x128_16x16_1x2_16x16x1_16x16x1_1x16x1x16_8x8x1_1x2_intrawave_v2},
112
+ {{128 ,512 ,5120 ,8192 },bf16_grouped_128x16x96x128_16x16_1x3_16x8x1_16x8x1_1x16x1x8_4x4x1_1x1_interwave_v2},
113
+ {{128 ,1024 ,2048 ,5120 },bf16_grouped_128x16x96x128_16x16_1x3_16x8x1_16x8x1_1x16x1x8_4x4x1_1x1_intrawave_v2},
114
+ {{128 ,1024 ,5120 ,1024 },bf16_grouped_128x16x64x128_16x16_1x2_16x8x1_16x8x1_1x16x1x8_8x8x1_1x2_intrawave_v2},
115
+ {{128 ,1024 ,16384 ,5120 },bf16_grouped_256x16x128x128_16x16_1x2_16x16x1_16x16x1_1x16x1x16_8x8x1_1x2_intrawave_v2},
116
+ {{128 ,1024 ,5120 ,8192 },bf16_grouped_128x16x96x128_16x16_1x3_16x8x1_16x8x1_1x16x1x8_4x4x1_1x1_interwave_v2},
117
+ {{128 ,2048 ,2048 ,5120 },bf16_grouped_128x16x96x128_16x16_1x3_16x8x1_16x8x1_1x16x1x8_4x4x1_1x1_intrawave_v2},
118
+ {{128 ,2048 ,5120 ,1024 },bf16_grouped_128x16x32x128_16x16_1x1_16x8x1_16x8x1_1x16x1x8_4x4x1_1x1_interwave_v2},
119
+ {{128 ,2048 ,16384 ,5120 },bf16_grouped_256x16x128x128_16x16_1x2_16x16x1_16x16x1_1x16x1x16_8x8x1_1x2_intrawave_v2},
120
+ {{128 ,2048 ,5120 ,8192 },bf16_grouped_128x16x96x128_16x16_1x3_16x8x1_16x8x1_1x16x1x8_4x4x1_1x1_intrawave_v1},
121
+ {{128 ,4096 ,2048 ,5120 },bf16_grouped_128x32x64x128_32x32_1x1_16x8x1_16x8x1_1x16x1x8_8x8x1_1x1_intrawave_v1},
122
+ {{128 ,4096 ,5120 ,1024 },bf16_grouped_128x32x64x128_32x32_1x1_16x8x1_16x8x1_1x16x1x8_8x8x1_1x1_interwave_v2},
123
+ {{128 ,4096 ,16384 ,5120 },bf16_grouped_256x32x128x128_16x16_1x4_16x16x1_16x16x1_1x32x1x8_8x8x1_1x2_intrawave_v2},
124
+ {{128 ,4096 ,5120 ,8192 },bf16_grouped_256x32x224x64_16x16_1x7_8x32x1_8x32x1_1x32x1x8_4x4x1_1x1_intrawave_v1},
125
+ {{128 ,8192 ,2048 ,5120 },bf16_grouped_256x64x192x128_16x16_4x3_16x16x1_16x16x1_1x32x1x8_8x8x1_2x1_intrawave_v3},
126
+ {{128 ,8192 ,5120 ,1024 },bf16_grouped_128x64x128x64_32x32_2x2_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_intrawave_v3},
127
+ {{128 ,8192 ,16384 ,5120 },bf16_grouped_256x64x192x128_16x16_4x3_16x16x1_16x16x1_1x32x1x8_8x8x1_2x1_intrawave_v3},
128
+ {{128 ,8192 ,5120 ,8192 },bf16_grouped_256x64x192x128_16x16_4x3_16x16x1_16x16x1_1x32x1x8_8x8x1_2x1_intrawave_v3},
129
+ };
130
+
131
+
132
+
133
+ // Helper function to return the next largest power of 2
134
+ static constexpr int nextPow2 (unsigned int num)
135
+ {
136
+ if (num <= 1 )
137
+ return 1 ;
138
+ return 1 << (CHAR_BIT * sizeof (num) - __builtin_clz (num - 1 ));
139
+ }
140
+
141
+ GroupedKernel grouped_heuristic_dispatch (int G, int total_M, int N, int K) {
142
+ // We use shape heuristics to find the best kernel.
143
+ // To do this, we divide by the size of M and find the best
144
+ // option within that grouping.
145
+
146
+ // First check if this shape is available in the direct lookup.
147
+ int padded_m = nextPow2 (total_M);
148
+ padded_m = padded_m < G ? G : padded_m;
149
+ padded_m = padded_m > 8192 ? 8192 : padded_m;
150
+ auto it = bf16_grouped_lookup_dispatch.find ({G, padded_m, N, K});
151
+ // If we found an optimal kernel, use it.
152
+ if (it != bf16_grouped_lookup_dispatch.end ()) {
153
+ return it->second ;
75
154
}
155
+
76
156
// Default kernel for all other shapes.
77
- return bf16_grouped_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_interwave_v1 ;
157
+ return bf16_grouped_256x128x128x64_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_interwave_v1 ;
78
158
}
79
159
80
160
__global__ void set_kernel_args_kernel (
@@ -343,7 +423,7 @@ std::vector<at::Tensor> bf16bf16bf16_grouped(
343
423
MaxN = max (MaxN, B[i].size (0 ));
344
424
MaxK = max (MaxK, A[i].size (1 ));
345
425
}
346
- GroupedKernel selected_kernel = grouped_heuristic_dispatch (MaxM, MaxN, MaxK);
426
+ GroupedKernel selected_kernel = grouped_heuristic_dispatch (group_count, MaxM, MaxN, MaxK);
347
427
return selected_kernel (A, B, kernel_args, Y);
348
428
}
349
429
@@ -418,7 +498,7 @@ at::Tensor bf16bf16bf16_grouped_dynamic(
418
498
MaxN = max (MaxN, B[i].size (0 ));
419
499
MaxK = max (MaxK, A[i].size (1 ));
420
500
}
421
- GroupedKernel selected_kernel = grouped_heuristic_dispatch (MaxM, MaxN, MaxK);
501
+ GroupedKernel selected_kernel = grouped_heuristic_dispatch (group_count, MaxM, MaxN, MaxK);
422
502
// Run kernel to populate output.
423
503
selected_kernel (A, B, kernel_args, Y);
424
504
// Return coalesced view of output tensor.
0 commit comments