Skip to content

[Kernel] optimize performance of gptq marlin kernel when n is small #14138

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Mar 7, 2025
Merged
62 changes: 46 additions & 16 deletions csrc/quantization/gptq_marlin/gptq_marlin.cu
Original file line number Diff line number Diff line change
Expand Up @@ -538,6 +538,7 @@ __global__ void Marlin(
int prob_n, // output dimension n
int prob_k, // reduction dimension k
int* locks, // extra global storage for barrier synchronization
bool use_atomic_add, // whether to use atomic add to reduce
bool use_fp32_reduce // whether to use fp32 global reduce
) {
// Each threadblock processes one "stripe" of the B matrix with (roughly) the
Expand Down Expand Up @@ -1542,7 +1543,17 @@ __global__ void Marlin(
i < div_ceil(16 * thread_m_blocks, threads / (2 * thread_n_blocks));
i++) {
if (c_gl_wr < c_gl_wr_end) {
C[c_gl_wr] = sh_red[c_sh_rd];
if (use_atomic_add && slice_count > 1) {
scalar_t2* C_half2 = reinterpret_cast<scalar_t2*>(&C[c_gl_wr]);
scalar_t2* sh_red_half2 =
reinterpret_cast<scalar_t2*>(&sh_red[c_sh_rd]);
#pragma unroll
for (int a = 0; a < 4; a++) {
atomicAdd(&C_half2[a], sh_red_half2[a]);
}
} else {
C[c_gl_wr] = sh_red[c_sh_rd];
}
c_gl_wr += c_gl_wr_delta;
c_sh_rd += c_sh_rd_delta;
}
Expand Down Expand Up @@ -1644,7 +1655,7 @@ __global__ void Marlin(
}
cp_async_fence();
} else {
if (last) {
if (last || use_atomic_add) {
if (s_sh_wr_pred) {
cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]);
}
Expand All @@ -1664,7 +1675,7 @@ __global__ void Marlin(
}

} else {
if (last) {
if (last || use_atomic_add) {
cp_async_wait<0>();
__syncthreads();
if (threadIdx.x / 32 < thread_n_blocks / 4) {
Expand Down Expand Up @@ -1703,8 +1714,8 @@ __global__ void Marlin(
}
}

if (slice_count > 1) { // only globally reduce if there is more than one
// block in a slice
if (slice_count > 1 && !use_atomic_add) {
// only globally reduce if there is more than one block in a slice
barrier_acquire(&locks[slice_col], slice_idx);
if (use_fp32_reduce) {
global_reduce_fp32(slice_idx == 0, last);
Expand All @@ -1713,7 +1724,8 @@ __global__ void Marlin(
}
barrier_release(&locks[slice_col], last);
}
if (last) // only the last block in a slice actually writes the result
if (last || use_atomic_add)
// only the last block in a slice actuallywrites the result
write_result();
slice_row = 0;
slice_col_par++;
Expand Down Expand Up @@ -1768,7 +1780,8 @@ __global__ void Marlin(
HAS_ZP, GROUP_BLOCKS, IS_ZP_FLOAT> \
<<<blocks, NUM_THREADS, max_shared_mem, stream>>>( \
A_ptr, B_ptr, C_ptr, C_tmp_ptr, s_ptr, zp_ptr, g_idx_ptr, \
num_groups, prob_m, prob_n, prob_k, locks, use_fp32_reduce); \
num_groups, prob_m, prob_n, prob_k, locks, use_atomic_add, \
use_fp32_reduce); \
} \
}

Expand Down Expand Up @@ -2062,7 +2075,8 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
vllm::ScalarType const& q_type, bool has_act_order,
bool is_k_full, bool has_zp, int num_groups, int group_size,
int dev, cudaStream_t stream, int thread_k, int thread_n,
int sms, int max_par, bool use_fp32_reduce, bool is_zp_float) {
int sms, int max_par, bool use_atomic_add, bool use_fp32_reduce,
bool is_zp_float) {
if (has_zp) {
TORCH_CHECK(
q_type == vllm::kU4 || q_type == vllm::kU8,
Expand Down Expand Up @@ -2243,7 +2257,7 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
torch::Tensor& workspace,
vllm::ScalarTypeId const& b_q_type_id,
int64_t size_m, int64_t size_n, int64_t size_k,
bool is_k_full, bool has_zp,
bool is_k_full, bool has_zp, bool use_atomic_add,
bool use_fp32_reduce, bool is_zp_float) {
vllm::ScalarType const b_q_type = vllm::ScalarType::from_id(b_q_type_id);
if (has_zp) {
Expand Down Expand Up @@ -2306,19 +2320,34 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
// Alloc buffers
const at::cuda::OptionalCUDAGuard device_guard(device_of(a));
auto options = torch::TensorOptions().dtype(a.dtype()).device(a.device());
torch::Tensor c = torch::empty({size_m, size_n}, options);
torch::Tensor a_tmp = torch::empty({size_m, size_k}, options);
torch::Tensor c;
if (use_atomic_add) {
c = torch::zeros({size_m, size_n}, options);
} else {
c = torch::empty({size_m, size_n}, options);
}

torch::Tensor a_tmp;
bool has_act_order = g_idx.size(0) != 0;
if (has_act_order) {
a_tmp = torch::empty({size_m, size_k}, options);
} else {
a_tmp = torch::empty({0}, options);
}

// Alloc C tmp buffer that is going to be used for the global reduce
torch::Tensor c_tmp;
int reduce_max_m = marlin::determine_reduce_max_m(size_m, marlin::max_par);
int reduce_n = size_n;
auto options_fp32 =
torch::TensorOptions().dtype(at::kFloat).device(a.device());
if (!use_fp32_reduce) {
if (use_fp32_reduce) {
c_tmp = torch::empty({reduce_max_m, reduce_n}, options_fp32);
} else {
reduce_max_m = 0;
reduce_n = 0;
c_tmp = torch::empty({0}, options_fp32);
}
torch::Tensor c_tmp = torch::empty({reduce_max_m, reduce_n}, options_fp32);

// thread_k: `k` size of a thread_tile in `weights` (can usually be left as
// auto -1)
Expand All @@ -2339,7 +2368,6 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
// Detect groupsize and act_order
int num_groups = -1;
int group_size = -1;
bool has_act_order = g_idx.size(0) != 0;

int rank = b_scales.sizes().size();
TORCH_CHECK(rank == 2, "b_scales rank = ", rank, " is not 2");
Expand Down Expand Up @@ -2407,7 +2435,8 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
a_tmp.data_ptr<at::Half>(), size_m, size_n, size_k,
workspace.data_ptr(), b_q_type, has_act_order, is_k_full, has_zp,
num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev),
thread_k, thread_n, sms, marlin::max_par, use_fp32_reduce, is_zp_float);
thread_k, thread_n, sms, marlin::max_par, use_atomic_add,
use_fp32_reduce, is_zp_float);
} else if (a.scalar_type() == at::ScalarType::BFloat16) {
marlin::marlin_mm<nv_bfloat16>(
a.data_ptr<at::BFloat16>(), b_q_weight.data_ptr(),
Expand All @@ -2416,7 +2445,8 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
perm.data_ptr(), a_tmp.data_ptr<at::BFloat16>(), size_m, size_n, size_k,
workspace.data_ptr(), b_q_type, has_act_order, is_k_full, has_zp,
num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev),
thread_k, thread_n, sms, marlin::max_par, use_fp32_reduce, is_zp_float);
thread_k, thread_n, sms, marlin::max_par, use_atomic_add,
use_fp32_reduce, is_zp_float);
} else {
TORCH_CHECK(false, "gpt_marlin_gemm only supports bfloat16 and float16");
}
Expand Down
3 changes: 2 additions & 1 deletion csrc/torch_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"Tensor b_zeros, Tensor g_idx, Tensor perm, Tensor workspace, "
"int b_q_type, "
"SymInt size_m, SymInt size_n, SymInt size_k, bool is_k_full, "
"bool has_zp, bool use_fp32_reduce, bool is_zp_float) -> Tensor",
"bool has_zp, bool use_atomic_add, bool use_fp32_reduce, "
"bool is_zp_float) -> Tensor",
{stride_tag});
// conditionally compiled so impl registration is in source file

Expand Down
16 changes: 10 additions & 6 deletions tests/kernels/test_marlin_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@

ACT_ORDER_OPTS = [False, True]
K_FULL_OPTS = [False, True]
USE_ATOMIC_ADD_OPTS = [False, True]
USE_FP32_REDUCE_OPTS = [False, True]

MARLIN_K_CHUNKS = [128]
Expand Down Expand Up @@ -194,6 +195,7 @@ def test_awq_marlin_repack(k_chunk, n_chunk, quant_type, group_size,
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
@pytest.mark.parametrize("act_order", ACT_ORDER_OPTS)
@pytest.mark.parametrize("is_k_full", K_FULL_OPTS)
@pytest.mark.parametrize("use_atomic_add", USE_ATOMIC_ADD_OPTS)
@pytest.mark.parametrize("use_fp32_reduce", USE_FP32_REDUCE_OPTS)
def test_gptq_marlin_gemm(
k_chunk,
Expand All @@ -203,6 +205,7 @@ def test_gptq_marlin_gemm(
mnk_factors,
act_order,
is_k_full,
use_atomic_add,
use_fp32_reduce,
):
m_factor, n_factor, k_factor = mnk_factors
Expand All @@ -228,12 +231,12 @@ def test_gptq_marlin_gemm(
workspace = MarlinWorkspace(size_n, GPTQ_MARLIN_MIN_THREAD_N,
GPTQ_MARLIN_MAX_PARALLEL)

opcheck(
torch.ops._C.gptq_marlin_gemm,
(a_input, marlin_q_w, marlin_s, marlin_zp, g_idx, sort_indices,
workspace.scratch, quant_type.id, a_input.shape[0], b_weight.shape[1],
a_input.shape[1], is_k_full, False, use_fp32_reduce, False),
test_utils=DEFAULT_OPCHECK_TEST_UTILS)
opcheck(torch.ops._C.gptq_marlin_gemm,
(a_input, marlin_q_w, marlin_s, marlin_zp, g_idx, sort_indices,
workspace.scratch, quant_type.id, a_input.shape[0],
b_weight.shape[1], a_input.shape[1], is_k_full, False,
use_atomic_add, use_fp32_reduce, False),
test_utils=DEFAULT_OPCHECK_TEST_UTILS)

output = ops.gptq_marlin_gemm(
a_input,
Expand All @@ -249,6 +252,7 @@ def test_gptq_marlin_gemm(
a_input.shape[1],
is_k_full=is_k_full,
has_zp=False,
use_atomic_add=use_atomic_add,
use_fp32_reduce=use_fp32_reduce,
is_zp_float=False,
)
Expand Down
5 changes: 4 additions & 1 deletion vllm/_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,7 @@ def _gptq_marlin_gemm_fake(a: torch.Tensor,
size_k: torch.SymInt,
is_k_full: bool,
has_zp: bool = False,
use_atomic_add: bool = False,
use_fp32_reduce: bool = False,
is_zp_float: bool = False) -> torch.Tensor:
return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype)
Expand Down Expand Up @@ -713,12 +714,14 @@ def gptq_marlin_gemm(a: torch.Tensor,
size_k: int,
is_k_full: bool,
has_zp: bool = False,
use_atomic_add: bool = False,
use_fp32_reduce: bool = False,
is_zp_float: bool = False) -> torch.Tensor:
return torch.ops._C.gptq_marlin_gemm(a, b_q_weight, b_scales, b_zeros,
g_idx, perm, workspace, b_q_type.id,
size_m, size_n, size_k, is_k_full,
has_zp, use_fp32_reduce, is_zp_float)
has_zp, use_atomic_add,
use_fp32_reduce, is_zp_float)


# fp8 marlin
Expand Down
5 changes: 5 additions & 0 deletions vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@
VLLM_DP_SIZE: int = 1
VLLM_DP_MASTER_IP: str = ""
VLLM_DP_MASTER_PORT: int = 0
VLLM_MARLIN_USE_ATOMIC_ADD: bool = False


def get_default_cache_root():
Expand Down Expand Up @@ -630,6 +631,10 @@ def maybe_convert_int(value: Optional[str]) -> Optional[int]:
# Whether to use S3 path for model loading in CI via RunAI Streamer
"VLLM_CI_USE_S3":
lambda: os.environ.get("VLLM_CI_USE_S3", "0") == "1",

# Whether to use atomicAdd reduce in gptq/awq marlin kernel.
"VLLM_MARLIN_USE_ATOMIC_ADD":
lambda: os.environ.get("VLLM_MARLIN_USE_ATOMIC_ADD", "0") == "1",
}

# end-env-vars-definition
Expand Down
32 changes: 32 additions & 0 deletions vllm/model_executor/layers/quantization/utils/marlin_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import numpy
import torch

import vllm.envs as envs
from vllm import _custom_ops as ops
from vllm.model_executor.layers.linear import LinearBase
from vllm.platforms import current_platform
Expand Down Expand Up @@ -290,6 +291,23 @@ def moe_awq_to_marlin_zero_points(q_zp_packed: torch.Tensor, size_k: int,
return output


def should_use_atomic_add_reduce(m: int, n: int, k: int, device: torch.device,
dtype: torch.dtype) -> bool:
# disable atomicAdd reduce by default,
# one can enable it with VLLM_MARLIN_USE_ATOMIC_ADD=1
if not envs.VLLM_MARLIN_USE_ATOMIC_ADD or device.type != "cuda":
return False

# sm8x doesn't support atomicAdd + bfloat16 natively
device_capability = torch.cuda.get_device_capability(device)
if device_capability[0] < 9 and dtype == torch.bfloat16:
return False

# the performance of atomicAdd is better than global reduce
# only when m*n is small and k is large
return max(m, 64) * n < 64 * 2048 and k >= 2048


def apply_gptq_marlin_linear(
input: torch.Tensor,
weight: torch.Tensor,
Expand All @@ -307,6 +325,12 @@ def apply_gptq_marlin_linear(
reshaped_x = input.reshape(-1, input.shape[-1])
out_shape = input.shape[:-1] + (output_size_per_partition, )

use_atomic_add = should_use_atomic_add_reduce(m=reshaped_x.size(0),
n=output_size_per_partition,
k=reshaped_x.size(1),
device=input.device,
dtype=input.dtype)

output = ops.gptq_marlin_gemm(reshaped_x,
weight,
weight_scale,
Expand All @@ -320,6 +344,7 @@ def apply_gptq_marlin_linear(
size_k=input_size_per_partition,
is_k_full=is_k_full,
has_zp=False,
use_atomic_add=use_atomic_add,
use_fp32_reduce=use_fp32_reduce,
is_zp_float=False)

Expand All @@ -345,6 +370,12 @@ def apply_awq_marlin_linear(
reshaped_x = input.reshape(-1, input.shape[-1])
out_shape = input.shape[:-1] + (output_size_per_partition, )

use_atomic_add = should_use_atomic_add_reduce(m=reshaped_x.size(0),
n=output_size_per_partition,
k=reshaped_x.size(1),
device=input.device,
dtype=input.dtype)

output = ops.gptq_marlin_gemm(reshaped_x,
weight,
weight_scale,
Expand All @@ -358,6 +389,7 @@ def apply_awq_marlin_linear(
size_k=input_size_per_partition,
is_k_full=True,
has_zp=True,
use_atomic_add=use_atomic_add,
use_fp32_reduce=use_fp32_reduce,
is_zp_float=False)

Expand Down