From 5223199e03ac3729eb60043a1ef57156c8af1bc9 Mon Sep 17 00:00:00 2001 From: Michael Goin Date: Wed, 7 Aug 2024 14:23:12 -0400 Subject: [PATCH] [Bugfix][FP8] Fix dynamic FP8 Marlin quantization (#7219) --- tests/quantization/test_fp8.py | 19 +++++++++++++++---- vllm/envs.py | 8 ++++++++ .../model_executor/layers/quantization/fp8.py | 11 ++++++++++- 3 files changed, 33 insertions(+), 5 deletions(-) diff --git a/tests/quantization/test_fp8.py b/tests/quantization/test_fp8.py index a020f7bf37262..ebb06ed20f249 100644 --- a/tests/quantization/test_fp8.py +++ b/tests/quantization/test_fp8.py @@ -9,6 +9,7 @@ from vllm import _custom_ops as ops from vllm.model_executor.layers.quantization.fp8 import (Fp8KVCacheMethod, Fp8LinearMethod) +from vllm.platforms import current_platform MODELS = [ "neuralmagic/Meta-Llama-3-8B-Instruct-FP8-KV", @@ -20,7 +21,12 @@ @pytest.mark.skipif(not is_quant_method_supported("fp8"), reason="FP8 is not supported on this GPU type.") @pytest.mark.parametrize("model_id", MODELS) -def test_model_load_and_run(vllm_runner, model_id: str): +@pytest.mark.parametrize("force_marlin", [False, True]) +def test_model_load_and_run(vllm_runner, model_id: str, force_marlin: bool, + monkeypatch) -> None: + if force_marlin: + monkeypatch.setenv("VLLM_TEST_FORCE_FP8_MARLIN", "1") + with vllm_runner(model_id) as llm: # note: this does not test accuracy, just that we can run through # see lm-eval tests for accuracy @@ -61,7 +67,12 @@ def test_kv_cache_model_load_and_run(vllm_runner, model_id: str): @pytest.mark.skipif(not is_quant_method_supported("fp8"), reason="FP8 is not supported on this GPU type.") @pytest.mark.parametrize("kv_cache_dtype", ["auto", "fp8"]) -def test_load_fp16_model(vllm_runner, kv_cache_dtype: str) -> None: +@pytest.mark.parametrize("force_marlin", [False, True]) +def test_load_fp16_model(vllm_runner, kv_cache_dtype: str, force_marlin: bool, + monkeypatch) -> None: + if force_marlin: + monkeypatch.setenv("VLLM_TEST_FORCE_FP8_MARLIN", "1") + with vllm_runner("facebook/opt-125m", quantization="fp8", kv_cache_dtype=kv_cache_dtype) as llm: @@ -75,9 +86,9 @@ def test_load_fp16_model(vllm_runner, kv_cache_dtype: str) -> None: assert attn._k_scale == 1.0 assert attn._v_scale == 1.0 - capability = torch.cuda.get_device_capability() + capability = current_platform.get_device_capability() capability = capability[0] * 10 + capability[1] - if capability >= 89: + if capability >= 89 and not force_marlin: # For GPUs with hardware support, we keep weights in fp8 assert fc1.weight.dtype == torch.float8_e4m3fn else: diff --git a/vllm/envs.py b/vllm/envs.py index df4c994359dbd..81f30b1d42a13 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -52,6 +52,7 @@ CMAKE_BUILD_TYPE: Optional[str] = None VERBOSE: bool = False VLLM_ALLOW_LONG_MAX_MODEL_LEN: bool = False + VLLM_TEST_FORCE_FP8_MARLIN: bool = False def get_default_cache_root(): @@ -342,6 +343,13 @@ def get_default_config_root(): lambda: (os.environ.get("VLLM_ALLOW_LONG_MAX_MODEL_LEN", "0").strip().lower() in ("1", "true")), + + # If set, forces FP8 Marlin to be used for FP8 quantization regardless + # of the hardware support for FP8 compute. + "VLLM_TEST_FORCE_FP8_MARLIN": + lambda: + (os.environ.get("VLLM_TEST_FORCE_FP8_MARLIN", "0").strip().lower() in + ("1", "true")), } # end-env-vars-definition diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index c829cb836ee4c..cdd2413f5b2c4 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -4,6 +4,7 @@ from torch.nn import Module from torch.nn.parameter import Parameter +import vllm.envs as envs from vllm import _custom_ops as ops from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe import FusedMoE, FusedMoEMethodBase @@ -118,7 +119,7 @@ def __init__(self, quant_config: Fp8Config): # kernel for fast weight-only FP8 quantization capability = current_platform.get_device_capability() capability = capability[0] * 10 + capability[1] - self.use_marlin = capability < 89 + self.use_marlin = capability < 89 or envs.VLLM_TEST_FORCE_FP8_MARLIN def create_weights( self, @@ -174,6 +175,14 @@ def process_weights_after_loading(self, layer: Module) -> None: qweight, weight_scale = ops.scaled_fp8_quant(layer.weight, scale=None) + # If using marlin (w8a16), kernel uses channelwise weights, + # so extend the weight scales to be channelwise. + if self.use_marlin: + assert weight_scale.numel() == 1 + weight_scale = convert_to_channelwise( + weight_scale.expand(len(layer.logical_widths)), + layer.logical_widths) + # Update the layer with the new values. layer.weight = Parameter(qweight.t(), requires_grad=False) layer.weight_scale = Parameter(weight_scale, requires_grad=False)