Skip to content

Commit

Permalink
[Bugfix][FP8] Fix dynamic FP8 Marlin quantization (vllm-project#7219)
Browse files Browse the repository at this point in the history
  • Loading branch information
mgoin authored Aug 7, 2024
1 parent fde47d3 commit 5223199
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 5 deletions.
19 changes: 15 additions & 4 deletions tests/quantization/test_fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.fp8 import (Fp8KVCacheMethod,
Fp8LinearMethod)
from vllm.platforms import current_platform

MODELS = [
"neuralmagic/Meta-Llama-3-8B-Instruct-FP8-KV",
Expand All @@ -20,7 +21,12 @@
@pytest.mark.skipif(not is_quant_method_supported("fp8"),
reason="FP8 is not supported on this GPU type.")
@pytest.mark.parametrize("model_id", MODELS)
def test_model_load_and_run(vllm_runner, model_id: str):
@pytest.mark.parametrize("force_marlin", [False, True])
def test_model_load_and_run(vllm_runner, model_id: str, force_marlin: bool,
monkeypatch) -> None:
if force_marlin:
monkeypatch.setenv("VLLM_TEST_FORCE_FP8_MARLIN", "1")

with vllm_runner(model_id) as llm:
# note: this does not test accuracy, just that we can run through
# see lm-eval tests for accuracy
Expand Down Expand Up @@ -61,7 +67,12 @@ def test_kv_cache_model_load_and_run(vllm_runner, model_id: str):
@pytest.mark.skipif(not is_quant_method_supported("fp8"),
reason="FP8 is not supported on this GPU type.")
@pytest.mark.parametrize("kv_cache_dtype", ["auto", "fp8"])
def test_load_fp16_model(vllm_runner, kv_cache_dtype: str) -> None:
@pytest.mark.parametrize("force_marlin", [False, True])
def test_load_fp16_model(vllm_runner, kv_cache_dtype: str, force_marlin: bool,
monkeypatch) -> None:
if force_marlin:
monkeypatch.setenv("VLLM_TEST_FORCE_FP8_MARLIN", "1")

with vllm_runner("facebook/opt-125m",
quantization="fp8",
kv_cache_dtype=kv_cache_dtype) as llm:
Expand All @@ -75,9 +86,9 @@ def test_load_fp16_model(vllm_runner, kv_cache_dtype: str) -> None:
assert attn._k_scale == 1.0
assert attn._v_scale == 1.0

capability = torch.cuda.get_device_capability()
capability = current_platform.get_device_capability()
capability = capability[0] * 10 + capability[1]
if capability >= 89:
if capability >= 89 and not force_marlin:
# For GPUs with hardware support, we keep weights in fp8
assert fc1.weight.dtype == torch.float8_e4m3fn
else:
Expand Down
8 changes: 8 additions & 0 deletions vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
CMAKE_BUILD_TYPE: Optional[str] = None
VERBOSE: bool = False
VLLM_ALLOW_LONG_MAX_MODEL_LEN: bool = False
VLLM_TEST_FORCE_FP8_MARLIN: bool = False


def get_default_cache_root():
Expand Down Expand Up @@ -342,6 +343,13 @@ def get_default_config_root():
lambda:
(os.environ.get("VLLM_ALLOW_LONG_MAX_MODEL_LEN", "0").strip().lower() in
("1", "true")),

# If set, forces FP8 Marlin to be used for FP8 quantization regardless
# of the hardware support for FP8 compute.
"VLLM_TEST_FORCE_FP8_MARLIN":
lambda:
(os.environ.get("VLLM_TEST_FORCE_FP8_MARLIN", "0").strip().lower() in
("1", "true")),
}

# end-env-vars-definition
Expand Down
11 changes: 10 additions & 1 deletion vllm/model_executor/layers/quantization/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from torch.nn import Module
from torch.nn.parameter import Parameter

import vllm.envs as envs
from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import FusedMoE, FusedMoEMethodBase
Expand Down Expand Up @@ -118,7 +119,7 @@ def __init__(self, quant_config: Fp8Config):
# kernel for fast weight-only FP8 quantization
capability = current_platform.get_device_capability()
capability = capability[0] * 10 + capability[1]
self.use_marlin = capability < 89
self.use_marlin = capability < 89 or envs.VLLM_TEST_FORCE_FP8_MARLIN

def create_weights(
self,
Expand Down Expand Up @@ -174,6 +175,14 @@ def process_weights_after_loading(self, layer: Module) -> None:
qweight, weight_scale = ops.scaled_fp8_quant(layer.weight,
scale=None)

# If using marlin (w8a16), kernel uses channelwise weights,
# so extend the weight scales to be channelwise.
if self.use_marlin:
assert weight_scale.numel() == 1
weight_scale = convert_to_channelwise(
weight_scale.expand(len(layer.logical_widths)),
layer.logical_widths)

# Update the layer with the new values.
layer.weight = Parameter(qweight.t(), requires_grad=False)
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
Expand Down

0 comments on commit 5223199

Please sign in to comment.