Skip to content

Commit

Permalink
[Kernel] Expand FP8 support to Ampere GPUs using FP8 Marlin (vllm-pro…
Browse files Browse the repository at this point in the history
  • Loading branch information
mgoin authored Jul 3, 2024
1 parent 7cd2ebb commit 47f0954
Show file tree
Hide file tree
Showing 11 changed files with 1,587 additions and 44 deletions.
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
"csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu"
"csrc/quantization/gptq_marlin/gptq_marlin.cu"
"csrc/quantization/gptq_marlin/gptq_marlin_repack.cu"
"csrc/quantization/fp8/fp8_marlin.cu"
"csrc/custom_all_reduce.cu"
"csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu"
"csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu"
Expand Down
5 changes: 5 additions & 0 deletions csrc/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,11 @@ torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
int64_t size_k, int64_t size_n,
int64_t num_bits);

torch::Tensor fp8_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
torch::Tensor& b_scales, torch::Tensor& workspace,
int64_t num_bits, int64_t size_m, int64_t size_n,
int64_t size_k);

bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability);

void cutlass_scaled_mm(torch::Tensor& out, torch::Tensor const& a,
Expand Down
1,308 changes: 1,308 additions & 0 deletions csrc/quantization/fp8/fp8_marlin.cu

Large diffs are not rendered by default.

4 changes: 4 additions & 0 deletions csrc/torch_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,10 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops.def("gptq_marlin_repack", &gptq_marlin_repack);
ops.impl("gptq_marlin_repack", torch::kCUDA, &gptq_marlin_repack);

// fp8_marlin Optimized Quantized GEMM for FP8 weight-only.
ops.def("fp8_marlin_gemm", &fp8_marlin_gemm);
ops.impl("fp8_marlin_gemm", torch::kCUDA, &fp8_marlin_gemm);

// CUTLASS w8a8 GEMM, supporting symmetric per-tensor or per-row/column
// quantization.
ops.def(
Expand Down
3 changes: 2 additions & 1 deletion docs/source/quantization/fp8.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@ FP8
==================

vLLM supports FP8 (8-bit floating point) weight and activation quantization using hardware acceleration on GPUs such as Nvidia H100 and AMD MI300x.
Currently, only Hopper and Ada Lovelace GPUs are supported.
Currently, only Hopper and Ada Lovelace GPUs are officially supported for W8A8.
Ampere GPUs are supported for W8A16 (weight-only FP8) utilizing Marlin kernels.
Quantization of models with FP8 allows for a 2x reduction in model memory requirements and up to a 1.6x improvement in throughput with minimal impact on accuracy.

Please visit the HF collection of `quantized FP8 checkpoints of popular LLMs ready to use with vLLM <https://huggingface.co/collections/neuralmagic/fp8-llms-for-vllm-666742ed2b78b7ac8df13127>`_.
Expand Down
2 changes: 1 addition & 1 deletion docs/source/quantization/supported_hardware.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ Implementation Volta Turing Ampere Ada Hopper AMD GPU Intel GPU x86
AQLM ✅ ✅ ✅ ✅ ✅ ❌ ❌ ❌ ❌ ❌
AWQ ❌ ✅ ✅ ✅ ✅ ❌ ❌ ❌ ❌ ❌
DeepSpeedFP ✅ ✅ ✅ ✅ ✅ ❌ ❌ ❌ ❌ ❌
FP8 ❌ ❌ ✅ ✅ ❌ ❌ ❌ ❌ ❌
FP8 ❌ ❌ ✅ ✅ ❌ ❌ ❌ ❌ ❌
Marlin ❌ ❌ ✅ ✅ ✅ ❌ ❌ ❌ ❌ ❌
GPTQ ✅ ✅ ✅ ✅ ✅ ❌ ❌ ❌ ❌ ❌
SqueezeLLM ✅ ✅ ✅ ✅ ✅ ❌ ❌ ❌ ❌ ❌
Expand Down
88 changes: 84 additions & 4 deletions tests/kernels/test_marlin_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,16 @@
from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.gptq_marlin import (
GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N,
GPTQ_MARLIN_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_SUPPORTED_NUM_BITS)
GPTQ_MARLIN_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_SUPPORTED_NUM_BITS,
marlin_permute_scales)
from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
GPTQ_MARLIN_24_MAX_PARALLEL, GPTQ_MARLIN_24_MIN_THREAD_N,
GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_24_SUPPORTED_NUM_BITS)
from vllm.model_executor.layers.quantization.utils.marlin_perms import (
marlin_perm)
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
MarlinWorkspace, compute_max_diff, is_marlin_supported, marlin_24_quantize,
marlin_quantize, marlin_weights)
marlin_quantize, marlin_weights, pack_fp8_to_int32)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
gptq_pack, quantize_weights, sort_weights)

Expand All @@ -38,9 +39,11 @@
(67, 13, 11),
]

DTYPES = [torch.float16, torch.bfloat16]

def rand_data(shape):
return torch.randn(shape, dtype=torch.half, device="cuda")

def rand_data(shape, dtype=torch.float16):
return torch.randn(shape, dtype=dtype, device="cuda")


