Skip to content
This repository has been archived by the owner on Oct 11, 2024. It is now read-only.

Commit

Permalink
[Kernel] Add marlin_24 unit tests (vllm-project#4901)
Browse files Browse the repository at this point in the history
  • Loading branch information
alexm-neuralmagic authored and robertgshaw2-neuralmagic committed Jun 8, 2024
1 parent 9fe9187 commit e69d23b
Show file tree
Hide file tree
Showing 6 changed files with 649 additions and 103 deletions.
87 changes: 74 additions & 13 deletions tests/kernels/test_marlin_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,38 +7,46 @@

from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.gptq_marlin import (
GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N,
GPTQ_MARLIN_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_SUPPORTED_NUM_BITS)
from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
GPTQ_MARLIN_24_MAX_PARALLEL, GPTQ_MARLIN_24_MIN_THREAD_N,
GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_24_SUPPORTED_NUM_BITS)
from vllm.model_executor.layers.quantization.utils.marlin_perms import (
marlin_perm)
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
MarlinWorkspace, is_marlin_supported, marlin_quantize, marlin_weights)
MarlinWorkspace, compute_max_diff, is_marlin_supported, marlin_24_quantize,
marlin_quantize, marlin_weights)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
gptq_pack, quantize_weights, sort_weights)

ACT_ORDER_OPTS = [False, True]
K_FULL_OPTS = [False, True]

K_CHUNKS = [128, 256]
N_CHUNKS = [64, 128, 256]
MARLIN_K_CHUNKS = [128]
MARLIN_N_CHUNKS = [64, 128, 256]

MARLIN_24_K_CHUNKS = [128]
MARLIN_24_N_CHUNKS = [256]

MNK_FACTORS = [
(1, 1, 1),
(1, 4, 8),
(1, 7, 5),
(1, 7 * 4, 5 * 1),
(13, 17, 67),
(26, 37, 13),
(67, 13, 11),
]


def rand_data(shape):
data = torch.rand(shape).to(torch.half).cuda()
return data
return torch.randn(shape, dtype=torch.half, device="cuda")


