Skip to content

Commit 4543f58

Browse files
alexm-redhatAlvant
authored andcommitted
[Kernel] Increase precision of GPTQ/AWQ Marlin kernel (vllm-project#6795)
Signed-off-by: Alvant <alvasian@yandex.ru>
1 parent 078266f commit 4543f58

File tree

6 files changed

+168
-44
lines changed

6 files changed

+168
-44
lines changed

benchmarks/kernels/benchmark_marlin.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_24_SUPPORTED_NUM_BITS)
1111
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
1212
GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N,
13-
GPTQ_MARLIN_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_SUPPORTED_NUM_BITS)
13+
MARLIN_SUPPORTED_GROUP_SIZES, MARLIN_SUPPORTED_NUM_BITS)
1414
from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
1515
MarlinWorkspace, marlin_quantize)
1616
from vllm.model_executor.layers.quantization.utils.marlin_utils_test_24 import (
@@ -56,6 +56,8 @@ def bench_run(results: List[benchmark.Measurement], model: str,
5656
(marlin_24_w_ref, marlin_24_q_w_comp, marlin_24_meta,
5757
marlin_24_s) = marlin_24_quantize(b, num_bits, group_size)
5858

59+
marlin_zp = torch.empty(0, dtype=torch.int, device=b.device)
60+
5961
# GPTQ quant
6062
(w_ref, q_w, s, g_idx,
6163
rand_perm) = quantize_weights(b, num_bits, group_size, act_order)
@@ -87,6 +89,7 @@ def bench_run(results: List[benchmark.Measurement], model: str,
8789
"marlin_w_ref": marlin_w_ref,
8890
"marlin_q_w": marlin_q_w,
8991
"marlin_s": marlin_s,
92+
"marlin_zp": marlin_zp,
9093
"marlin_g_idx": marlin_g_idx,
9194
"marlin_sort_indices": marlin_sort_indices,
9295
"marlin_rand_perm": marlin_rand_perm,
@@ -125,11 +128,21 @@ def bench_run(results: List[benchmark.Measurement], model: str,
125128
results.append(
126129
benchmark.Timer(
127130
stmt=
128-
"output = gptq_marlin_gemm(a, marlin_q_w, marlin_s, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, num_bits, size_m, size_n, size_k, is_k_full)", # noqa: E501
131+
"output = gptq_marlin_gemm(a, marlin_q_w, marlin_s, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, num_bits, size_m, size_n, size_k, is_k_full, False, False)", # noqa: E501
132+
globals=globals,
133+
label=label,
134+
sub_label=sub_label,
135+
description="gptq_marlin_gemm_fp16",
136+
).blocked_autorange(min_run_time=min_run_time))
137+
138+
results.append(
139+
benchmark.Timer(
140+
stmt=
141+
"output = gptq_marlin_gemm(a, marlin_q_w, marlin_s, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, num_bits, size_m, size_n, size_k, is_k_full, False, True)", # noqa: E501
129142
globals=globals,
130143
label=label,
131144
sub_label=sub_label,
132-
description="gptq_marlin_gemm",
145+
description="gptq_marlin_gemm_fp32",
133146
).blocked_autorange(min_run_time=min_run_time))
134147

135148
if (num_bits in GPTQ_MARLIN_24_SUPPORTED_NUM_BITS
@@ -183,12 +196,12 @@ def main(args):
183196
) > 0 and is_k_full not in args.limit_k_full:
184197
continue
185198

186-
for num_bits in GPTQ_MARLIN_SUPPORTED_NUM_BITS:
199+
for num_bits in MARLIN_SUPPORTED_NUM_BITS:
187200
if len(args.limit_num_bits
188201
) > 0 and num_bits not in args.limit_num_bits:
189202
continue
190203

191-
for group_size in GPTQ_MARLIN_SUPPORTED_GROUP_SIZES:
204+
for group_size in MARLIN_SUPPORTED_GROUP_SIZES:
192205
if len(
193206
args.limit_group_size
194207
) > 0 and group_size not in args.limit_group_size:

csrc/ops.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,8 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
9393
torch::Tensor& g_idx, torch::Tensor& perm,
9494
torch::Tensor& workspace, int64_t num_bits,
9595
int64_t size_m, int64_t size_n, int64_t size_k,
96-
bool is_k_full, bool has_zp);
96+
bool is_k_full, bool has_zp,
97+
bool use_fp32_reduce);
9798

9899
torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
99100
int64_t size_k, int64_t size_n,

csrc/quantization/gptq_marlin/gptq_marlin.cu

Lines changed: 122 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -59,14 +59,16 @@ __global__ void Marlin(
5959
const int4* __restrict__ A, // fp16 input matrix of shape mxk
6060
const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn
6161
int4* __restrict__ C, // fp16 output buffer of shape mxn
62+
int4* __restrict__ C_tmp, // fp32 tmp output buffer (for reduce)
6263
const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape
6364
// (k/groupsize)xn
6465
const int* __restrict__ g_idx, // int32 group indices of shape k
65-
int num_groups, // number of scale groups per output channel
66-
int prob_m, // batch dimension m
67-
int prob_n, // output dimension n
68-
int prob_k, // reduction dimension k
69-
int* locks // extra global storage for barrier synchronization
66+
int num_groups, // number of scale groups per output channel
67+
int prob_m, // batch dimension m
68+
int prob_n, // output dimension n
69+
int prob_k, // reduction dimension k
70+
int* locks, // extra global storage for barrier synchronization
71+
bool use_fp32_reduce // whether to use fp32 global reduce
7072
) {}
7173

7274
} // namespace gptq_marlin
@@ -532,16 +534,18 @@ __global__ void Marlin(
532534
const int4* __restrict__ A, // fp16 input matrix of shape mxk
533535
const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn
534536
int4* __restrict__ C, // fp16 output buffer of shape mxn
537+
int4* __restrict__ C_tmp, // fp32 tmp output buffer (for reduce)
535538
const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape
536539
// (k/groupsize)xn
537540
const int4* __restrict__ zp_ptr, // 4bit packed zero-points of shape
538541
// (k/groupsize)x(n/pack_factor)
539542
const int* __restrict__ g_idx, // int32 group indices of shape k
540-
int num_groups, // number of scale groups per output channel
541-
int prob_m, // batch dimension m
542-
int prob_n, // output dimension n
543-
int prob_k, // reduction dimension k
544-
int* locks // extra global storage for barrier synchronization
543+
int num_groups, // number of scale groups per output channel
544+
int prob_m, // batch dimension m
545+
int prob_n, // output dimension n
546+
int prob_k, // reduction dimension k
547+
int* locks, // extra global storage for barrier synchronization
548+
bool use_fp32_reduce // whether to use fp32 global reduce
545549
) {
546550
// Each threadblock processes one "stripe" of the B matrix with (roughly) the
547551
// same size, which might involve multiple column "slices" (of width 16 *
@@ -595,13 +599,16 @@ __global__ void Marlin(
595599
int slice_idx; // index of threadblock in current slice; numbered bottom to
596600
// top
597601

602+
int par_id = 0;
603+
598604
// We can easily implement parallel problem execution by just remapping
599605
// indices and advancing global pointers
600606
if (slice_col_par >= n_tiles) {
601607
A += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_k / 8;
602608
C += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_n / 8;
603609
locks += (slice_col_par / n_tiles) * n_tiles;
604610
slice_col = slice_col_par % n_tiles;
611+
par_id = slice_col_par / n_tiles;
605612
}
606613

607614
// Compute all information about the current slice which is required for
@@ -632,6 +639,7 @@ __global__ void Marlin(
632639
C += 16 * thread_m_blocks * prob_n / 8;
633640
locks += n_tiles;
634641
slice_col = 0;
642+
par_id++;
635643
}
636644
};
637645
init_slice();
@@ -1321,7 +1329,7 @@ __global__ void Marlin(
13211329
// finally have to globally reduce over the results. As the striped
13221330
// partitioning minimizes the number of such reductions and our outputs are
13231331
// usually rather small, we perform this reduction serially in L2 cache.
1324-
auto global_reduce = [&](bool first = false, bool last = false) {
1332+
auto global_reduce_fp16 = [&](bool first = false, bool last = false) {
13251333
// We are very careful here to reduce directly in the output buffer to
13261334
// maximize L2 cache utilization in this step. To do this, we write out
13271335
// results in FP16 (but still reduce with FP32 compute).
@@ -1382,6 +1390,53 @@ __global__ void Marlin(
13821390
}
13831391
};
13841392

1393+
// Globally reduce over threadblocks that compute the same column block.
1394+
// We use a tmp C buffer to reduce in full fp32 precision.
1395+
auto global_reduce_fp32 = [&](bool first = false, bool last = false) {
1396+
constexpr int tb_m = thread_m_blocks * 16;
1397+
constexpr int tb_n = thread_n_blocks * 16;
1398+
1399+
constexpr int c_size = tb_m * tb_n * sizeof(float) / 16;
1400+
1401+
constexpr int active_threads = 32 * thread_n_blocks / 4;
1402+
bool is_th_active = threadIdx.x < active_threads;
1403+
1404+
int par_offset = c_size * n_tiles * par_id;
1405+
int slice_offset = c_size * slice_col;
1406+
1407+
constexpr int num_floats = thread_m_blocks * 4 * 2 * 4;
1408+
constexpr int th_size = num_floats * sizeof(float) / 16;
1409+
1410+
int c_cur_offset = par_offset + slice_offset;
1411+
1412+
if (!is_th_active) {
1413+
return;
1414+
}
1415+
1416+
if (!first) {
1417+
float* frag_c_ptr = reinterpret_cast<float*>(&frag_c);
1418+
#pragma unroll
1419+
for (int k = 0; k < th_size; k++) {
1420+
sh[threadIdx.x] =
1421+
C_tmp[c_cur_offset + active_threads * k + threadIdx.x];
1422+
1423+
float* sh_c_ptr = reinterpret_cast<float*>(&sh[threadIdx.x]);
1424+
#pragma unroll
1425+
for (int f = 0; f < 4; f++) {
1426+
frag_c_ptr[k * 4 + f] += sh_c_ptr[f];
1427+
}
1428+
}
1429+
}
1430+
1431+
if (!last) {
1432+
int4* frag_c_ptr = reinterpret_cast<int4*>(&frag_c);
1433+
#pragma unroll
1434+
for (int k = 0; k < th_size; k++) {
1435+
C_tmp[c_cur_offset + active_threads * k + threadIdx.x] = frag_c_ptr[k];
1436+
}
1437+
}
1438+
};
1439+
13851440
// Write out the reduce final result in the correct layout. We only actually
13861441
// reshuffle matrix fragments in this step, the reduction above is performed
13871442
// in fragment layout.
@@ -1606,7 +1661,11 @@ __global__ void Marlin(
16061661
if (slice_count > 1) { // only globally reduce if there is more than one
16071662
// block in a slice
16081663
barrier_acquire(&locks[slice_col], slice_idx);
1609-
global_reduce(slice_idx == 0, last);
1664+
if (use_fp32_reduce) {
1665+
global_reduce_fp32(slice_idx == 0, last);
1666+
} else {
1667+
global_reduce_fp16(slice_idx == 0, last);
1668+
}
16101669
barrier_release(&locks[slice_col], last);
16111670
}
16121671
if (last) // only the last block in a slice actually writes the result
@@ -1661,8 +1720,8 @@ __global__ void Marlin(
16611720
THREAD_N_BLOCKS, THREAD_K_BLOCKS, pipe_stages, HAS_ACT_ORDER, \
16621721
HAS_ZP, GROUP_BLOCKS> \
16631722
<<<blocks, NUM_THREADS, max_shared_mem, stream>>>( \
1664-
A_ptr, B_ptr, C_ptr, s_ptr, zp_ptr, g_idx_ptr, num_groups, \
1665-
prob_m, prob_n, prob_k, locks); \
1723+
A_ptr, B_ptr, C_ptr, C_tmp_ptr, s_ptr, zp_ptr, g_idx_ptr, \
1724+
num_groups, prob_m, prob_n, prob_k, locks, use_fp32_reduce); \
16661725
}
16671726

16681727
typedef struct {
@@ -1801,6 +1860,27 @@ bool is_valid_config(thread_config_t const& th_config, int max_m_blocks,
18011860
return true;
18021861
}
18031862

1863+
int determine_reduce_max_m(int prob_m, int max_par) {
1864+
constexpr int tile_m_size = 16;
1865+
1866+
if (prob_m <= tile_m_size) {
1867+
return tile_m_size;
1868+
1869+
} else if (prob_m <= tile_m_size * 2) {
1870+
return tile_m_size * 2;
1871+
1872+
} else if (prob_m <= tile_m_size * 3) {
1873+
return tile_m_size * 3;
1874+
1875+
} else if (prob_m <= tile_m_size * 4) {
1876+
return tile_m_size * 4;
1877+
1878+
} else {
1879+
int cur_par = min(div_ceil(prob_m, tile_m_size * 4), max_par);
1880+
return tile_m_size * 4 * cur_par;
1881+
}
1882+
}
1883+
18041884
exec_config_t determine_thread_config(int prob_m, int prob_n, int prob_k,
18051885
int num_bits, int group_size,
18061886
bool has_act_order, bool is_k_full,
@@ -1880,13 +1960,13 @@ exec_config_t determine_thread_config(int prob_m, int prob_n, int prob_k,
18801960
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS)
18811961

18821962
template <typename scalar_t>
1883-
void marlin_mm_f16i4(const void* A, const void* B, void* C, void* s, void* zp,
1884-
void* g_idx, void* perm, void* a_tmp, int prob_m,
1885-
int prob_n, int prob_k, void* workspace, int num_bits,
1886-
bool has_act_order, bool is_k_full, bool has_zp,
1887-
int num_groups, int group_size, int dev,
1963+
void marlin_mm_f16i4(const void* A, const void* B, void* C, void* C_tmp,
1964+
void* s, void* zp, void* g_idx, void* perm, void* a_tmp,
1965+
int prob_m, int prob_n, int prob_k, void* workspace,
1966+
int num_bits, bool has_act_order, bool is_k_full,
1967+
bool has_zp, int num_groups, int group_size, int dev,
18881968
cudaStream_t stream, int thread_k, int thread_n, int sms,
1889-
int max_par) {
1969+
int max_par, bool use_fp32_reduce) {
18901970
TORCH_CHECK(num_bits == 4 || num_bits == 8,
18911971
"num_bits must be 4 or 8. Got = ", num_bits);
18921972
TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m,
@@ -1970,6 +2050,7 @@ void marlin_mm_f16i4(const void* A, const void* B, void* C, void* s, void* zp,
19702050
const int4* A_ptr = (const int4*)A;
19712051
const int4* B_ptr = (const int4*)B;
19722052
int4* C_ptr = (int4*)C;
2053+
int4* C_tmp_ptr = (int4*)C_tmp;
19732054
const int4* s_ptr = (const int4*)s;
19742055
const int4* zp_ptr = (const int4*)zp;
19752056
const int* g_idx_ptr = (const int*)g_idx;
@@ -2049,7 +2130,8 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
20492130
torch::Tensor& g_idx, torch::Tensor& perm,
20502131
torch::Tensor& workspace, int64_t num_bits,
20512132
int64_t size_m, int64_t size_n, int64_t size_k,
2052-
bool is_k_full, bool has_zp) {
2133+
bool is_k_full, bool has_zp,
2134+
bool use_fp32_reduce) {
20532135
// Verify num_bits
20542136
TORCH_CHECK(num_bits == 4 || num_bits == 8,
20552137
"num_bits must be 4 or 8. Got = ", num_bits);
@@ -2099,6 +2181,17 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
20992181
torch::Tensor c = torch::empty({size_m, size_n}, options);
21002182
torch::Tensor a_tmp = torch::empty({size_m, size_k}, options);
21012183

2184+
// Alloc C tmp buffer that is going to be used for the global reduce
2185+
int reduce_max_m = marlin::determine_reduce_max_m(size_m, marlin::max_par);
2186+
int reduce_n = size_n;
2187+
auto options_fp32 =
2188+
torch::TensorOptions().dtype(at::kFloat).device(a.device());
2189+
if (!use_fp32_reduce) {
2190+
reduce_max_m = 0;
2191+
reduce_n = 0;
2192+
}
2193+
torch::Tensor c_tmp = torch::empty({reduce_max_m, reduce_n}, options_fp32);
2194+
21022195
// thread_k: `k` size of a thread_tile in `weights` (can usually be left as
21032196
// auto -1)
21042197
int thread_k = -1;
@@ -2171,20 +2264,21 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
21712264
if (a.scalar_type() == at::ScalarType::Half) {
21722265
marlin::marlin_mm_f16i4<half>(
21732266
a.data_ptr<at::Half>(), b_q_weight.data_ptr(), c.data_ptr<at::Half>(),
2174-
b_scales.data_ptr<at::Half>(), b_zeros.data_ptr(), g_idx.data_ptr(),
2175-
perm.data_ptr(), a_tmp.data_ptr<at::Half>(), size_m, size_n, size_k,
2267+
c_tmp.data_ptr<float>(), b_scales.data_ptr<at::Half>(),
2268+
b_zeros.data_ptr(), g_idx.data_ptr(), perm.data_ptr(),
2269+
a_tmp.data_ptr<at::Half>(), size_m, size_n, size_k,
21762270
workspace.data_ptr(), num_bits, has_act_order, is_k_full, has_zp,
21772271
num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev),
2178-
thread_k, thread_n, sms, marlin::max_par);
2272+
thread_k, thread_n, sms, marlin::max_par, use_fp32_reduce);
21792273
} else if (a.scalar_type() == at::ScalarType::BFloat16) {
21802274
marlin::marlin_mm_f16i4<nv_bfloat16>(
21812275
a.data_ptr<at::BFloat16>(), b_q_weight.data_ptr(),
2182-
c.data_ptr<at::BFloat16>(), b_scales.data_ptr<at::BFloat16>(),
2183-
b_zeros.data_ptr(), g_idx.data_ptr(), perm.data_ptr(),
2184-
a_tmp.data_ptr<at::BFloat16>(), size_m, size_n, size_k,
2276+
c.data_ptr<at::BFloat16>(), c_tmp.data_ptr<float>(),
2277+
b_scales.data_ptr<at::BFloat16>(), b_zeros.data_ptr(), g_idx.data_ptr(),
2278+
perm.data_ptr(), a_tmp.data_ptr<at::BFloat16>(), size_m, size_n, size_k,
21852279
workspace.data_ptr(), num_bits, has_act_order, is_k_full, has_zp,
21862280
num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev),
2187-
thread_k, thread_n, sms, marlin::max_par);
2281+
thread_k, thread_n, sms, marlin::max_par, use_fp32_reduce);
21882282
} else {
21892283
TORCH_CHECK(false, "gpt_marlin_gemm only supports bfloat16 and float16");
21902284
}

0 commit comments

Comments
 (0)