Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Kernel] Expand FP8 support to Ampere GPUs using FP8 Marlin #5975

Merged
merged 19 commits into from
Jul 3, 2024

Conversation

mgoin
Copy link
Member

@mgoin mgoin commented Jun 28, 2024

This work expands FP8 support in vLLM from GPUs with hardware FP8 support (Hopper and Ada Lovelace) to GPUs without native support (currently Ampere) by introducing FP8 Marlin - a fast fused dequantization kernel for FP8 to BF16/FP16 conversion.

Key features:

  • Enables FP8 quantization on a wider range of GPUs (SM 8.0 and 8.7, Ampere)
  • Improves performance up to 2x in memory-bound scenarios
  • Maintains accuracy comparable to FP16 baselines
  • Reduces weight memory usage by 2x, allowing larger batches
  • Simple to use - just specify quantization="fp8" at runtime or use pre-quantized FP8 checkpoints

Implementation details:

  • Based on existing 8-bit integer support in GPTQ Marlin kernel
  • Packs FP8 weights into int32 doublewords (GPTQ format) and then permutes weights into Marlin format
  • Efficient 4xFP8 to 4xFP16/BF16 dequantization using bit arithmetic and SIMT operations

End-to-end performance and accuracy results:
FP8 Marlin A10 E2E Latency in vLLM
FP8 Marlin A100 E2E Latency in vLLM
GSM8k lm-eval with FP8 Marlin in vLLM
Individual layer sweeps:
A10 Layer-wise Sweep _ PyTorch FP16 vs FP8 Marlin MatMul
A100 Layer-wise Sweep _ PyTorch FP16 vs FP8 Marlin MatMul

As shown in the graphs, FP8 Marlin can provide significant speedups with minimal accuracy impact. Performance gains are higher on GPUs with less memory bandwidth (A10, RTX 3090) and for larger models.

Notes:

  • This weight-only approach differs slightly from the existing W8A8 FP8 quantization, offering higher accuracy because the activations have no need to be quantized
  • Currently expanding scales to be channelwise; future work will revert to per-tensor scales
  • This does not include support for MoE models.

Testing:

  • Tested on H100, A100, and A10 GPUs

This enhancement enables more users to benefit from FP8 quantization without hardware restrictions, improving vLLM's performance and efficiency across a broader range of setups!

@mgoin mgoin changed the title [Kernel] Expand FP8 support to Ampere GPUs using FP8 Marlin #331 [Kernel] Expand FP8 support to Ampere GPUs using FP8 Marlin Jun 28, 2024
@robertgshaw2-neuralmagic
Copy link
Collaborator

This is an awesome feature!

Copy link
Collaborator

@comaniac comaniac left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall LGTM. Thanks!

tests/kernels/test_marlin_gemm.py Show resolved Hide resolved
@mgoin mgoin enabled auto-merge (squash) July 3, 2024 16:30
@mgoin mgoin merged commit 47f0954 into vllm-project:main Jul 3, 2024
70 checks passed
@fxmarty
Copy link

fxmarty commented Jul 19, 2024

@mgoin awesome feature! I suppose that the perf benchmark was run with cuda graph enabled? Out of curiosity, did you run it without cuda graph?

As this kernel has been integrated in TGI as well, it appears having CUDA graph enabled is rather critical so as to get speedups in the decoding (which I don't really explain to myself - but haven't profiled). In the prefill, as cuda graphs are never used for long enough seqlens, I do get a slight slowdown.

prefill_gpu
decode_gpu

I did not benchmark on vllm, but I suppose the trend is similar. Probably depends on the gpu/tp config/model as well.

related: huggingface/optimum-quanto#241 (comment)

@mgoin
Copy link
Member Author

mgoin commented Jul 19, 2024

Glad you're enjoying it @fxmarty. Thanks for sharing your analysis. My end-to-end benchmarks were all done with cuda graphs enabled as this is the default in vLLM. Note that it is expected to see a slight slow-down at prefill (M>256), we trade this off to see the improvements at decode.

I'm curious, have you seen the same difference for marlin int8 or int4? Aside from this, I think there could be additional tuning for A100 problem shapes.

@HPC4AI
Copy link

HPC4AI commented Aug 26, 2024

Hello, I noticed that you used the dequant_8bit function to dequantize FP8 data to FP16 data, but I'm not clear on the underlying principle. Could you please also provide the code for quantizing FP16 to FP8? Thanks.

Alvant pushed a commit to compressa-ai/vllm that referenced this pull request Oct 26, 2024
@AllenDou
Copy link
Contributor

Hello, I noticed that you used the dequant_8bit function to dequantize FP8 data to FP16 data, but I'm not clear on the underlying principle. Could you please also provide the code for quantizing FP16 to FP8? Thanks.

__device__ inline typename ScalarType<nv_bfloat16>::FragB

https://github.com/IST-DASLab/marlin/blob/1f25790bdd49fba53106164a24666dade68d7c90/marlin/marlin_cuda_kernel.cu#L131

https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants