Skip to content

Commit a4ce24b

Browse files
zjing14facebook-github-bot
authored andcommitted
Retuned CK GMM fp8/bf16 with perf fixes (#3851)
Summary: Pull Request resolved: #3851 X-link: facebookresearch/FBGEMM#941 - Fixed launch bound for grouped gemm - Retuned fp8 gmm - Retuned fp8/bf16 GMM for 17Bx16/128 with auto-gen instances (D71528034) Reviewed By: mxz297 Differential Revision: D71140320
1 parent 6a6db7c commit a4ce24b

File tree

222 files changed

+9907
-10883
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

222 files changed

+9907
-10883
lines changed

fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/bf16_grouped/bf16_grouped_gemm.hip

Lines changed: 121 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,126 @@ using ADataType = ck::bhalf_t;
3333
using BDataType = ck::bhalf_t;
3434
using CDataType = ck::bhalf_t;
3535

36-
GroupedKernel grouped_heuristic_dispatch(int M, int N, int K) {
36+
// Define a custom hash function for std::tuple<int, int, int>
37+
struct IntTupleHash {
38+
size_t operator()(const std::tuple<int, int>& t) const {
39+
auto hash1 = std::hash<int>{}(std::get<0>(t));
40+
auto hash2 = std::hash<int>{}(std::get<1>(t));
41+
return hash1 ^ hash2;
42+
}
43+
size_t operator()(const std::tuple<int, int, int>& t) const {
44+
auto hash1 = std::hash<int>{}(std::get<0>(t));
45+
auto hash2 = std::hash<int>{}(std::get<1>(t));
46+
auto hash3 = std::hash<int>{}(std::get<2>(t));
47+
return hash1 ^ hash2 ^ hash3;
48+
}
49+
size_t operator()(const std::tuple<int, int, int, int>& t) const {
50+
auto hash1 = std::hash<int>{}(std::get<0>(t));
51+
auto hash2 = std::hash<int>{}(std::get<1>(t));
52+
auto hash3 = std::hash<int>{}(std::get<2>(t));
53+
auto hash4 = std::hash<int>{}(std::get<3>(t));
54+
return hash1 ^ hash2 ^ hash3 ^ hash4;
55+
}
56+
};
57+
58+
// For certain high priority shapes, we directly map to the best kernel rather
59+
// than use heuristics.
60+
static const std::unordered_map<std::tuple<int, int, int, int>, GroupedKernel, IntTupleHash> bf16_grouped_lookup_dispatch = {
61+
{{16,16,2048,5120},bf16_grouped_128x16x64x128_16x16_1x2_16x8x1_16x8x1_1x16x1x8_8x8x1_1x2_intrawave_v2},
62+
{{16,16,5120,1024},bf16_grouped_64x16x16x128_16x16_1x1_16x4x1_16x4x1_1x16x1x4_4x4x1_1x1_interwave_v2},
63+
{{16,16,16384,5120},bf16_grouped_64x16x32x128_16x16_1x2_16x4x1_16x4x1_1x16x1x4_8x8x1_1x2_intrawave_v2},
64+
{{16,16,5120,8192},bf16_grouped_128x16x64x128_16x16_1x2_16x8x1_16x8x1_1x16x1x8_8x8x1_1x2_interwave_v1},
65+
{{16,32,2048,5120},bf16_grouped_128x16x64x128_16x16_1x2_16x8x1_16x8x1_1x16x1x8_8x8x1_1x2_interwave_v2},
66+
{{16,32,5120,1024},bf16_grouped_256x16x64x128_16x16_1x1_16x16x1_16x16x1_1x16x1x16_4x4x1_1x1_interwave_v2},
67+
{{16,32,16384,5120},bf16_grouped_128x16x96x128_16x16_1x3_16x8x1_16x8x1_1x16x1x8_4x4x1_1x1_interwave_v1},
68+
{{16,32,5120,8192},bf16_grouped_128x16x96x64_16x16_1x3_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v1},
69+
{{16,64,2048,5120},bf16_grouped_256x16x64x128_16x16_1x1_16x16x1_16x16x1_1x16x1x16_4x4x1_1x1_interwave_v2},
70+
{{16,64,5120,1024},bf16_grouped_256x16x64x128_16x16_1x1_16x16x1_16x16x1_1x16x1x16_4x4x1_1x1_interwave_v2},
71+
{{16,64,16384,5120},bf16_grouped_128x16x96x128_16x16_1x3_16x8x1_16x8x1_1x16x1x8_4x4x1_1x1_interwave_v1},
72+
{{16,64,5120,8192},bf16_grouped_128x16x96x64_16x16_1x3_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v1},
73+
{{16,128,2048,5120},bf16_grouped_128x16x64x128_16x16_1x2_16x8x1_16x8x1_1x16x1x8_8x8x1_1x2_interwave_v2},
74+
{{16,128,5120,1024},bf16_grouped_256x16x64x128_16x16_1x1_16x16x1_16x16x1_1x16x1x16_4x4x1_1x1_interwave_v2},
75+
{{16,128,16384,5120},bf16_grouped_64x16x64x128_16x16_1x4_16x4x1_16x4x1_1x16x1x4_8x8x1_1x2_intrawave_v2},
76+
{{16,128,5120,8192},bf16_grouped_128x16x96x128_16x16_1x3_16x8x1_16x8x1_1x16x1x8_4x4x1_1x1_intrawave_v1},
77+
{{16,256,2048,5120},bf16_grouped_128x16x64x128_16x16_1x2_16x8x1_16x8x1_1x16x1x8_8x8x1_1x2_interwave_v2},
78+
{{16,256,5120,1024},bf16_grouped_128x16x32x128_16x16_1x1_16x8x1_16x8x1_1x16x1x8_4x4x1_1x1_interwave_v2},
79+
{{16,256,16384,5120},bf16_grouped_128x16x96x128_16x16_1x3_16x8x1_16x8x1_1x16x1x8_4x4x1_1x1_intrawave_v1},
80+
{{16,256,5120,8192},bf16_grouped_128x16x96x128_16x16_1x3_16x8x1_16x8x1_1x16x1x8_4x4x1_1x1_interwave_v1},
81+
{{16,512,2048,5120},bf16_grouped_128x32x64x128_32x32_1x1_16x8x1_16x8x1_1x16x1x8_8x8x1_1x1_interwave_v2},
82+
{{16,512,5120,1024},bf16_grouped_256x32x96x64_16x16_1x3_8x32x1_8x32x1_1x32x1x8_4x4x1_1x1_interwave_v2},
83+
{{16,512,16384,5120},bf16_grouped_128x32x96x128_16x16_2x3_16x8x1_16x8x1_1x32x1x4_8x8x1_2x1_intrawave_v2},
84+
{{16,512,5120,8192},bf16_grouped_256x32x96x64_16x16_1x3_8x32x1_8x32x1_1x32x1x8_4x4x1_1x1_interwave_v1},
85+
{{16,1024,2048,5120},bf16_grouped_256x64x128x128_32x32_2x1_16x16x1_16x16x1_1x16x1x16_8x8x1_1x1_intrawave_v3},
86+
{{16,1024,5120,1024},bf16_grouped_256x64x96x64_16x16_2x3_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3},
87+
{{16,1024,16384,5120},bf16_grouped_256x64x192x128_16x16_4x3_16x16x1_16x16x1_1x32x1x8_8x8x1_2x1_intrawave_v3},
88+
{{16,1024,5120,8192},bf16_grouped_128x64x96x64_16x16_4x3_8x16x1_8x16x1_1x32x1x4_8x8x1_2x1_intrawave_v3},
89+
{{16,2048,2048,5120},bf16_grouped_256x128x128x128_32x32_2x2_16x16x1_16x16x1_1x32x1x8_8x8x1_1x1_intrawave_v3},
90+
{{16,2048,5120,1024},bf16_grouped_256x128x96x64_16x16_4x3_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3},
91+
{{16,2048,16384,5120},bf16_grouped_256x128x224x64_16x16_4x7_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3},
92+
{{16,2048,5120,8192},bf16_grouped_256x128x96x64_16x16_4x3_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3},
93+
{{16,4096,2048,5120},bf16_grouped_256x128x256x64_32x32_4x2_8x32x1_8x32x1_1x16x1x16_8x8x1_1x1_intrawave_v3},
94+
{{16,4096,5120,1024},bf16_grouped_256x128x96x64_16x16_4x3_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3},
95+
{{16,4096,16384,5120},bf16_grouped_256x256x224x64_16x16_8x7_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3},
96+
{{16,4096,5120,8192},bf16_grouped_256x256x160x64_16x16_8x5_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3},
97+
{{16,8192,2048,5120},bf16_grouped_256x256x256x64_32x32_4x4_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3},
98+
{{16,8192,5120,1024},bf16_grouped_256x256x192x64_32x32_4x3_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3},
99+
{{16,8192,16384,5120},bf16_grouped_256x256x224x64_16x16_8x7_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3},
100+
{{16,8192,5120,8192},bf16_grouped_256x256x192x64_32x32_4x3_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3},
101+
{{128,128,2048,5120},bf16_grouped_128x16x64x128_16x16_1x2_16x8x1_16x8x1_1x16x1x8_8x8x1_1x2_interwave_v2},
102+
{{128,128,5120,1024},bf16_grouped_256x16x64x128_16x16_1x1_16x16x1_16x16x1_1x16x1x16_4x4x1_1x1_interwave_v2},
103+
{{128,128,16384,5120},bf16_grouped_256x16x128x128_16x16_1x2_16x16x1_16x16x1_1x16x1x16_8x8x1_1x2_intrawave_v2},
104+
{{128,128,5120,8192},bf16_grouped_128x16x96x128_16x16_1x3_16x8x1_16x8x1_1x16x1x8_4x4x1_1x1_interwave_v2},
105+
{{128,256,2048,5120},bf16_grouped_128x16x96x128_16x16_1x3_16x8x1_16x8x1_1x16x1x8_4x4x1_1x1_intrawave_v1},
106+
{{128,256,5120,1024},bf16_grouped_256x16x64x128_16x16_1x1_16x16x1_16x16x1_1x16x1x16_4x4x1_1x1_interwave_v2},
107+
{{128,256,16384,5120},bf16_grouped_256x16x128x128_16x16_1x2_16x16x1_16x16x1_1x16x1x16_8x8x1_1x2_intrawave_v1},
108+
{{128,256,5120,8192},bf16_grouped_64x16x48x128_16x16_1x3_16x4x1_16x4x1_1x16x1x4_4x4x1_1x1_intrawave_v1},
109+
{{128,512,2048,5120},bf16_grouped_128x16x64x128_16x16_1x2_16x8x1_16x8x1_1x16x1x8_8x8x1_1x2_intrawave_v2},
110+
{{128,512,5120,1024},bf16_grouped_128x16x64x128_16x16_1x2_16x8x1_16x8x1_1x16x1x8_8x8x1_1x2_interwave_v2},
111+
{{128,512,16384,5120},bf16_grouped_256x16x128x128_16x16_1x2_16x16x1_16x16x1_1x16x1x16_8x8x1_1x2_intrawave_v2},
112+
{{128,512,5120,8192},bf16_grouped_128x16x96x128_16x16_1x3_16x8x1_16x8x1_1x16x1x8_4x4x1_1x1_interwave_v2},
113+
{{128,1024,2048,5120},bf16_grouped_128x16x96x128_16x16_1x3_16x8x1_16x8x1_1x16x1x8_4x4x1_1x1_intrawave_v2},
114+
{{128,1024,5120,1024},bf16_grouped_128x16x64x128_16x16_1x2_16x8x1_16x8x1_1x16x1x8_8x8x1_1x2_intrawave_v2},
115+
{{128,1024,16384,5120},bf16_grouped_256x16x128x128_16x16_1x2_16x16x1_16x16x1_1x16x1x16_8x8x1_1x2_intrawave_v2},
116+
{{128,1024,5120,8192},bf16_grouped_128x16x96x128_16x16_1x3_16x8x1_16x8x1_1x16x1x8_4x4x1_1x1_interwave_v2},
117+
{{128,2048,2048,5120},bf16_grouped_128x16x96x128_16x16_1x3_16x8x1_16x8x1_1x16x1x8_4x4x1_1x1_intrawave_v2},
118+
{{128,2048,5120,1024},bf16_grouped_128x16x32x128_16x16_1x1_16x8x1_16x8x1_1x16x1x8_4x4x1_1x1_interwave_v2},
119+
{{128,2048,16384,5120},bf16_grouped_256x16x128x128_16x16_1x2_16x16x1_16x16x1_1x16x1x16_8x8x1_1x2_intrawave_v2},
120+
{{128,2048,5120,8192},bf16_grouped_128x16x96x128_16x16_1x3_16x8x1_16x8x1_1x16x1x8_4x4x1_1x1_intrawave_v1},
121+
{{128,4096,2048,5120},bf16_grouped_128x32x64x128_32x32_1x1_16x8x1_16x8x1_1x16x1x8_8x8x1_1x1_intrawave_v1},
122+
{{128,4096,5120,1024},bf16_grouped_128x32x64x128_32x32_1x1_16x8x1_16x8x1_1x16x1x8_8x8x1_1x1_interwave_v2},
123+
{{128,4096,16384,5120},bf16_grouped_256x32x128x128_16x16_1x4_16x16x1_16x16x1_1x32x1x8_8x8x1_1x2_intrawave_v2},
124+
{{128,4096,5120,8192},bf16_grouped_256x32x224x64_16x16_1x7_8x32x1_8x32x1_1x32x1x8_4x4x1_1x1_intrawave_v1},
125+
{{128,8192,2048,5120},bf16_grouped_256x64x192x128_16x16_4x3_16x16x1_16x16x1_1x32x1x8_8x8x1_2x1_intrawave_v3},
126+
{{128,8192,5120,1024},bf16_grouped_128x64x128x64_32x32_2x2_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_intrawave_v3},
127+
{{128,8192,16384,5120},bf16_grouped_256x64x192x128_16x16_4x3_16x16x1_16x16x1_1x32x1x8_8x8x1_2x1_intrawave_v3},
128+
{{128,8192,5120,8192},bf16_grouped_256x64x192x128_16x16_4x3_16x16x1_16x16x1_1x32x1x8_8x8x1_2x1_intrawave_v3},
129+
};
130+
131+
132+
133+
// Helper function to return the next largest power of 2
134+
static constexpr int nextPow2(unsigned int num)
135+
{
136+
if (num <= 1)
137+
return 1;
138+
return 1 << (CHAR_BIT * sizeof(num) - __builtin_clz(num - 1));
139+
}
140+
141+
GroupedKernel grouped_heuristic_dispatch(int G, int total_M, int N, int K) {
37142
// We use shape heuristics to find the best kernel.
38143
// To do this, we divide by the size of M and find the best
39144
// option within that grouping.
145+
146+
// First check if this shape is available in the direct lookup.
147+
int padded_m = nextPow2(total_M);
148+
padded_m = padded_m < G ? G : padded_m;
149+
auto it = bf16_grouped_lookup_dispatch.find({G, padded_m, N, K});
150+
// If we found an optimal kernel, use it.
151+
if (it != bf16_grouped_lookup_dispatch.end()) {
152+
return it->second;
153+
}
154+
155+
#if 0
40156
if (M <= 16) {
41157
if (N < 8192 && K <= 8192) {
42158
return bf16_grouped_64x16x16x256_16x16_1x1_16x4x1_16x4x1_1x4x1x16_4x4x1_1x1_intrawave_v1;
@@ -73,8 +189,9 @@ GroupedKernel grouped_heuristic_dispatch(int M, int N, int K) {
73189
}
74190
return bf16_grouped_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3;
75191
}
192+
#endif
76193
// Default kernel for all other shapes.
77-
return bf16_grouped_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_interwave_v1;
194+
return bf16_grouped_256x128x128x64_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3;
78195
}
79196

80197
__global__ void set_kernel_args_kernel(
@@ -343,7 +460,7 @@ std::vector<at::Tensor> bf16bf16bf16_grouped(
343460
MaxN = max(MaxN, B[i].size(0));
344461
MaxK = max(MaxK, A[i].size(1));
345462
}
346-
GroupedKernel selected_kernel = grouped_heuristic_dispatch(MaxM, MaxN, MaxK);
463+
GroupedKernel selected_kernel = grouped_heuristic_dispatch(group_count, MaxM, MaxN, MaxK);
347464
return selected_kernel(A, B, kernel_args, Y);
348465
}
349466

@@ -418,7 +535,7 @@ at::Tensor bf16bf16bf16_grouped_dynamic(
418535
MaxN = max(MaxN, B[i].size(0));
419536
MaxK = max(MaxK, A[i].size(1));
420537
}
421-
GroupedKernel selected_kernel = grouped_heuristic_dispatch(MaxM, MaxN, MaxK);
538+
GroupedKernel selected_kernel = grouped_heuristic_dispatch(group_count, MaxM, MaxN, MaxK);
422539
// Run kernel to populate output.
423540
selected_kernel(A, B, kernel_args, Y);
424541
// Return coalesced view of output tensor.

fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/bf16_grouped/kernels/bf16_grouped_128x128x16x128_16x16_4x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_interwave_v2.hip

Lines changed: 0 additions & 70 deletions
This file was deleted.

fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/bf16_grouped/kernels/bf16_grouped_128x128x16x128_16x16_4x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_intrawave_v2.hip

Lines changed: 0 additions & 70 deletions
This file was deleted.

fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/bf16_grouped/kernels/bf16_grouped_128x128x16x128_16x16_4x1_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2.hip

Lines changed: 0 additions & 38 deletions
This file was deleted.

0 commit comments

Comments
 (0)