@pytest.mark.skipif(not is_marlin_supported(),
Expand Down Expand Up @@ -217,3 +220,80 @@ def test_marlin_24_gemm(k_chunk, n_chunk, num_bits, group_size, mnk_factors):
print("max_diff = {}".format(max_diff))

assert max_diff < 0.04


@pytest.mark.skipif(not is_marlin_supported(),
reason="Marlin is not supported on this GPU type.")
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
@pytest.mark.parametrize("num_bits", [8])
@pytest.mark.parametrize("group_size", [-1])
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
@pytest.mark.parametrize("dtype", DTYPES)
def test_fp8_marlin_gemm(
k_chunk,
n_chunk,
num_bits,
group_size,
mnk_factors,
dtype,
):
m_factor, n_factor, k_factor = mnk_factors

size_m = m_factor
size_k = k_chunk * k_factor
size_n = n_chunk * n_factor

print(f"MNK = {size_m} {size_n} {size_k}")
print(f"groupsize = {group_size}")

a_input = rand_data((size_m, size_k), dtype=dtype)
b_weight = rand_data((size_k, size_n), dtype=dtype)

# WEIGHTS
fp8_weight, weight_scale = ops.scaled_fp8_quant(b_weight, scale=None)
# Repack weights to gptq format (packed int32 elements)
packed_gptq_qweight = pack_fp8_to_int32(fp8_weight)
# Repack weights to marlin format
marlin_qweight = ops.gptq_marlin_repack(
b_q_weight=packed_gptq_qweight,
perm=torch.empty(0, dtype=torch.int, device="cuda"),
size_k=size_k,
size_n=size_n,
num_bits=8,
)

# WEIGHT SCALES
# Currently Marlin doesn't support per-tensor scales, so we
# expand it to channelwise
scales = weight_scale.repeat(1, size_n).to(a_input.dtype).to("cuda")
# Permute scales
marlin_scales = marlin_permute_scales(
s=scales,
size_k=size_k,
size_n=size_n,
group_size=-1,
num_bits=8,
)

workspace = MarlinWorkspace(size_n, GPTQ_MARLIN_MIN_THREAD_N,
GPTQ_MARLIN_MAX_PARALLEL)

output = ops.fp8_marlin_gemm(
a=a_input,
b_q_weight=marlin_qweight,
b_scales=marlin_scales,
workspace=workspace.scratch,
num_bits=num_bits,
size_m=a_input.shape[0],
size_n=b_weight.shape[1],
size_k=a_input.shape[1],
)
output_ref = torch.matmul(a_input, b_weight)

torch.cuda.synchronize()

max_diff = compute_max_diff(output, output_ref)
print("max_diff = {}".format(max_diff))

assert max_diff < 0.04
19 changes: 14 additions & 5 deletions tests/quantization/test_fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import torch

from tests.quantization.utils import is_quant_method_supported
from vllm._custom_ops import scaled_fp8_quant
from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.fp8 import Fp8LinearMethod

MODELS = [
Expand Down Expand Up @@ -35,7 +35,16 @@ def test_load_fp16_model(vllm_runner) -> None:
model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501
fc1 = model.model.decoder.layers[0].fc1
assert isinstance(fc1.quant_method, Fp8LinearMethod)
assert fc1.weight.dtype == torch.float8_e4m3fn

capability = torch.cuda.get_device_capability()
capability = capability[0] * 10 + capability[1]
if capability >= 89:
# For GPUs with hardware support, we keep weights in fp8
assert fc1.weight.dtype == torch.float8_e4m3fn
else:
# For GPUs without hardware support, we pack the fp8 weights
# for weight-only quantization using Marlin kernels
assert fc1.weight.dtype == torch.int32


@pytest.mark.skipif(not is_quant_method_supported("fp8"),
Expand Down Expand Up @@ -63,19 +72,19 @@ def per_tensor_dequantize(tensor, inv_scale, dtype):
x = (torch.randn(size=(11, 11), device="cuda") * 13).to(dtype)

# Dynamic quantization
ref_y, inv_scale = scaled_fp8_quant(x, None)
ref_y, inv_scale = ops.scaled_fp8_quant(x, None)
ref_y = per_tensor_dequantize(ref_y, inv_scale, dtype)

# Reference dynamic quantizaton
y = quantize_ref(x, inv_scale)
assert torch.allclose(ref_y, per_tensor_dequantize(y, inv_scale, dtype))

# Static quantization
y, _ = scaled_fp8_quant(x, inv_scale)
y, _ = ops.scaled_fp8_quant(x, inv_scale)
assert torch.allclose(ref_y, per_tensor_dequantize(y, inv_scale, dtype))

# Padding
y, _ = scaled_fp8_quant(x, inv_scale, batch_dim_padding=17)
y, _ = ops.scaled_fp8_quant(x, inv_scale, batch_dim_padding=17)
assert y.shape[0] == 17
assert torch.allclose(
ref_y,
Expand Down
9 changes: 9 additions & 0 deletions vllm/_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,15 @@ def gptq_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
size_k, is_k_full)


# fp8 marlin
def fp8_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
b_scales: torch.Tensor, workspace: torch.Tensor,
num_bits: int, size_m: int, size_n: int,
size_k: int) -> torch.Tensor:
return torch.ops._C.fp8_marlin_gemm(a, b_q_weight, b_scales, workspace,
num_bits, size_m, size_n, size_k)


# fp8
def scaled_fp8_quant(
input: torch.Tensor,
Expand Down
Loading

0 comments on commit 47f0954

Please sign in to comment.