Skip to content

Commit 5658f48

Browse files
mxz297facebook-github-bot
authored andcommitted
new tuning for fp8 rowwise (pytorch#838)
Summary: X-link: pytorch#3756 Pull Request resolved: facebookresearch/FBGEMM#838 As title Reviewed By: jwfromm Differential Revision: D70494396 fbshipit-source-id: d1d77676138ead6c9653928c28e873fed5c56e59
1 parent fd7dc4d commit 5658f48

File tree

30 files changed

+1699
-255
lines changed

30 files changed

+1699
-255
lines changed

fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise/fp8_rowwise_gemm.hip

Lines changed: 73 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -327,14 +327,86 @@ static const std::map<int, RowwiseKernel> N_5120_K_640_dispatch_table = {
327327
{ 5984, fp8_rowwise_256x224x256x128_16x16_7x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3},
328328
};
329329

330+
static const std::map<int, RowwiseKernel> N_4096_K_5120_dispatch_table = {
331+
{ 16, fp8_rowwise_128x16x32x512_16x16_1x1_32x4x1_32x4x1_1x16x1x8_4x4x1_1x1_interwave_v2},
332+
{ 32, fp8_rowwise_128x32x16x512_16x16_1x1_32x4x1_32x4x1_1x32x1x4_4x4x1_1x1_interwave_v2},
333+
{ 48, fp8_rowwise_256x16x64x512_16x16_1x1_32x8x1_32x8x1_1x16x1x16_4x4x1_1x1_intrawave_v3},
334+
{ 128, fp8_rowwise_256x32x64x512_16x16_1x2_32x8x1_32x8x1_1x32x1x8_8x8x1_1x2_intrawave_v3},
335+
{ 256, fp8_rowwise_256x64x64x512_32x32_1x1_32x8x1_32x8x1_1x32x1x8_8x8x1_1x1_intrawave_v3},
336+
{ 288, fp8_rowwise_256x32x128x256_32x32_1x1_16x16x1_16x16x1_1x32x1x8_8x8x1_1x1_intrawave_v3},
337+
{ 576, fp8_rowwise_256x64x128x256_32x32_1x2_16x16x1_16x16x1_1x32x1x8_8x8x1_1x1_intrawave_v3},
338+
{ 896, fp8_rowwise_256x128x96x128_16x16_4x3_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3},
339+
{ 1152, fp8_rowwise_256x128x128x128_16x16_4x4_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3},
340+
{ 1392, fp8_rowwise_256x128x160x128_16x16_4x5_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3},
341+
{ 1440, fp8_rowwise_256x160x128x128_16x16_5x4_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3},
342+
{ 1776, fp8_rowwise_256x128x96x128_16x16_4x3_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3},
343+
{ 1824, fp8_rowwise_256x96x128x128_16x16_3x4_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3},
344+
{ 2240, fp8_rowwise_256x160x96x128_16x16_5x3_8x32x1_8x32x1_1x32x1x8_4x4x1_1x1_intrawave_v3},
345+
{ 2496, fp8_rowwise_256x192x192x128_16x16_6x6_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3},
346+
{ 2816, fp8_rowwise_256x256x160x128_16x16_8x5_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3},
347+
{ 2896, fp8_rowwise_256x224x192x128_16x16_7x6_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3},
348+
{ 3040, fp8_rowwise_256x160x256x128_16x16_5x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3},
349+
{ 3072, fp8_rowwise_256x192x224x128_16x16_6x7_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3},
350+
{ 3328, fp8_rowwise_256x256x192x128_16x16_8x6_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3},
351+
{ 3648, fp8_rowwise_256x192x256x128_16x16_6x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3},
352+
{ 4096, fp8_rowwise_256x256x224x128_16x16_8x7_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3},
353+
{ 4256, fp8_rowwise_256x224x256x128_16x16_7x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3},
354+
{ 4832, fp8_rowwise_256x256x256x64_32x32_4x4_4x64x1_4x64x1_1x32x1x8_8x8x1_1x1_intrawave_v4},
355+
{ 4864, fp8_rowwise_256x256x256x128_16x16_8x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3},
356+
{ 5152, fp8_rowwise_256x224x160x128_16x16_7x5_8x32x1_8x32x1_1x32x1x8_4x4x1_1x1_intrawave_v3},
357+
{ 5184, fp8_rowwise_256x192x192x128_16x16_6x6_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3},
358+
{ 5888, fp8_rowwise_256x256x160x128_16x16_8x5_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3},
359+
{ 5920, fp8_rowwise_256x160x256x128_16x16_5x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3},
360+
{ 5984, fp8_rowwise_256x224x192x128_16x16_7x6_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3},
361+
};
362+
363+
static const std::map<int, RowwiseKernel> N_5120_K_2048_dispatch_table = {
364+
{ 48, fp8_rowwise_256x16x64x512_16x16_1x1_32x8x1_32x8x1_1x16x1x16_4x4x1_1x1_intrawave_v2},
365+
{ 96, fp8_rowwise_256x32x64x512_16x16_1x2_32x8x1_32x8x1_1x32x1x8_8x8x1_1x2_intrawave_v3},
366+
{ 192, fp8_rowwise_256x64x64x512_32x32_1x1_32x8x1_32x8x1_1x32x1x8_8x8x1_1x1_intrawave_v3},
367+
{ 224, fp8_rowwise_256x32x128x256_32x32_1x1_16x16x1_16x16x1_1x32x1x8_8x8x1_1x1_intrawave_v3},
368+
{ 384, fp8_rowwise_256x128x64x256_32x32_2x1_16x16x1_16x16x1_1x32x1x8_8x8x1_1x1_intrawave_v3},
369+
{ 448, fp8_rowwise_256x64x128x256_32x32_1x2_16x16x1_16x16x1_1x32x1x8_8x8x1_1x1_intrawave_v3},
370+
{ 560, fp8_rowwise_256x80x128x256_16x16_5x2_16x16x1_16x16x1_1x16x1x16_8x8x1_1x2_intrawave_v3},
371+
{ 608, fp8_rowwise_256x128x96x128_16x16_4x3_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3},
372+
{ 672, fp8_rowwise_256x96x128x128_16x16_3x4_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3},
373+
{ 896, fp8_rowwise_256x128x128x128_16x16_4x4_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3},
374+
{ 1008, fp8_rowwise_256x128x160x128_16x16_4x5_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3},
375+
{ 1120, fp8_rowwise_256x160x128x128_16x16_5x4_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3},
376+
{ 1408, fp8_rowwise_256x128x96x128_16x16_4x3_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3},
377+
{ 1440, fp8_rowwise_256x96x128x128_16x16_3x4_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3},
378+
{ 1536, fp8_rowwise_256x128x128x128_16x16_4x4_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3},
379+
{ 1600, fp8_rowwise_256x160x96x128_16x16_5x3_8x32x1_8x32x1_1x32x1x8_4x4x1_1x1_intrawave_v3},
380+
{ 1920, fp8_rowwise_256x128x128x128_16x16_4x4_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3},
381+
{ 2112, fp8_rowwise_256x192x192x128_16x16_6x6_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3},
382+
{ 2400, fp8_rowwise_256x160x256x128_16x16_5x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3},
383+
{ 2464, fp8_rowwise_256x224x192x128_16x16_7x6_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3},
384+
{ 2496, fp8_rowwise_256x192x224x128_16x16_6x7_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3},
385+
{ 2816, fp8_rowwise_256x256x192x128_16x16_8x6_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3},
386+
{ 2880, fp8_rowwise_256x192x256x128_16x16_6x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3},
387+
{ 3328, fp8_rowwise_256x256x224x128_16x16_8x7_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3},
388+
{ 3360, fp8_rowwise_256x224x256x128_16x16_7x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3},
389+
{ 3840, fp8_rowwise_256x256x256x128_16x16_8x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3},
390+
{ 4224, fp8_rowwise_256x192x192x128_16x16_6x6_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3},
391+
{ 4736, fp8_rowwise_256x128x128x128_16x16_4x4_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3},
392+
{ 4864, fp8_rowwise_256x256x160x128_16x16_8x5_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3},
393+
{ 4928, fp8_rowwise_256x224x192x128_16x16_7x6_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3},
394+
{ 4992, fp8_rowwise_256x192x224x128_16x16_6x7_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3},
395+
{ 5632, fp8_rowwise_256x256x192x128_16x16_8x6_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3},
396+
{ 5760, fp8_rowwise_256x192x256x128_16x16_6x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3},
397+
{ 5984, fp8_rowwise_256x256x224x128_16x16_8x7_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3},
398+
};
399+
330400
static const std::unordered_map<std::tuple<int, int>, NKLookupTableType, IntTupleHash> NK_lookup_table = {
331401
{{7168, 8192}, N_7168_K_8192_dispatch_table},
332402
{{8192, 3584}, N_8192_K_3584_dispatch_table},
333403
{{1024, 5120}, N_1024_K_5120_dispatch_table},
334404
{{5120, 1024}, N_5120_K_1024_dispatch_table},
335405
{{2048, 5120}, N_2048_K_5120_dispatch_table},
336406
{{896, 5120}, N_896_K_5120_dispatch_table},
337-
{{5120, 640}, N_5120_K_640_dispatch_table}
407+
{{5120, 640}, N_5120_K_640_dispatch_table},
408+
{{4096, 5120}, N_4096_K_5120_dispatch_table},
409+
{{5120, 2048}, N_5120_K_2048_dispatch_table}
338410
};
339411

340412
RowwiseKernel rowwise_nk_lookup(int M, const NKLookupTableType& table) {

fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise/kernels/fp8_rowwise_128x16x32x512_16x16_1x1_32x4x1_32x4x1_1x16x1x8_4x4x1_1x1_interwave_v2.hip

Lines changed: 45 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -15,25 +15,50 @@ fp8_rowwise_128x16x32x512_16x16_1x1_32x4x1_32x4x1_1x16x1x8_4x4x1_1x1_interwave_v
1515
at::Tensor x_scale,
1616
at::Tensor w_scale,
1717
at::Tensor Y) {
18-
using DeviceGemmInstance = DeviceGemmHelper<
19-
128,
20-
16,
21-
32,
22-
512,
23-
16,
24-
16,
25-
1,
26-
1,
27-
S<32, 4, 1>,
28-
S<32, 4, 1>,
29-
S<1, 16, 1, 8>,
30-
S<4, 4, 1>,
31-
1,
32-
1,
33-
ck::BlockGemmPipelineScheduler::Interwave,
34-
ck::BlockGemmPipelineVersion::v2,
35-
ck::tensor_operation::device::GemmSpecialization::Default>;
36-
// Run kernel instance.
37-
return f8f8bf16_rowwise_impl<DeviceGemmInstance>(XQ, WQ, x_scale, w_scale, Y);
18+
int K = WQ.size(1);
19+
bool kpad = (K % 512 != 0);
20+
if (kpad) {
21+
using DeviceGemmInstance = DeviceGemmHelper<
22+
128,
23+
16,
24+
32,
25+
512,
26+
16,
27+
16,
28+
1,
29+
1,
30+
S<32, 4, 1>,
31+
S<32, 4, 1>,
32+
S<1, 16, 1, 8>,
33+
S<4, 4, 1>,
34+
1,
35+
1,
36+
ck::BlockGemmPipelineScheduler::Interwave,
37+
ck::BlockGemmPipelineVersion::v2,
38+
ck::tensor_operation::device::GemmSpecialization::KPadding>;
39+
// Run kernel instance.
40+
return f8f8bf16_rowwise_impl<DeviceGemmInstance>(XQ, WQ, x_scale, w_scale, Y);
41+
} else {
42+
using DeviceGemmInstance = DeviceGemmHelper<
43+
128,
44+
16,
45+
32,
46+
512,
47+
16,
48+
16,
49+
1,
50+
1,
51+
S<32, 4, 1>,
52+
S<32, 4, 1>,
53+
S<1, 16, 1, 8>,
54+
S<4, 4, 1>,
55+
1,
56+
1,
57+
ck::BlockGemmPipelineScheduler::Interwave,
58+
ck::BlockGemmPipelineVersion::v2,
59+
ck::tensor_operation::device::GemmSpecialization::Default>;
60+
// Run kernel instance.
61+
return f8f8bf16_rowwise_impl<DeviceGemmInstance>(XQ, WQ, x_scale, w_scale, Y);
62+
}
3863
}
3964

fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise/kernels/fp8_rowwise_128x32x16x512_16x16_1x1_32x4x1_32x4x1_1x32x1x4_4x4x1_1x1_interwave_v2.hip

Lines changed: 45 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -15,25 +15,50 @@ fp8_rowwise_128x32x16x512_16x16_1x1_32x4x1_32x4x1_1x32x1x4_4x4x1_1x1_interwave_v
1515
at::Tensor x_scale,
1616
at::Tensor w_scale,
1717
at::Tensor Y) {
18-
using DeviceGemmInstance = DeviceGemmHelper<
19-
128,
20-
32,
21-
16,
22-
512,
23-
16,
24-
16,
25-
1,
26-
1,
27-
S<32, 4, 1>,
28-
S<32, 4, 1>,
29-
S<1, 32, 1, 4>,
30-
S<4, 4, 1>,
31-
1,
32-
1,
33-
ck::BlockGemmPipelineScheduler::Interwave,
34-
ck::BlockGemmPipelineVersion::v2,
35-
ck::tensor_operation::device::GemmSpecialization::Default>;
36-
// Run kernel instance.
37-
return f8f8bf16_rowwise_impl<DeviceGemmInstance>(XQ, WQ, x_scale, w_scale, Y);
18+
int K = WQ.size(1);
19+
bool kpad = (K % 512 != 0);
20+
if (kpad) {
21+
using DeviceGemmInstance = DeviceGemmHelper<
22+
128,
23+
32,
24+
16,
25+
512,
26+
16,
27+
16,
28+
1,
29+
1,
30+
S<32, 4, 1>,
31+
S<32, 4, 1>,
32+
S<1, 32, 1, 4>,
33+
S<4, 4, 1>,
34+
1,
35+
1,
36+
ck::BlockGemmPipelineScheduler::Interwave,
37+
ck::BlockGemmPipelineVersion::v2,
38+
ck::tensor_operation::device::GemmSpecialization::KPadding>;
39+
// Run kernel instance.
40+
return f8f8bf16_rowwise_impl<DeviceGemmInstance>(XQ, WQ, x_scale, w_scale, Y);
41+
} else {
42+
using DeviceGemmInstance = DeviceGemmHelper<
43+
128,
44+
32,
45+
16,
46+
512,
47+
16,
48+
16,
49+
1,
50+
1,
51+
S<32, 4, 1>,
52+
S<32, 4, 1>,
53+
S<1, 32, 1, 4>,
54+
S<4, 4, 1>,
55+
1,
56+
1,
57+
ck::BlockGemmPipelineScheduler::Interwave,
58+
ck::BlockGemmPipelineVersion::v2,
59+
ck::tensor_operation::device::GemmSpecialization::Default>;
60+
// Run kernel instance.
61+
return f8f8bf16_rowwise_impl<DeviceGemmInstance>(XQ, WQ, x_scale, w_scale, Y);
62+
}
3863
}
3964

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include "fp8_rowwise_common.h"
10+
11+
at::Tensor
12+
fp8_rowwise_256x128x128x128_16x16_4x4_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3(
13+
at::Tensor XQ,
14+
at::Tensor WQ,
15+
at::Tensor x_scale,
16+
at::Tensor w_scale,
17+
at::Tensor Y) {
18+
int K = WQ.size(1);
19+
bool kpad = (K % 128 != 0);
20+
if (kpad) {
21+
using DeviceGemmInstance = DeviceGemmHelper<
22+
256,
23+
128,
24+
128,
25+
128,
26+
16,
27+
16,
28+
4,
29+
4,
30+
S<8, 32, 1>,
31+
S<8, 32, 1>,
32+
S<1, 32, 1, 8>,
33+
S<8, 8, 1>,
34+
1,
35+
2,
36+
ck::BlockGemmPipelineScheduler::Intrawave,
37+
ck::BlockGemmPipelineVersion::v3,
38+
ck::tensor_operation::device::GemmSpecialization::KPadding>;
39+
// Run kernel instance.
40+
return f8f8bf16_rowwise_impl<DeviceGemmInstance>(XQ, WQ, x_scale, w_scale, Y);
41+
} else {
42+
using DeviceGemmInstance = DeviceGemmHelper<
43+
256,
44+
128,
45+
128,
46+
128,
47+
16,
48+
16,
49+
4,
50+
4,
51+
S<8, 32, 1>,
52+
S<8, 32, 1>,
53+
S<1, 32, 1, 8>,
54+
S<8, 8, 1>,
55+
1,
56+
2,
57+
ck::BlockGemmPipelineScheduler::Intrawave,
58+
ck::BlockGemmPipelineVersion::v3,
59+
ck::tensor_operation::device::GemmSpecialization::Default>;
60+
// Run kernel instance.
61+
return f8f8bf16_rowwise_impl<DeviceGemmInstance>(XQ, WQ, x_scale, w_scale, Y);
62+
}
63+
}
64+
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include "fp8_rowwise_common.h"
10+
11+
at::Tensor
12+
fp8_rowwise_256x128x160x128_16x16_4x5_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3(
13+
at::Tensor XQ,
14+
at::Tensor WQ,
15+
at::Tensor x_scale,
16+
at::Tensor w_scale,
17+
at::Tensor Y) {
18+
int K = WQ.size(1);
19+
bool kpad = (K % 128 != 0);
20+
if (kpad) {
21+
using DeviceGemmInstance = DeviceGemmHelper<
22+
256,
23+
128,
24+
160,
25+
128,
26+
16,
27+
16,
28+
4,
29+
5,
30+
S<8, 32, 1>,
31+
S<8, 32, 1>,
32+
S<1, 64, 1, 4>,
33+
S<8, 8, 1>,
34+
2,
35+
1,
36+
ck::BlockGemmPipelineScheduler::Intrawave,
37+
ck::BlockGemmPipelineVersion::v3,
38+
ck::tensor_operation::device::GemmSpecialization::KPadding>;
39+
// Run kernel instance.
40+
return f8f8bf16_rowwise_impl<DeviceGemmInstance>(XQ, WQ, x_scale, w_scale, Y);
41+
} else {
42+
using DeviceGemmInstance = DeviceGemmHelper<
43+
256,
44+
128,
45+
160,
46+
128,
47+
16,
48+
16,
49+
4,
50+
5,
51+
S<8, 32, 1>,
52+
S<8, 32, 1>,
53+
S<1, 64, 1, 4>,
54+
S<8, 8, 1>,
55+
2,
56+
1,
57+
ck::BlockGemmPipelineScheduler::Intrawave,
58+
ck::BlockGemmPipelineVersion::v3,
59+
ck::tensor_operation::device::GemmSpecialization::Default>;
60+
// Run kernel instance.
61+
return f8f8bf16_rowwise_impl<DeviceGemmInstance>(XQ, WQ, x_scale, w_scale, Y);
62+
}
63+
}
64+

0 commit comments

Comments
 (0)