Skip to content

Commit 045c27a

Browse files
zjing14facebook-github-bot
authored andcommitted
Retuned CK GMM fp8/bf16 with perf fixes (pytorch#3851)
Summary: Pull Request resolved: pytorch#3851 X-link: facebookresearch/FBGEMM#941 - Fixed launch bound for grouped gemm - Retuned fp8 gmm - Retuned fp8/bf16 GMM for 17Bx16/128 with auto-gen instances (D71528034) Reviewed By: mxz297 Differential Revision: D71140320
1 parent 851815d commit 045c27a

File tree

222 files changed

+9903
-10978
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

222 files changed

+9903
-10978
lines changed

fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/bf16_grouped/bf16_grouped_gemm.hip

Lines changed: 119 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -33,48 +33,128 @@ using ADataType = ck::bhalf_t;
3333
using BDataType = ck::bhalf_t;
3434
using CDataType = ck::bhalf_t;
3535

36-
GroupedKernel grouped_heuristic_dispatch(int M, int N, int K) {
37-
// We use shape heuristics to find the best kernel.
38-
// To do this, we divide by the size of M and find the best
39-
// option within that grouping.
40-
if (M <= 16) {
41-
if (N < 8192 && K <= 8192) {
42-
return bf16_grouped_64x16x16x256_16x16_1x1_16x4x1_16x4x1_1x4x1x16_4x4x1_1x1_intrawave_v1;
43-
}
44-
if (K <= 8192) {
45-
return bf16_grouped_128x16x64x128_16x16_1x2_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2;
46-
}
47-
return bf16_grouped_128x16x32x256_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2;
36+
// Define a custom hash function for std::tuple<int, int, int>
37+
struct IntTupleHash {
38+
size_t operator()(const std::tuple<int, int>& t) const {
39+
auto hash1 = std::hash<int>{}(std::get<0>(t));
40+
auto hash2 = std::hash<int>{}(std::get<1>(t));
41+
return hash1 ^ hash2;
4842
}
49-
if (M <= 32) {
50-
if (N < 8192 && K <= 8192) {
51-
return bf16_grouped_128x32x64x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2;
52-
}
53-
if (K <= 8192) {
54-
return bf16_grouped_128x32x128x128_32x32_1x2_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2;
55-
}
56-
return bf16_grouped_128x32x128x128_32x32_1x2_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_intrawave_v2;
57-
}
58-
if (M <= 64) {
59-
return bf16_grouped_256x64x64x128_32x32_1x1_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3;
43+
size_t operator()(const std::tuple<int, int, int>& t) const {
44+
auto hash1 = std::hash<int>{}(std::get<0>(t));
45+
auto hash2 = std::hash<int>{}(std::get<1>(t));
46+
auto hash3 = std::hash<int>{}(std::get<2>(t));
47+
return hash1 ^ hash2 ^ hash3;
6048
}
61-
if (M <= 128) {
62-
if (N < 8192 && K <= 8192) {
63-
return bf16_grouped_256x128x64x128_32x32_2x1_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3;
64-
}
65-
return bf16_grouped_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3;
49+
size_t operator()(const std::tuple<int, int, int, int>& t) const {
50+
auto hash1 = std::hash<int>{}(std::get<0>(t));
51+
auto hash2 = std::hash<int>{}(std::get<1>(t));
52+
auto hash3 = std::hash<int>{}(std::get<2>(t));
53+
auto hash4 = std::hash<int>{}(std::get<3>(t));
54+
return hash1 ^ hash2 ^ hash3 ^ hash4;
6655
}
67-
if (M <= 256) {
68-
return bf16_grouped_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3;
69-
}
70-
if (M <= 512) {
71-
if (K <= 8192) {
72-
return bf16_grouped_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_interwave_v1;
73-
}
74-
return bf16_grouped_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3;
56+
};
57+
58+
// For certain high priority shapes, we directly map to the best kernel rather
59+
// than use heuristics.
60+
static const std::unordered_map<std::tuple<int, int, int, int>, GroupedKernel, IntTupleHash> bf16_grouped_lookup_dispatch = {
61+
{{16,16,2048,5120},bf16_grouped_128x16x64x128_16x16_1x2_16x8x1_16x8x1_1x16x1x8_8x8x1_1x2_intrawave_v2},
62+
{{16,16,5120,1024},bf16_grouped_64x16x16x128_16x16_1x1_16x4x1_16x4x1_1x16x1x4_4x4x1_1x1_interwave_v2},
63+
{{16,16,16384,5120},bf16_grouped_64x16x32x128_16x16_1x2_16x4x1_16x4x1_1x16x1x4_8x8x1_1x2_intrawave_v2},
64+
{{16,16,5120,8192},bf16_grouped_128x16x64x128_16x16_1x2_16x8x1_16x8x1_1x16x1x8_8x8x1_1x2_interwave_v1},
65+
{{16,32,2048,5120},bf16_grouped_128x16x64x128_16x16_1x2_16x8x1_16x8x1_1x16x1x8_8x8x1_1x2_interwave_v2},
66+
{{16,32,5120,1024},bf16_grouped_256x16x64x128_16x16_1x1_16x16x1_16x16x1_1x16x1x16_4x4x1_1x1_interwave_v2},
67+
{{16,32,16384,5120},bf16_grouped_128x16x96x128_16x16_1x3_16x8x1_16x8x1_1x16x1x8_4x4x1_1x1_interwave_v1},
68+
{{16,32,5120,8192},bf16_grouped_128x16x96x64_16x16_1x3_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v1},
69+
{{16,64,2048,5120},bf16_grouped_256x16x64x128_16x16_1x1_16x16x1_16x16x1_1x16x1x16_4x4x1_1x1_interwave_v2},
70+
{{16,64,5120,1024},bf16_grouped_256x16x64x128_16x16_1x1_16x16x1_16x16x1_1x16x1x16_4x4x1_1x1_interwave_v2},
71+
{{16,64,16384,5120},bf16_grouped_128x16x96x128_16x16_1x3_16x8x1_16x8x1_1x16x1x8_4x4x1_1x1_interwave_v1},
72+
{{16,64,5120,8192},bf16_grouped_128x16x96x64_16x16_1x3_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v1},
73+
{{16,128,2048,5120},bf16_grouped_128x16x64x128_16x16_1x2_16x8x1_16x8x1_1x16x1x8_8x8x1_1x2_interwave_v2},
74+
{{16,128,5120,1024},bf16_grouped_256x16x64x128_16x16_1x1_16x16x1_16x16x1_1x16x1x16_4x4x1_1x1_interwave_v2},
75+
{{16,128,16384,5120},bf16_grouped_64x16x64x128_16x16_1x4_16x4x1_16x4x1_1x16x1x4_8x8x1_1x2_intrawave_v2},
76+
{{16,128,5120,8192},bf16_grouped_128x16x96x128_16x16_1x3_16x8x1_16x8x1_1x16x1x8_4x4x1_1x1_intrawave_v1},
77+
{{16,256,2048,5120},bf16_grouped_128x16x64x128_16x16_1x2_16x8x1_16x8x1_1x16x1x8_8x8x1_1x2_interwave_v2},
78+
{{16,256,5120,1024},bf16_grouped_128x16x32x128_16x16_1x1_16x8x1_16x8x1_1x16x1x8_4x4x1_1x1_interwave_v2},
79+
{{16,256,16384,5120},bf16_grouped_128x16x96x128_16x16_1x3_16x8x1_16x8x1_1x16x1x8_4x4x1_1x1_intrawave_v1},
80+
{{16,256,5120,8192},bf16_grouped_128x16x96x128_16x16_1x3_16x8x1_16x8x1_1x16x1x8_4x4x1_1x1_interwave_v1},
81+
{{16,512,2048,5120},bf16_grouped_128x32x64x128_32x32_1x1_16x8x1_16x8x1_1x16x1x8_8x8x1_1x1_interwave_v2},
82+
{{16,512,5120,1024},bf16_grouped_256x32x96x64_16x16_1x3_8x32x1_8x32x1_1x32x1x8_4x4x1_1x1_interwave_v2},
83+
{{16,512,16384,5120},bf16_grouped_128x32x96x128_16x16_2x3_16x8x1_16x8x1_1x32x1x4_8x8x1_2x1_intrawave_v2},
84+
{{16,512,5120,8192},bf16_grouped_256x32x96x64_16x16_1x3_8x32x1_8x32x1_1x32x1x8_4x4x1_1x1_interwave_v1},
85+
{{16,1024,2048,5120},bf16_grouped_256x64x128x128_32x32_2x1_16x16x1_16x16x1_1x16x1x16_8x8x1_1x1_intrawave_v3},
86+
{{16,1024,5120,1024},bf16_grouped_256x64x96x64_16x16_2x3_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3},
87+
{{16,1024,16384,5120},bf16_grouped_256x64x192x128_16x16_4x3_16x16x1_16x16x1_1x32x1x8_8x8x1_2x1_intrawave_v3},
88+
{{16,1024,5120,8192},bf16_grouped_128x64x96x64_16x16_4x3_8x16x1_8x16x1_1x32x1x4_8x8x1_2x1_intrawave_v3},
89+
{{16,2048,2048,5120},bf16_grouped_256x128x128x128_32x32_2x2_16x16x1_16x16x1_1x32x1x8_8x8x1_1x1_intrawave_v3},
90+
{{16,2048,5120,1024},bf16_grouped_256x128x96x64_16x16_4x3_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3},
91+
{{16,2048,16384,5120},bf16_grouped_256x128x224x64_16x16_4x7_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3},
92+
{{16,2048,5120,8192},bf16_grouped_256x128x96x64_16x16_4x3_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3},
93+
{{16,4096,2048,5120},bf16_grouped_256x128x256x64_32x32_4x2_8x32x1_8x32x1_1x16x1x16_8x8x1_1x1_intrawave_v3},
94+
{{16,4096,5120,1024},bf16_grouped_256x128x96x64_16x16_4x3_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3},
95+
{{16,4096,16384,5120},bf16_grouped_256x256x224x64_16x16_8x7_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3},
96+
{{16,4096,5120,8192},bf16_grouped_256x256x160x64_16x16_8x5_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3},
97+
{{16,8192,2048,5120},bf16_grouped_256x256x256x64_32x32_4x4_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3},
98+
{{16,8192,5120,1024},bf16_grouped_256x256x192x64_32x32_4x3_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3},
99+
{{16,8192,16384,5120},bf16_grouped_256x256x224x64_16x16_8x7_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3},
100+
{{16,8192,5120,8192},bf16_grouped_256x256x192x64_32x32_4x3_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3},
101+
{{128,128,2048,5120},bf16_grouped_128x16x64x128_16x16_1x2_16x8x1_16x8x1_1x16x1x8_8x8x1_1x2_interwave_v2},
102+
{{128,128,5120,1024},bf16_grouped_256x16x64x128_16x16_1x1_16x16x1_16x16x1_1x16x1x16_4x4x1_1x1_interwave_v2},
103+
{{128,128,16384,5120},bf16_grouped_256x16x128x128_16x16_1x2_16x16x1_16x16x1_1x16x1x16_8x8x1_1x2_intrawave_v2},
104+
{{128,128,5120,8192},bf16_grouped_128x16x96x128_16x16_1x3_16x8x1_16x8x1_1x16x1x8_4x4x1_1x1_interwave_v2},
105+
{{128,256,2048,5120},bf16_grouped_128x16x96x128_16x16_1x3_16x8x1_16x8x1_1x16x1x8_4x4x1_1x1_intrawave_v1},
106+
{{128,256,5120,1024},bf16_grouped_256x16x64x128_16x16_1x1_16x16x1_16x16x1_1x16x1x16_4x4x1_1x1_interwave_v2},
107+
{{128,256,16384,5120},bf16_grouped_256x16x128x128_16x16_1x2_16x16x1_16x16x1_1x16x1x16_8x8x1_1x2_intrawave_v1},
108+
{{128,256,5120,8192},bf16_grouped_64x16x48x128_16x16_1x3_16x4x1_16x4x1_1x16x1x4_4x4x1_1x1_intrawave_v1},
109+
{{128,512,2048,5120},bf16_grouped_128x16x64x128_16x16_1x2_16x8x1_16x8x1_1x16x1x8_8x8x1_1x2_intrawave_v2},
110+
{{128,512,5120,1024},bf16_grouped_128x16x64x128_16x16_1x2_16x8x1_16x8x1_1x16x1x8_8x8x1_1x2_interwave_v2},
111+
{{128,512,16384,5120},bf16_grouped_256x16x128x128_16x16_1x2_16x16x1_16x16x1_1x16x1x16_8x8x1_1x2_intrawave_v2},
112+
{{128,512,5120,8192},bf16_grouped_128x16x96x128_16x16_1x3_16x8x1_16x8x1_1x16x1x8_4x4x1_1x1_interwave_v2},
113+
{{128,1024,2048,5120},bf16_grouped_128x16x96x128_16x16_1x3_16x8x1_16x8x1_1x16x1x8_4x4x1_1x1_intrawave_v2},
114+
{{128,1024,5120,1024},bf16_grouped_128x16x64x128_16x16_1x2_16x8x1_16x8x1_1x16x1x8_8x8x1_1x2_intrawave_v2},
115+
{{128,1024,16384,5120},bf16_grouped_256x16x128x128_16x16_1x2_16x16x1_16x16x1_1x16x1x16_8x8x1_1x2_intrawave_v2},
116+
{{128,1024,5120,8192},bf16_grouped_128x16x96x128_16x16_1x3_16x8x1_16x8x1_1x16x1x8_4x4x1_1x1_interwave_v2},
117+
{{128,2048,2048,5120},bf16_grouped_128x16x96x128_16x16_1x3_16x8x1_16x8x1_1x16x1x8_4x4x1_1x1_intrawave_v2},
118+
{{128,2048,5120,1024},bf16_grouped_128x16x32x128_16x16_1x1_16x8x1_16x8x1_1x16x1x8_4x4x1_1x1_interwave_v2},
119+
{{128,2048,16384,5120},bf16_grouped_256x16x128x128_16x16_1x2_16x16x1_16x16x1_1x16x1x16_8x8x1_1x2_intrawave_v2},
120+
{{128,2048,5120,8192},bf16_grouped_128x16x96x128_16x16_1x3_16x8x1_16x8x1_1x16x1x8_4x4x1_1x1_intrawave_v1},
121+
{{128,4096,2048,5120},bf16_grouped_128x32x64x128_32x32_1x1_16x8x1_16x8x1_1x16x1x8_8x8x1_1x1_intrawave_v1},
122+
{{128,4096,5120,1024},bf16_grouped_128x32x64x128_32x32_1x1_16x8x1_16x8x1_1x16x1x8_8x8x1_1x1_interwave_v2},
123+
{{128,4096,16384,5120},bf16_grouped_256x32x128x128_16x16_1x4_16x16x1_16x16x1_1x32x1x8_8x8x1_1x2_intrawave_v2},
124+
{{128,4096,5120,8192},bf16_grouped_256x32x224x64_16x16_1x7_8x32x1_8x32x1_1x32x1x8_4x4x1_1x1_intrawave_v1},
125+
{{128,8192,2048,5120},bf16_grouped_256x64x192x128_16x16_4x3_16x16x1_16x16x1_1x32x1x8_8x8x1_2x1_intrawave_v3},
126+
{{128,8192,5120,1024},bf16_grouped_128x64x128x64_32x32_2x2_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_intrawave_v3},
127+
{{128,8192,16384,5120},bf16_grouped_256x64x192x128_16x16_4x3_16x16x1_16x16x1_1x32x1x8_8x8x1_2x1_intrawave_v3},
128+
{{128,8192,5120,8192},bf16_grouped_256x64x192x128_16x16_4x3_16x16x1_16x16x1_1x32x1x8_8x8x1_2x1_intrawave_v3},
129+
};
130+
131+
132+
133+
// Helper function to return the next largest power of 2
134+
static constexpr int nextPow2(unsigned int num)
135+
{
136+
if (num <= 1)
137+
return 1;
138+
return 1 << (CHAR_BIT * sizeof(num) - __builtin_clz(num - 1));
139+
}
140+
141+
GroupedKernel grouped_heuristic_dispatch(int G, int total_M, int N, int K) {
142+
// We use shape heuristics to find the best kernel.
143+
// To do this, we divide by the size of M and find the best
144+
// option within that grouping.
145+
146+
// First check if this shape is available in the direct lookup.
147+
int padded_m = nextPow2(total_M);
148+
padded_m = padded_m < G ? G : padded_m;
149+
padded_m = padded_m > 8192 ? 8192 : padded_m;
150+
auto it = bf16_grouped_lookup_dispatch.find({G, padded_m, N, K});
151+
// If we found an optimal kernel, use it.
152+
if (it != bf16_grouped_lookup_dispatch.end()) {
153+
return it->second;
75154
}
155+
76156
// Default kernel for all other shapes.
77-
return bf16_grouped_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_interwave_v1;
157+
return bf16_grouped_256x128x128x64_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_interwave_v1;
78158
}
79159

80160
__global__ void set_kernel_args_kernel(
@@ -343,7 +423,7 @@ std::vector<at::Tensor> bf16bf16bf16_grouped(
343423
MaxN = max(MaxN, B[i].size(0));
344424
MaxK = max(MaxK, A[i].size(1));
345425
}
346-
GroupedKernel selected_kernel = grouped_heuristic_dispatch(MaxM, MaxN, MaxK);
426+
GroupedKernel selected_kernel = grouped_heuristic_dispatch(group_count, MaxM, MaxN, MaxK);
347427
return selected_kernel(A, B, kernel_args, Y);
348428
}
349429

@@ -418,7 +498,7 @@ at::Tensor bf16bf16bf16_grouped_dynamic(
418498
MaxN = max(MaxN, B[i].size(0));
419499
MaxK = max(MaxK, A[i].size(1));
420500
}
421-
GroupedKernel selected_kernel = grouped_heuristic_dispatch(MaxM, MaxN, MaxK);
501+
GroupedKernel selected_kernel = grouped_heuristic_dispatch(group_count, MaxM, MaxN, MaxK);
422502
// Run kernel to populate output.
423503
selected_kernel(A, B, kernel_args, Y);
424504
// Return coalesced view of output tensor.

fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/bf16_grouped/kernels/bf16_grouped_128x128x16x128_16x16_4x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_interwave_v2.hip

Lines changed: 0 additions & 70 deletions
This file was deleted.

fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/bf16_grouped/kernels/bf16_grouped_128x128x16x128_16x16_4x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_intrawave_v2.hip

Lines changed: 0 additions & 70 deletions
This file was deleted.

fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/bf16_grouped/kernels/bf16_grouped_128x128x16x128_16x16_4x1_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2.hip

Lines changed: 0 additions & 38 deletions
This file was deleted.

0 commit comments

Comments
 (0)