@pytest.mark.skipif(not is_marlin_supported(),
reason="Marlin is not supported on this GPU type.")
@pytest.mark.parametrize("k_chunk", K_CHUNKS)
@pytest.mark.parametrize("n_chunk", N_CHUNKS)
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
@pytest.mark.parametrize("num_bits", GPTQ_MARLIN_SUPPORTED_NUM_BITS)
@pytest.mark.parametrize("group_size", GPTQ_MARLIN_SUPPORTED_GROUP_SIZES)
@pytest.mark.parametrize("act_order", ACT_ORDER_OPTS)
Expand Down Expand Up @@ -82,7 +90,8 @@ def test_marlin_repack(k_chunk, n_chunk, num_bits, group_size, act_order,
q_w, g_idx, sort_indices = sort_weights(q_w, g_idx)

# Pack to Marlin format
marlin_q_w_1 = marlin_weights(q_w, size_k, size_n, num_bits)
marlin_q_w_1 = marlin_weights(q_w, size_k, size_n, num_bits,
marlin_perm[num_bits])

# Run Marlin repack GPU kernel
marlin_q_w_2 = ops.gptq_marlin_repack(
Expand All @@ -99,8 +108,8 @@ def test_marlin_repack(k_chunk, n_chunk, num_bits, group_size, act_order,

@pytest.mark.skipif(not is_marlin_supported(),
reason="Marlin is not supported on this GPU type.")
@pytest.mark.parametrize("k_chunk", K_CHUNKS)
@pytest.mark.parametrize("n_chunk", N_CHUNKS)
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
@pytest.mark.parametrize("num_bits", GPTQ_MARLIN_SUPPORTED_NUM_BITS)
@pytest.mark.parametrize("group_size", GPTQ_MARLIN_SUPPORTED_GROUP_SIZES)
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
Expand Down Expand Up @@ -136,7 +145,8 @@ def test_marlin_gemm(
w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, _ = marlin_quantize(
b_weight, num_bits, group_size, act_order)

workspace = MarlinWorkspace(size_n)
workspace = MarlinWorkspace(size_n, GPTQ_MARLIN_MIN_THREAD_N,
GPTQ_MARLIN_MAX_PARALLEL)

output = ops.gptq_marlin_gemm(
a_input,
Expand All @@ -155,4 +165,55 @@ def test_marlin_gemm(

torch.cuda.synchronize()

assert torch.allclose(output, output_ref, rtol=1e-2)
max_diff = compute_max_diff(output, output_ref)
print("max_diff = {}".format(max_diff))

assert max_diff < 0.04


@pytest.mark.skipif(not is_marlin_supported(),
reason="Marlin is not supported on this GPU type.")
@pytest.mark.parametrize("k_chunk", MARLIN_24_K_CHUNKS)
@pytest.mark.parametrize("n_chunk", MARLIN_24_N_CHUNKS)
@pytest.mark.parametrize("num_bits", GPTQ_MARLIN_24_SUPPORTED_NUM_BITS)
@pytest.mark.parametrize("group_size", GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES)
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
def test_marlin_24_gemm(k_chunk, n_chunk, num_bits, group_size, mnk_factors):
m_factor, n_factor, k_factor = mnk_factors

size_m = m_factor
size_k = k_chunk * k_factor
size_n = n_chunk * n_factor

print(f"MNK = {size_m} {size_n} {size_k}")
print(f"groupsize = {group_size}")

a_input = rand_data((size_m, size_k))
b_weight = rand_data((size_k, size_n))

(w_24_ref, marlin_24_q_w_comp, marlin_24_meta,
marlin_24_s) = marlin_24_quantize(b_weight, num_bits, group_size)

workspace_24 = MarlinWorkspace(size_n, GPTQ_MARLIN_24_MIN_THREAD_N,
GPTQ_MARLIN_24_MAX_PARALLEL)

output_ref = torch.matmul(a_input, w_24_ref)

output = ops.gptq_marlin_24_gemm(
a_input,
marlin_24_q_w_comp,
marlin_24_meta,
marlin_24_s,
workspace_24.scratch,
num_bits,
a_input.shape[0],
b_weight.shape[1],
a_input.shape[1],
)

torch.cuda.synchronize()

max_diff = compute_max_diff(output, output_ref)
print("max_diff = {}".format(max_diff))

assert max_diff < 0.04
27 changes: 19 additions & 8 deletions vllm/model_executor/layers/quantization/gptq_marlin_24.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,15 @@

logger = init_logger(__name__)

GPTQ_MARLIN_24_TILE = 16
GPTQ_MARLIN_24_MIN_THREAD_N = 128
GPTQ_MARLIN_24_MIN_THREAD_K = 128
GPTQ_MARLIN_24_MAX_PARALLEL = 16

GPTQ_MARLIN_24_SUPPORTED_NUM_BITS = [4, 8]
GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES = [-1, 128]
GPTQ_MARLIN_24_SUPPORTED_SYM = [True]


class GPTQMarlin24Config(QuantizationConfig):
"""Config class for Marlin24.
Expand All @@ -25,15 +34,17 @@ def __init__(
self.weight_bits = weight_bits
self.group_size = group_size

if self.weight_bits != 4 and self.weight_bits != 8:
raise ValueError("weight_bits must be 4 or 8. Got = {}".format(
self.weight_bits))

if self.group_size != 128 and self.group_size != -1:
# Verify
if self.weight_bits not in GPTQ_MARLIN_24_SUPPORTED_NUM_BITS:
raise ValueError(
f"Marlin_24 does not support weight_bits = {self.weight_bits}. "
f"Only weight_bits = {GPTQ_MARLIN_24_SUPPORTED_NUM_BITS} "
"are supported.")
if self.group_size not in GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES:
raise ValueError(
"Currently, only group size 128 and -1 (channelwise) "
"is supported for Marlin24, but got group_size of "
f"{self.group_size}")
f"Marlin_24 does not support group_size = {self.group_size}. "
f"Only group_sizes = {GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES} "
"are supported.")

# 4 Bits packed into 32 bit datatype.
self.pack_factor = 32 // self.weight_bits
Expand Down
Loading

0 comments on commit e69d23b

Please sign in to comment.