@@ -301,9 +301,11 @@ __global__ void Marlin(
301
301
int4* __restrict__ C_tmp, // fp32 tmp output buffer (for reduce)
302
302
const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape
303
303
// (k/groupsize)xn
304
- const int4* __restrict__ zp_ptr, // 4bit packed zero-points of shape
305
- // (k/groupsize)x(n/pack_factor)
306
- const int * __restrict__ g_idx, // int32 group indices of shape k
304
+ const uint16_t * __restrict__ scale2_ptr, // fp16 global scale (for nvfp4
305
+ // only)
306
+ const int4* __restrict__ zp_ptr, // 4bit packed zero-points of shape
307
+ // (k/groupsize)x(n/pack_factor)
308
+ const int * __restrict__ g_idx, // int32 group indices of shape k
307
309
const int32_t * __restrict__ sorted_token_ids_ptr, // moe sorted_ids
308
310
const int32_t * __restrict__ expert_ids_ptr, // moe expert ids
309
311
const int32_t * __restrict__ num_tokens_past_padded_ptr, // moe num tokens
@@ -341,14 +343,25 @@ __global__ void Marlin(
341
343
extern __shared__ int4 sh[];
342
344
static constexpr auto w_type = vllm::ScalarType::from_id (w_type_id);
343
345
constexpr bool has_zp = w_type == vllm::kU4 || w_type == vllm::kU8 ;
346
+ constexpr bool is_int_type = w_type == vllm::kU4 || w_type == vllm::kU8 ||
347
+ w_type == vllm::kU4B8 || w_type == vllm::kU8B128 ;
348
+ // see comments of dequant.h for more details
349
+ constexpr bool dequant_skip_flop =
350
+ !is_int_type ||
351
+ has_zp && !is_zp_float && !std::is_same<scalar_t , nv_bfloat16>::value ||
352
+ has_zp && !is_zp_float && !(w_type == vllm::kU8 );
353
+
354
+ scalar_t2 global_scale;
355
+
344
356
constexpr bool has_act_order = group_blocks == 0 ;
345
357
346
358
constexpr int pack_factor = 32 / w_type.size_bits ();
347
359
static_assert (thread_m_blocks == 1 || !m_block_size_8);
348
360
constexpr int moe_block_size = m_block_size_8 ? 8 : (16 * thread_m_blocks);
349
361
const int group_size =
350
362
(!has_act_order && group_blocks == -1 ) ? prob_k : prob_k / num_groups;
351
- const int scales_expert_stride = prob_n * prob_k / group_size / 8 ;
363
+ const int scales_expert_stride =
364
+ prob_n * prob_k / group_size / (w_type == vllm::kFE2M1f ? 16 : 8 );
352
365
const int zp_expert_stride =
353
366
is_zp_float ? prob_n * prob_k / group_size / 8
354
367
: prob_n * prob_k / group_size / (pack_factor * 4 );
@@ -460,9 +473,16 @@ __global__ void Marlin(
460
473
if (mul_topk_weights) {
461
474
#pragma unroll
462
475
for (int i = 0 ; i < 4 ; i++) {
463
- sh_block_topk_weights[tid4 * 4 + i] =
464
- Dtype::num2num2 (Dtype::float2num (
465
- topk_weights_ptr[sh_block_sorted_ids[tid4 * 4 + i]]));
476
+ if constexpr (w_type == vllm::kFE2M1f ) {
477
+ sh_block_topk_weights[tid4 * 4 + i] = __hmul2 (
478
+ global_scale,
479
+ Dtype::num2num2 (Dtype::float2num (
480
+ topk_weights_ptr[sh_block_sorted_ids[tid4 * 4 + i]])));
481
+ } else {
482
+ sh_block_topk_weights[tid4 * 4 + i] =
483
+ Dtype::num2num2 (Dtype::float2num (
484
+ topk_weights_ptr[sh_block_sorted_ids[tid4 * 4 + i]]));
485
+ }
466
486
}
467
487
}
468
488
}
@@ -493,6 +513,11 @@ __global__ void Marlin(
493
513
expert_id = expert_ids_ptr[block_id];
494
514
}
495
515
516
+ if constexpr (w_type == vllm::kFE2M1f ) {
517
+ uint16_t val = scale2_ptr[expert_id];
518
+ global_scale = Dtype::num2num2 (*reinterpret_cast <scalar_t *>(&val));
519
+ }
520
+
496
521
B_expert_off = expert_id * prob_n * prob_k / (pack_factor * 4 );
497
522
scales_ptr += (expert_id - old_expert_id) * scales_expert_stride;
498
523
if constexpr (has_zp) {
@@ -606,7 +631,7 @@ __global__ void Marlin(
606
631
constexpr int s_sh_stride = 16 * thread_n_blocks / 8 ;
607
632
constexpr int s_tb_groups =
608
633
!has_act_order && group_blocks != -1 && group_blocks < thread_k_blocks
609
- ? thread_k_blocks / group_blocks
634
+ ? thread_k_blocks / group_blocks / (w_type == vllm:: kFE2M1f ? 2 : 1 )
610
635
: 1 ;
611
636
constexpr int s_sh_stage = s_tb_groups * s_sh_stride;
612
637
int s_gl_rd_delta = s_gl_stride;
@@ -664,7 +689,8 @@ __global__ void Marlin(
664
689
if constexpr (group_blocks == -1 ) {
665
690
s_gl_rd = s_sh_stride * slice_col + threadIdx.x ;
666
691
} else {
667
- s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) +
692
+ s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) /
693
+ (w_type == vllm::kFE2M1f ? 2 : 1 ) +
668
694
s_sh_stride * slice_col + threadIdx.x ;
669
695
}
670
696
}
@@ -688,10 +714,20 @@ __global__ void Marlin(
688
714
// we scale a `half2` tile in column-major layout in the former and in
689
715
// row-major in the latter case.
690
716
int s_sh_rd;
691
- if constexpr (group_blocks != -1 )
717
+ if constexpr (group_blocks != -1 && w_type == vllm::kFE2M1f ) {
718
+ auto warp_id = threadIdx.x / 32 ;
719
+ int n_warps = thread_n_blocks / 4 ;
720
+ int warp_row = warp_id / n_warps;
721
+
692
722
s_sh_rd = 8 * ((threadIdx.x / 32 ) % (thread_n_blocks / 4 )) +
693
723
(threadIdx.x % 32 ) / 4 ;
694
- else if constexpr (group_blocks == -1 && (m_block_size_8 || has_zp))
724
+ s_sh_rd = s_sh_rd * 2 + warp_row % 2 ;
725
+
726
+ } else if constexpr (group_blocks != -1 )
727
+ s_sh_rd = 8 * ((threadIdx.x / 32 ) % (thread_n_blocks / 4 )) +
728
+ (threadIdx.x % 32 ) / 4 ;
729
+ else if constexpr (group_blocks == -1 &&
730
+ (m_block_size_8 || (has_zp && !dequant_skip_flop)))
695
731
s_sh_rd = 8 * ((threadIdx.x / 32 ) % (thread_n_blocks / 4 )) +
696
732
(threadIdx.x % 32 ) / 8 ;
697
733
else
@@ -801,7 +837,7 @@ __global__ void Marlin(
801
837
sh_first_group_id = first_group_id;
802
838
sh_num_groups = last_group_id - first_group_id + 1 ;
803
839
804
- if (sh_num_groups < act_s_max_num_groups) {
840
+ if (sh_num_groups > act_s_max_num_groups) {
805
841
sh_num_groups = act_s_max_num_groups;
806
842
}
807
843
@@ -1021,12 +1057,19 @@ __global__ void Marlin(
1021
1057
cur_k += k_iter_size * (k % b_sh_wr_iters);
1022
1058
1023
1059
int k_blocks = cur_k / 16 ;
1024
- int cur_group_id = k_blocks / group_blocks;
1060
+ int cur_group_id =
1061
+ k_blocks / (group_blocks * (w_type == vllm::kFE2M1f ? 2 : 1 ));
1025
1062
1026
1063
int4* sh_s_stage = sh_s + s_sh_stage * pipe;
1027
1064
1028
- reinterpret_cast <int4*>(&frag_s[k % 2 ])[0 ] =
1029
- sh_s_stage[s_sh_rd + cur_group_id * s_sh_stride];
1065
+ if constexpr (w_type_id != vllm::kFE2M1f .id ()) {
1066
+ reinterpret_cast <int4*>(&frag_s[k % 2 ])[0 ] =
1067
+ sh_s_stage[s_sh_rd + cur_group_id * s_sh_stride];
1068
+ } else {
1069
+ reinterpret_cast <int2*>(&frag_s[k % 2 ])[0 ] =
1070
+ reinterpret_cast <int2*>(
1071
+ sh_s_stage)[s_sh_rd + cur_group_id * (2 * s_sh_stride)];
1072
+ }
1030
1073
}
1031
1074
}
1032
1075
@@ -1199,22 +1242,7 @@ __global__ void Marlin(
1199
1242
};
1200
1243
1201
1244
auto dequant_data = [&](int q, scalar_t2* frag_b_ptr) {
1202
- if constexpr (has_zp && is_zp_float || !has_zp) {
1203
- dequant<scalar_t2, w_type_id>(q, frag_b_ptr);
1204
- } else {
1205
- static_assert (has_zp && !is_zp_float);
1206
- static_assert (w_type_id == vllm::kU4 .id () || w_type_id == vllm::kU8 .id ());
1207
- // If (has_zp && !is_zp_float),
1208
- // we use not-zp version `dequant` function
1209
- // to improve numerical accuracy.
1210
- // Since both weight and zero point are dequanted using this logic,
1211
- // the final dequanted weight would be correct.
1212
- if constexpr (w_type_id == vllm::kU4 .id ()) {
1213
- dequant<scalar_t2, vllm::kU4B8 .id ()>(q, frag_b_ptr);
1214
- } else if constexpr (w_type_id == vllm::kU8 .id ()) {
1215
- dequant<scalar_t2, vllm::kU8B128 .id ()>(q, frag_b_ptr);
1216
- }
1217
- }
1245
+ dequant<scalar_t2, w_type_id, dequant_skip_flop>(q, frag_b_ptr);
1218
1246
};
1219
1247
1220
1248
// Execute the actual tensor core matmul of a sub-tile.
@@ -1244,13 +1272,23 @@ __global__ void Marlin(
1244
1272
dequant_data (zp_quant_1, reinterpret_cast <scalar_t2*>(&frag_zp) + 2 );
1245
1273
}
1246
1274
}
1247
- if constexpr (has_zp && is_zp_float) {
1275
+ if constexpr (!dequant_skip_flop && has_zp && is_zp_float) {
1248
1276
if (is_new_zp) {
1249
1277
reinterpret_cast <int4*>(&frag_zp)[0 ] =
1250
1278
reinterpret_cast <int4*>(&frag_zpf[k2])[0 ];
1251
1279
}
1252
1280
}
1253
1281
1282
+ if constexpr (w_type == vllm::kFE2M1f ) {
1283
+ int s_quant_0 = reinterpret_cast <int *>(frag_s[k2])[0 ];
1284
+ int s_quant_1 = reinterpret_cast <int *>(frag_s[k2])[1 ];
1285
+
1286
+ dequant_fp8_scales<scalar_t2>(s_quant_0,
1287
+ reinterpret_cast <scalar_t2*>(&frag_s[k2]));
1288
+ dequant_fp8_scales<scalar_t2>(
1289
+ s_quant_1, reinterpret_cast <scalar_t2*>(&frag_s[k2]) + 2 );
1290
+ }
1291
+
1254
1292
// We have the m dimension as the inner loop in order to encourage overlapping
1255
1293
// dequantization and matmul operations.
1256
1294
#pragma unroll
@@ -1259,7 +1297,10 @@ __global__ void Marlin(
1259
1297
FragB frag_b1;
1260
1298
int b_quant_0, b_quant_1;
1261
1299
1262
- if constexpr (w_type.size_bits () == 4 ) {
1300
+ if constexpr (w_type_id == vllm::kFE2M1f .id ()) {
1301
+ b_quant_1 = frag_b_quant[k2][0 ][j];
1302
+ b_quant_0 = b_quant_1 << 8 ;
1303
+ } else if constexpr (w_type.size_bits () == 4 ) {
1263
1304
b_quant_0 = frag_b_quant[k2][0 ][j];
1264
1305
b_quant_1 = b_quant_0 >> 8 ;
1265
1306
} else {
@@ -1272,22 +1313,28 @@ __global__ void Marlin(
1272
1313
dequant_data (b_quant_0, reinterpret_cast <scalar_t2*>(&frag_b0));
1273
1314
dequant_data (b_quant_1, reinterpret_cast <scalar_t2*>(&frag_b1));
1274
1315
1316
+ if constexpr (dequant_skip_flop && has_zp && !is_zp_float) {
1317
+ sub_zp<scalar_t >(frag_b0, frag_zp[j], 0 );
1318
+ sub_zp<scalar_t >(frag_b1, frag_zp[j], 1 );
1319
+ }
1320
+
1275
1321
// Apply scale to frag_b0
1276
1322
if constexpr (has_act_order) {
1277
1323
static_assert (group_blocks != -1 );
1278
1324
scale4<scalar_t >(frag_b0, act_frag_s[k2][0 ][j], act_frag_s[k2][1 ][j],
1279
1325
act_frag_s[k2][2 ][j], act_frag_s[k2][3 ][j], 0 );
1280
1326
scale4<scalar_t >(frag_b1, act_frag_s[k2][0 ][j], act_frag_s[k2][1 ][j],
1281
1327
act_frag_s[k2][2 ][j], act_frag_s[k2][3 ][j], 1 );
1282
- } else if constexpr (has_zp && !is_zp_float && group_blocks == -1 ) {
1328
+ } else if constexpr (!dequant_skip_flop && has_zp && !is_zp_float &&
1329
+ group_blocks == -1 ) {
1283
1330
int idx = (threadIdx.x / 4 ) % 2 ;
1284
1331
scalar_t2 s2 = Dtype::nums2num2 (
1285
1332
reinterpret_cast <scalar_t *>(&frag_s[j / 2 ][j % 2 * 2 + 0 ])[idx],
1286
1333
reinterpret_cast <scalar_t *>(&frag_s[j / 2 ][j % 2 * 2 + 1 ])[idx]);
1287
1334
if (is_new_zp) frag_zp[j] = __hmul2 (frag_zp[j], s2);
1288
1335
scale_and_sub<scalar_t >(frag_b0, s2.x , frag_zp[j].x );
1289
1336
scale_and_sub<scalar_t >(frag_b1, s2.y , frag_zp[j].y );
1290
- } else if constexpr (has_zp && group_blocks != -1 ) {
1337
+ } else if constexpr (!dequant_skip_flop && has_zp && group_blocks != -1 ) {
1291
1338
if (is_new_zp)
1292
1339
frag_zp[j] = __hmul2 (frag_zp[j],
1293
1340
*reinterpret_cast <scalar_t2*>(&frag_s[k2][j]));
@@ -1554,10 +1601,17 @@ __global__ void Marlin(
1554
1601
// For per-column quantization we finally apply the scale here (only for
1555
1602
// 4-bit)
1556
1603
if constexpr (!has_act_order && group_blocks == -1 &&
1557
- w_type.size_bits () == 4 && !has_zp) {
1604
+ w_type.size_bits () == 4 &&
1605
+ (has_zp && dequant_skip_flop || !has_zp)) {
1558
1606
res = __hmul2 (res, s[0 ]);
1559
1607
}
1560
1608
1609
+ if constexpr (w_type == vllm::kFE2M1f ) {
1610
+ if (!mul_topk_weights) {
1611
+ res = __hmul2 (res, global_scale);
1612
+ }
1613
+ }
1614
+
1561
1615
if constexpr (m_block_size_8) {
1562
1616
((scalar_t *)sh_red)[idx] = res.x ;
1563
1617
((scalar_t *)sh_red)[idx + 8 * c_sh_stride] = res.y ;
@@ -1648,7 +1702,9 @@ __global__ void Marlin(
1648
1702
if constexpr (has_zp && !is_zp_float && group_blocks == -1 ) {
1649
1703
if (i == 0 ) {
1650
1704
fetch_col_zp_to_shared ();
1651
- fetch_col_scale_to_shared ();
1705
+ if constexpr (!dequant_skip_flop) {
1706
+ fetch_col_scale_to_shared ();
1707
+ }
1652
1708
}
1653
1709
}
1654
1710
fetch_to_shared (i, i, i < slice_iters, i);
@@ -1737,7 +1793,8 @@ __global__ void Marlin(
1737
1793
bool last = slice_idx == slice_count - 1 ;
1738
1794
// For per-column scales, we only fetch them here in the final step before
1739
1795
// write-out
1740
- if constexpr (!has_act_order && group_blocks == -1 && !has_zp) {
1796
+ if constexpr (!has_act_order && group_blocks == -1 &&
1797
+ (has_zp && dequant_skip_flop || !has_zp)) {
1741
1798
if (w_type.size_bits () == 8 || (last || use_atomic_add)) {
1742
1799
if (s_sh_wr_pred) {
1743
1800
cp_async4 (&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]);
@@ -1747,7 +1804,8 @@ __global__ void Marlin(
1747
1804
}
1748
1805
1749
1806
thread_block_reduce ();
1750
- if constexpr (!has_act_order && group_blocks == -1 && !has_zp) {
1807
+ if constexpr (!has_act_order && group_blocks == -1 &&
1808
+ (has_zp && dequant_skip_flop || !has_zp)) {
1751
1809
if (w_type.size_bits () == 8 || (last || use_atomic_add)) {
1752
1810
cp_async_wait<0 >();
1753
1811
__syncthreads ();
@@ -1771,7 +1829,8 @@ __global__ void Marlin(
1771
1829
// that converts the fp32 results to fp16 (so that we avoid possible
1772
1830
// overflow in fp16)
1773
1831
if constexpr (!has_act_order && group_blocks == -1 &&
1774
- w_type.size_bits () == 8 && !has_zp) {
1832
+ w_type.size_bits () == 8 &&
1833
+ (has_zp && dequant_skip_flop || !has_zp)) {
1775
1834
if (threadIdx.x / 32 < thread_n_blocks / 4 ) {
1776
1835
#pragma unroll
1777
1836
for (int i = 0 ; i < thread_m_blocks; i++) {
0 commit comments