Skip to content

[Misc][Kernel]: Add GPTQAllSpark Quantization #12931

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Mar 1, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions CMakeLists.txt
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,22 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
" in CUDA target architectures")
endif()

# Only build AllSpark kernels if we are building for at least some compatible archs.
cuda_archs_loose_intersection(ALLSPARK_ARCHS "8.0;8.6;8.7;8.9" "${CUDA_ARCHS}")
if (ALLSPARK_ARCHS)
set(ALLSPARK_SRCS
"csrc/quantization/gptq_allspark/allspark_repack.cu"
"csrc/quantization/gptq_allspark/allspark_qgemm_w8a16.cu")
set_gencode_flags_for_srcs(
SRCS "${ALLSPARK_SRCS}"
CUDA_ARCHS "${ALLSPARK_ARCHS}")
list(APPEND VLLM_EXT_SRC "${ALLSPARK_SRCS}")
message(STATUS "Building AllSpark kernels for archs: ${ALLSPARK_ARCHS}")
else()
message(STATUS "Not building AllSpark kernels as no compatible archs found"
" in CUDA target architectures")
endif()

# The cutlass_scaled_mm kernels for Hopper (c3x, i.e. CUTLASS 3.x) require
# CUDA 12.0 or later (and only work on Hopper, 9.0a for now).
cuda_archs_loose_intersection(SCALED_MM_3X_ARCHS "9.0a" "${CUDA_ARCHS}")
Expand Down
47 changes: 45 additions & 2 deletions benchmarks/kernels/benchmark_marlin.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
GPTQ_MARLIN_24_MAX_PARALLEL, GPTQ_MARLIN_24_MIN_THREAD_N,
GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES)
from vllm.model_executor.layers.quantization.utils.allspark_utils import (
ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD, ALLSPARK_SUPPORTED_QUANT_TYPES)
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N,
MARLIN_SUPPORTED_GROUP_SIZES, query_marlin_supported_quant_types)
Expand All @@ -18,12 +20,12 @@
from vllm.model_executor.layers.quantization.utils.marlin_utils_test_24 import (
marlin_24_quantize)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
gptq_pack, gptq_quantize_weights, sort_weights)
gptq_pack, gptq_quantize_weights, quantize_weights, sort_weights)
from vllm.scalar_type import ScalarType
from vllm.utils import FlexibleArgumentParser

DEFAULT_MODELS = ["meta-llama/Llama-2-7b-hf/TP1"]
DEFAULT_BATCH_SIZES = [1, 16, 32, 64, 128, 256, 512]
DEFAULT_BATCH_SIZES = [1, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192]

ACT_ORDER_OPTS = [False, True]
K_FULL_OPTS = [False, True]
Expand Down Expand Up @@ -81,6 +83,27 @@ def bench_run(results: List[benchmark.Measurement], model: str,
GPTQ_MARLIN_24_MAX_PARALLEL)
marlin_zp = torch.zeros_like(marlin_s, dtype=torch.int)

# AllSpark W8A16 quant
as_supported_case = (quant_type in ALLSPARK_SUPPORTED_QUANT_TYPES
and group_size == -1 and not act_order and is_k_full)
if as_supported_case:
properties = torch.cuda.get_device_properties(b.device.index)
sm_count = properties.multi_processor_count
sm_version = properties.major * 10 + properties.minor

supported_arch = (sm_version >= 80 and sm_version < 90)
as_supported_case = as_supported_case and supported_arch
if supported_arch:
has_zp = False
w_ref, qw, s, zp = quantize_weights(b, quant_type, group_size,
has_zp)
qw = qw.to(torch.uint8)

qw_reorder, s_reorder, zp_reorder = \
ops.allspark_repack_weight(
qw, s, zp, has_zp)
CUBLAS_M_THRESHOLD = ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD

globals = {
# Gen params
"quant_type": quant_type,
Expand Down Expand Up @@ -109,10 +132,19 @@ def bench_run(results: List[benchmark.Measurement], model: str,
# GPTQ params
"q_w_gptq": q_w_gptq,
"repack_sort_indices": repack_sort_indices,
# AllSpark W8A16 params
"qw_reorder": qw_reorder if as_supported_case else None,
"s_reorder": s_reorder if as_supported_case else None,
"zp_reorder": zp_reorder if as_supported_case else None,
"sm_count": sm_count if as_supported_case else None,
"sm_version": sm_version if as_supported_case else None,
"CUBLAS_M_THRESHOLD":
CUBLAS_M_THRESHOLD if as_supported_case else None,
# Kernels
"gptq_marlin_gemm": ops.gptq_marlin_gemm,
"gptq_marlin_24_gemm": ops.gptq_marlin_24_gemm,
"gptq_marlin_repack": ops.gptq_marlin_repack,
"allspark_w8a16_gemm": ops.allspark_w8a16_gemm,
}

min_run_time = 1
Expand Down Expand Up @@ -172,6 +204,17 @@ def bench_run(results: List[benchmark.Measurement], model: str,
description="gptq_marlin_repack",
).blocked_autorange(min_run_time=min_run_time))

if as_supported_case:
results.append(
benchmark.Timer(
stmt=
"output = allspark_w8a16_gemm(a, qw_reorder, s_reorder, zp_reorder, size_n, group_size, sm_count, sm_version, CUBLAS_M_THRESHOLD, False, True)", # noqa: E501
globals=globals,
label=label,
sub_label=sub_label,
description="allspark_w8a16_gemm_fp32",
).blocked_autorange(min_run_time=min_run_time))


def main(args):
print("Benchmarking models:")
Expand Down
Loading