Skip to content

Commit 2e2da01

Browse files
jinzhen-linminpeter
authored andcommitted
[Kernel] fp4 marlin kernel (vllm-project#17687)
Signed-off-by: Jinzhen Lin <linjinzhen@hotmail.com> Signed-off-by: minpeter <kali2005611@gmail.com>
1 parent 0154c6f commit 2e2da01

File tree

21 files changed

+1215
-330
lines changed

21 files changed

+1215
-330
lines changed

csrc/core/scalar_type.hpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -315,6 +315,8 @@ static inline constexpr auto kS8 = ScalarType::int_(8);
315315
static inline constexpr auto kU8 = ScalarType::uint(8);
316316
static inline constexpr auto kU8B128 = ScalarType::uint(8, 128);
317317

318+
static inline constexpr auto kFE2M1f =
319+
ScalarType::float_(2, 1, true, ScalarType::NAN_NONE);
318320
static inline constexpr auto kFE3M2f =
319321
ScalarType::float_(3, 2, true, ScalarType::NAN_NONE);
320322
static inline constexpr auto kFE4M3fn =
@@ -332,6 +334,7 @@ static inline constexpr auto kInt8 = kS8;
332334
static inline constexpr auto kUint8 = kU8;
333335
static inline constexpr auto kUint8b128 = kU8B128;
334336

337+
static inline constexpr auto kFloat4_e2m1f = kFE2M1f;
335338
static inline constexpr auto kFloat6_e3m2f = kFE3M2f;
336339
static inline constexpr auto kFloat8_e4m3fn = kFE4M3fn;
337340
static inline constexpr auto kFloat8_e5m2 = kFE5M2;

csrc/moe/marlin_moe_wna16/generate_kernels.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,15 +31,18 @@
3131

3232
# int8 with zero point case (vllm::kU8) is also supported,
3333
# we don't add it to reduce wheel size.
34-
SCALAR_TYPES = ["vllm::kU4", "vllm::kU4B8", "vllm::kU8B128", "vllm::kFE4M3fn"]
34+
SCALAR_TYPES = [
35+
"vllm::kU4", "vllm::kU4B8", "vllm::kU8B128", "vllm::kFE4M3fn",
36+
"vllm::kFE2M1f"
37+
]
3538
THREAD_CONFIGS = [(128, 128, 256), (64, 256, 256), (64, 128, 128)]
3639

3740
THREAD_M_BLOCKS = [0.5, 1, 2, 3, 4]
3841
# group_blocks:
3942
# = 0 : act order case
4043
# = -1 : channelwise quantization
4144
# > 0 : group_size=16*group_blocks
42-
GROUP_BLOCKS = [0, -1, 2, 4, 8]
45+
GROUP_BLOCKS = [0, -1, 1, 2, 4, 8]
4346
DTYPES = ["fp16", "bf16"]
4447

4548

@@ -72,6 +75,12 @@ def generate_new_kernels():
7275
# for fp8
7376
if scalar_type == "vllm::kFE4M3fn" and group_blocks not in [-1, 8]:
7477
continue
78+
# nvfp4 only supports group_size == 16
79+
if scalar_type == "vllm::kFE2M1f" and group_blocks not in [1, 2]:
80+
continue
81+
# other quantization methods don't support group_size = 16
82+
if scalar_type != "vllm::kFE2M1f" and group_blocks == 1:
83+
continue
7584

7685
k_blocks = thread_configs[0] // 16
7786
n_blocks = thread_configs[1] // 16

csrc/moe/marlin_moe_wna16/kernel.h

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7,17 +7,18 @@
77
#include "quantization/gptq_marlin/marlin_dtypes.cuh"
88
#include "core/scalar_type.hpp"
99

10-
#define MARLIN_KERNEL_PARAMS \
11-
const int4 *__restrict__ A, const int4 *__restrict__ B, \
12-
int4 *__restrict__ C, int4 *__restrict__ C_tmp, \
13-
const int4 *__restrict__ scales_ptr, const int4 *__restrict__ zp_ptr, \
14-
const int *__restrict__ g_idx, \
15-
const int32_t *__restrict__ sorted_token_ids_ptr, \
16-
const int32_t *__restrict__ expert_ids_ptr, \
17-
const int32_t *__restrict__ num_tokens_past_padded_ptr, \
18-
const float *__restrict__ topk_weights_ptr, int top_k, \
19-
bool mul_topk_weights, bool is_ep, int num_groups, int prob_m, \
20-
int prob_n, int prob_k, int *locks, bool use_atomic_add, \
10+
#define MARLIN_KERNEL_PARAMS \
11+
const int4 *__restrict__ A, const int4 *__restrict__ B, \
12+
int4 *__restrict__ C, int4 *__restrict__ C_tmp, \
13+
const int4 *__restrict__ scales_ptr, \
14+
const uint16_t *__restrict__ scale2_ptr, \
15+
const int4 *__restrict__ zp_ptr, const int *__restrict__ g_idx, \
16+
const int32_t *__restrict__ sorted_token_ids_ptr, \
17+
const int32_t *__restrict__ expert_ids_ptr, \
18+
const int32_t *__restrict__ num_tokens_past_padded_ptr, \
19+
const float *__restrict__ topk_weights_ptr, int top_k, \
20+
bool mul_topk_weights, bool is_ep, int num_groups, int prob_m, \
21+
int prob_n, int prob_k, int *locks, bool use_atomic_add, \
2122
bool use_fp32_reduce, int max_shared_mem
2223

2324
namespace MARLIN_NAMESPACE_NAME {

csrc/moe/marlin_moe_wna16/marlin_template.h

Lines changed: 99 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -301,9 +301,11 @@ __global__ void Marlin(
301301
int4* __restrict__ C_tmp, // fp32 tmp output buffer (for reduce)
302302
const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape
303303
// (k/groupsize)xn
304-
const int4* __restrict__ zp_ptr, // 4bit packed zero-points of shape
305-
// (k/groupsize)x(n/pack_factor)
306-
const int* __restrict__ g_idx, // int32 group indices of shape k
304+
const uint16_t* __restrict__ scale2_ptr, // fp16 global scale (for nvfp4
305+
// only)
306+
const int4* __restrict__ zp_ptr, // 4bit packed zero-points of shape
307+
// (k/groupsize)x(n/pack_factor)
308+
const int* __restrict__ g_idx, // int32 group indices of shape k
307309
const int32_t* __restrict__ sorted_token_ids_ptr, // moe sorted_ids
308310
const int32_t* __restrict__ expert_ids_ptr, // moe expert ids
309311
const int32_t* __restrict__ num_tokens_past_padded_ptr, // moe num tokens
@@ -341,14 +343,25 @@ __global__ void Marlin(
341343
extern __shared__ int4 sh[];
342344
static constexpr auto w_type = vllm::ScalarType::from_id(w_type_id);
343345
constexpr bool has_zp = w_type == vllm::kU4 || w_type == vllm::kU8;
346+
constexpr bool is_int_type = w_type == vllm::kU4 || w_type == vllm::kU8 ||
347+
w_type == vllm::kU4B8 || w_type == vllm::kU8B128;
348+
// see comments of dequant.h for more details
349+
constexpr bool dequant_skip_flop =
350+
!is_int_type ||
351+
has_zp && !is_zp_float && !std::is_same<scalar_t, nv_bfloat16>::value ||
352+
has_zp && !is_zp_float && !(w_type == vllm::kU8);
353+
354+
scalar_t2 global_scale;
355+
344356
constexpr bool has_act_order = group_blocks == 0;
345357

346358
constexpr int pack_factor = 32 / w_type.size_bits();
347359
static_assert(thread_m_blocks == 1 || !m_block_size_8);
348360
constexpr int moe_block_size = m_block_size_8 ? 8 : (16 * thread_m_blocks);
349361
const int group_size =
350362
(!has_act_order && group_blocks == -1) ? prob_k : prob_k / num_groups;
351-
const int scales_expert_stride = prob_n * prob_k / group_size / 8;
363+
const int scales_expert_stride =
364+
prob_n * prob_k / group_size / (w_type == vllm::kFE2M1f ? 16 : 8);
352365
const int zp_expert_stride =
353366
is_zp_float ? prob_n * prob_k / group_size / 8
354367
: prob_n * prob_k / group_size / (pack_factor * 4);
@@ -460,9 +473,16 @@ __global__ void Marlin(
460473
if (mul_topk_weights) {
461474
#pragma unroll
462475
for (int i = 0; i < 4; i++) {
463-
sh_block_topk_weights[tid4 * 4 + i] =
464-
Dtype::num2num2(Dtype::float2num(
465-
topk_weights_ptr[sh_block_sorted_ids[tid4 * 4 + i]]));
476+
if constexpr (w_type == vllm::kFE2M1f) {
477+
sh_block_topk_weights[tid4 * 4 + i] = __hmul2(
478+
global_scale,
479+
Dtype::num2num2(Dtype::float2num(
480+
topk_weights_ptr[sh_block_sorted_ids[tid4 * 4 + i]])));
481+
} else {
482+
sh_block_topk_weights[tid4 * 4 + i] =
483+
Dtype::num2num2(Dtype::float2num(
484+
topk_weights_ptr[sh_block_sorted_ids[tid4 * 4 + i]]));
485+
}
466486
}
467487
}
468488
}
@@ -493,6 +513,11 @@ __global__ void Marlin(
493513
expert_id = expert_ids_ptr[block_id];
494514
}
495515

516+
if constexpr (w_type == vllm::kFE2M1f) {
517+
uint16_t val = scale2_ptr[expert_id];
518+
global_scale = Dtype::num2num2(*reinterpret_cast<scalar_t*>(&val));
519+
}
520+
496521
B_expert_off = expert_id * prob_n * prob_k / (pack_factor * 4);
497522
scales_ptr += (expert_id - old_expert_id) * scales_expert_stride;
498523
if constexpr (has_zp) {
@@ -606,7 +631,7 @@ __global__ void Marlin(
606631
constexpr int s_sh_stride = 16 * thread_n_blocks / 8;
607632
constexpr int s_tb_groups =
608633
!has_act_order && group_blocks != -1 && group_blocks < thread_k_blocks
609-
? thread_k_blocks / group_blocks
634+
? thread_k_blocks / group_blocks / (w_type == vllm::kFE2M1f ? 2 : 1)
610635
: 1;
611636
constexpr int s_sh_stage = s_tb_groups * s_sh_stride;
612637
int s_gl_rd_delta = s_gl_stride;
@@ -664,7 +689,8 @@ __global__ void Marlin(
664689
if constexpr (group_blocks == -1) {
665690
s_gl_rd = s_sh_stride * slice_col + threadIdx.x;
666691
} else {
667-
s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) +
692+
s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) /
693+
(w_type == vllm::kFE2M1f ? 2 : 1) +
668694
s_sh_stride * slice_col + threadIdx.x;
669695
}
670696
}
@@ -688,10 +714,20 @@ __global__ void Marlin(
688714
// we scale a `half2` tile in column-major layout in the former and in
689715
// row-major in the latter case.
690716
int s_sh_rd;
691-
if constexpr (group_blocks != -1)
717+
if constexpr (group_blocks != -1 && w_type == vllm::kFE2M1f) {
718+
auto warp_id = threadIdx.x / 32;
719+
int n_warps = thread_n_blocks / 4;
720+
int warp_row = warp_id / n_warps;
721+
692722
s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
693723
(threadIdx.x % 32) / 4;
694-
else if constexpr (group_blocks == -1 && (m_block_size_8 || has_zp))
724+
s_sh_rd = s_sh_rd * 2 + warp_row % 2;
725+
726+
} else if constexpr (group_blocks != -1)
727+
s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
728+
(threadIdx.x % 32) / 4;
729+
else if constexpr (group_blocks == -1 &&
730+
(m_block_size_8 || (has_zp && !dequant_skip_flop)))
695731
s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
696732
(threadIdx.x % 32) / 8;
697733
else
@@ -801,7 +837,7 @@ __global__ void Marlin(
801837
sh_first_group_id = first_group_id;
802838
sh_num_groups = last_group_id - first_group_id + 1;
803839

804-
if (sh_num_groups < act_s_max_num_groups) {
840+
if (sh_num_groups > act_s_max_num_groups) {
805841
sh_num_groups = act_s_max_num_groups;
806842
}
807843

@@ -1021,12 +1057,19 @@ __global__ void Marlin(
10211057
cur_k += k_iter_size * (k % b_sh_wr_iters);
10221058

10231059
int k_blocks = cur_k / 16;
1024-
int cur_group_id = k_blocks / group_blocks;
1060+
int cur_group_id =
1061+
k_blocks / (group_blocks * (w_type == vllm::kFE2M1f ? 2 : 1));
10251062

10261063
int4* sh_s_stage = sh_s + s_sh_stage * pipe;
10271064

1028-
reinterpret_cast<int4*>(&frag_s[k % 2])[0] =
1029-
sh_s_stage[s_sh_rd + cur_group_id * s_sh_stride];
1065+
if constexpr (w_type_id != vllm::kFE2M1f.id()) {
1066+
reinterpret_cast<int4*>(&frag_s[k % 2])[0] =
1067+
sh_s_stage[s_sh_rd + cur_group_id * s_sh_stride];
1068+
} else {
1069+
reinterpret_cast<int2*>(&frag_s[k % 2])[0] =
1070+
reinterpret_cast<int2*>(
1071+
sh_s_stage)[s_sh_rd + cur_group_id * (2 * s_sh_stride)];
1072+
}
10301073
}
10311074
}
10321075

@@ -1199,22 +1242,7 @@ __global__ void Marlin(
11991242
};
12001243

12011244
auto dequant_data = [&](int q, scalar_t2* frag_b_ptr) {
1202-
if constexpr (has_zp && is_zp_float || !has_zp) {
1203-
dequant<scalar_t2, w_type_id>(q, frag_b_ptr);
1204-
} else {
1205-
static_assert(has_zp && !is_zp_float);
1206-
static_assert(w_type_id == vllm::kU4.id() || w_type_id == vllm::kU8.id());
1207-
// If (has_zp && !is_zp_float),
1208-
// we use not-zp version `dequant` function
1209-
// to improve numerical accuracy.
1210-
// Since both weight and zero point are dequanted using this logic,
1211-
// the final dequanted weight would be correct.
1212-
if constexpr (w_type_id == vllm::kU4.id()) {
1213-
dequant<scalar_t2, vllm::kU4B8.id()>(q, frag_b_ptr);
1214-
} else if constexpr (w_type_id == vllm::kU8.id()) {
1215-
dequant<scalar_t2, vllm::kU8B128.id()>(q, frag_b_ptr);
1216-
}
1217-
}
1245+
dequant<scalar_t2, w_type_id, dequant_skip_flop>(q, frag_b_ptr);
12181246
};
12191247

12201248
// Execute the actual tensor core matmul of a sub-tile.
@@ -1244,13 +1272,23 @@ __global__ void Marlin(
12441272
dequant_data(zp_quant_1, reinterpret_cast<scalar_t2*>(&frag_zp) + 2);
12451273
}
12461274
}
1247-
if constexpr (has_zp && is_zp_float) {
1275+
if constexpr (!dequant_skip_flop && has_zp && is_zp_float) {
12481276
if (is_new_zp) {
12491277
reinterpret_cast<int4*>(&frag_zp)[0] =
12501278
reinterpret_cast<int4*>(&frag_zpf[k2])[0];
12511279
}
12521280
}
12531281

1282+
if constexpr (w_type == vllm::kFE2M1f) {
1283+
int s_quant_0 = reinterpret_cast<int*>(frag_s[k2])[0];
1284+
int s_quant_1 = reinterpret_cast<int*>(frag_s[k2])[1];
1285+
1286+
dequant_fp8_scales<scalar_t2>(s_quant_0,
1287+
reinterpret_cast<scalar_t2*>(&frag_s[k2]));
1288+
dequant_fp8_scales<scalar_t2>(
1289+
s_quant_1, reinterpret_cast<scalar_t2*>(&frag_s[k2]) + 2);
1290+
}
1291+
12541292
// We have the m dimension as the inner loop in order to encourage overlapping
12551293
// dequantization and matmul operations.
12561294
#pragma unroll
@@ -1259,7 +1297,10 @@ __global__ void Marlin(
12591297
FragB frag_b1;
12601298
int b_quant_0, b_quant_1;
12611299

1262-
if constexpr (w_type.size_bits() == 4) {
1300+
if constexpr (w_type_id == vllm::kFE2M1f.id()) {
1301+
b_quant_1 = frag_b_quant[k2][0][j];
1302+
b_quant_0 = b_quant_1 << 8;
1303+
} else if constexpr (w_type.size_bits() == 4) {
12631304
b_quant_0 = frag_b_quant[k2][0][j];
12641305
b_quant_1 = b_quant_0 >> 8;
12651306
} else {
@@ -1272,22 +1313,28 @@ __global__ void Marlin(
12721313
dequant_data(b_quant_0, reinterpret_cast<scalar_t2*>(&frag_b0));
12731314
dequant_data(b_quant_1, reinterpret_cast<scalar_t2*>(&frag_b1));
12741315

1316+
if constexpr (dequant_skip_flop && has_zp && !is_zp_float) {
1317+
sub_zp<scalar_t>(frag_b0, frag_zp[j], 0);
1318+
sub_zp<scalar_t>(frag_b1, frag_zp[j], 1);
1319+
}
1320+
12751321
// Apply scale to frag_b0
12761322
if constexpr (has_act_order) {
12771323
static_assert(group_blocks != -1);
12781324
scale4<scalar_t>(frag_b0, act_frag_s[k2][0][j], act_frag_s[k2][1][j],
12791325
act_frag_s[k2][2][j], act_frag_s[k2][3][j], 0);
12801326
scale4<scalar_t>(frag_b1, act_frag_s[k2][0][j], act_frag_s[k2][1][j],
12811327
act_frag_s[k2][2][j], act_frag_s[k2][3][j], 1);
1282-
} else if constexpr (has_zp && !is_zp_float && group_blocks == -1) {
1328+
} else if constexpr (!dequant_skip_flop && has_zp && !is_zp_float &&
1329+
group_blocks == -1) {
12831330
int idx = (threadIdx.x / 4) % 2;
12841331
scalar_t2 s2 = Dtype::nums2num2(
12851332
reinterpret_cast<scalar_t*>(&frag_s[j / 2][j % 2 * 2 + 0])[idx],
12861333
reinterpret_cast<scalar_t*>(&frag_s[j / 2][j % 2 * 2 + 1])[idx]);
12871334
if (is_new_zp) frag_zp[j] = __hmul2(frag_zp[j], s2);
12881335
scale_and_sub<scalar_t>(frag_b0, s2.x, frag_zp[j].x);
12891336
scale_and_sub<scalar_t>(frag_b1, s2.y, frag_zp[j].y);
1290-
} else if constexpr (has_zp && group_blocks != -1) {
1337+
} else if constexpr (!dequant_skip_flop && has_zp && group_blocks != -1) {
12911338
if (is_new_zp)
12921339
frag_zp[j] = __hmul2(frag_zp[j],
12931340
*reinterpret_cast<scalar_t2*>(&frag_s[k2][j]));
@@ -1554,10 +1601,17 @@ __global__ void Marlin(
15541601
// For per-column quantization we finally apply the scale here (only for
15551602
// 4-bit)
15561603
if constexpr (!has_act_order && group_blocks == -1 &&
1557-
w_type.size_bits() == 4 && !has_zp) {
1604+
w_type.size_bits() == 4 &&
1605+
(has_zp && dequant_skip_flop || !has_zp)) {
15581606
res = __hmul2(res, s[0]);
15591607
}
15601608

1609+
if constexpr (w_type == vllm::kFE2M1f) {
1610+
if (!mul_topk_weights) {
1611+
res = __hmul2(res, global_scale);
1612+
}
1613+
}
1614+
15611615
if constexpr (m_block_size_8) {
15621616
((scalar_t*)sh_red)[idx] = res.x;
15631617
((scalar_t*)sh_red)[idx + 8 * c_sh_stride] = res.y;
@@ -1648,7 +1702,9 @@ __global__ void Marlin(
16481702
if constexpr (has_zp && !is_zp_float && group_blocks == -1) {
16491703
if (i == 0) {
16501704
fetch_col_zp_to_shared();
1651-
fetch_col_scale_to_shared();
1705+
if constexpr (!dequant_skip_flop) {
1706+
fetch_col_scale_to_shared();
1707+
}
16521708
}
16531709
}
16541710
fetch_to_shared(i, i, i < slice_iters, i);
@@ -1737,7 +1793,8 @@ __global__ void Marlin(
17371793
bool last = slice_idx == slice_count - 1;
17381794
// For per-column scales, we only fetch them here in the final step before
17391795
// write-out
1740-
if constexpr (!has_act_order && group_blocks == -1 && !has_zp) {
1796+
if constexpr (!has_act_order && group_blocks == -1 &&
1797+
(has_zp && dequant_skip_flop || !has_zp)) {
17411798
if (w_type.size_bits() == 8 || (last || use_atomic_add)) {
17421799
if (s_sh_wr_pred) {
17431800
cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]);
@@ -1747,7 +1804,8 @@ __global__ void Marlin(
17471804
}
17481805

17491806
thread_block_reduce();
1750-
if constexpr (!has_act_order && group_blocks == -1 && !has_zp) {
1807+
if constexpr (!has_act_order && group_blocks == -1 &&
1808+
(has_zp && dequant_skip_flop || !has_zp)) {
17511809
if (w_type.size_bits() == 8 || (last || use_atomic_add)) {
17521810
cp_async_wait<0>();
17531811
__syncthreads();
@@ -1771,7 +1829,8 @@ __global__ void Marlin(
17711829
// that converts the fp32 results to fp16 (so that we avoid possible
17721830
// overflow in fp16)
17731831
if constexpr (!has_act_order && group_blocks == -1 &&
1774-
w_type.size_bits() == 8 && !has_zp) {
1832+
w_type.size_bits() == 8 &&
1833+
(has_zp && dequant_skip_flop || !has_zp)) {
17751834
if (threadIdx.x / 32 < thread_n_blocks / 4) {
17761835
#pragma unroll
17771836
for (int i = 0; i < thread_m_blocks; i++) {

0 commit comments

Comments
 (0)