Skip to content

Commit

Permalink
review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
LucasWilkinson committed Aug 1, 2024
1 parent 443d00d commit 4dc605e
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 10 deletions.
5 changes: 3 additions & 2 deletions benchmarks/kernels/benchmark_marlin.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES)
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N,
MARLIN_SUPPORTED_GROUP_SIZES, marlin_supported_quant_types)
MARLIN_SUPPORTED_GROUP_SIZES, query_marlin_supported_quant_types)
from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
MarlinWorkspace, marlin_quantize)
from vllm.model_executor.layers.quantization.utils.marlin_utils_test_24 import (
Expand Down Expand Up @@ -199,7 +199,8 @@ def main(args):
) > 0 and is_k_full not in args.limit_k_full:
continue

for quant_type in marlin_supported_quant_types(False):
for quant_type in query_marlin_supported_quant_types(
False):
if len(args.limit_num_bits) > 0 and \
quant_type.size_bits not in args.limit_num_bits:
continue
Expand Down
14 changes: 9 additions & 5 deletions tests/kernels/test_marlin_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N,
MARLIN_SUPPORTED_GROUP_SIZES, marlin_make_empty_g_idx,
marlin_permute_scales, marlin_supported_quant_types)
marlin_permute_scales, query_marlin_supported_quant_types)
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
pack_fp8_to_int32)
from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
Expand Down Expand Up @@ -64,7 +64,8 @@ def rand_data(shape, dtype=torch.float16):
reason="Marlin is not supported on this GPU type.")
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
@pytest.mark.parametrize("quant_type", marlin_supported_quant_types(False))
@pytest.mark.parametrize("quant_type",
query_marlin_supported_quant_types(False))
@pytest.mark.parametrize("group_size", MARLIN_SUPPORTED_GROUP_SIZES)
@pytest.mark.parametrize("act_order", ACT_ORDER_OPTS)
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
Expand Down Expand Up @@ -128,7 +129,8 @@ def test_gptq_marlin_repack(k_chunk, n_chunk, quant_type, group_size,
reason="Marlin is not supported on this GPU type.")
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
@pytest.mark.parametrize("quant_type", marlin_supported_quant_types(False))
@pytest.mark.parametrize("quant_type",
query_marlin_supported_quant_types(False))
@pytest.mark.parametrize("group_size", MARLIN_SUPPORTED_GROUP_SIZES)
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
def test_awq_marlin_repack(k_chunk, n_chunk, quant_type, group_size,
Expand Down Expand Up @@ -179,7 +181,8 @@ def test_awq_marlin_repack(k_chunk, n_chunk, quant_type, group_size,
reason="Marlin is not supported on this GPU type.")
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
@pytest.mark.parametrize("quant_type", marlin_supported_quant_types(False))
@pytest.mark.parametrize("quant_type",
query_marlin_supported_quant_types(False))
@pytest.mark.parametrize("group_size", MARLIN_SUPPORTED_GROUP_SIZES)
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
@pytest.mark.parametrize("act_order", ACT_ORDER_OPTS)
Expand Down Expand Up @@ -374,7 +377,8 @@ def test_fp8_marlin_gemm(
reason="Marlin is not supported on this GPU type.")
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
@pytest.mark.parametrize("quant_type", marlin_supported_quant_types(True))
@pytest.mark.parametrize("quant_type",
query_marlin_supported_quant_types(True))
@pytest.mark.parametrize("group_size", MARLIN_SUPPORTED_GROUP_SIZES)
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
@pytest.mark.parametrize("use_fp32_reduce", USE_FP32_REDUCE_OPTS)
Expand Down
7 changes: 4 additions & 3 deletions vllm/model_executor/layers/quantization/utils/marlin_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@
# For binary size and compile time, we don't support the same types for with and
# without runtime zero-point. We support common cases, i.e. AWQ and GPTQ.
# TODO: we may want to move this into the C++ so its closer to the actual impl
def marlin_supported_quant_types(has_zp: bool,
min_capability: Optional[int] = None):
def query_marlin_supported_quant_types(has_zp: bool,
min_capability: Optional[int] = None):
if min_capability is None:
major, minor = current_platform.get_device_capability()
min_capability = major * 10 + minor
Expand Down Expand Up @@ -54,7 +54,8 @@ def _check_marlin_supported(
major, minor = current_platform.get_device_capability()
min_capability = major * 10 + minor

supported_types = marlin_supported_quant_types(has_zp, min_capability)
supported_types = query_marlin_supported_quant_types(
has_zp, min_capability)

if quant_type not in supported_types:
return (False, f"Marlin does not support weight_bits = {quant_type}. "
Expand Down

0 comments on commit 4dc605e

Please sign in to comment.