Skip to content

Conversation

HandH1998
Copy link
Collaborator

@HandH1998 HandH1998 commented Jan 26, 2025

Following #3047, we replace w8a8 fp8 vllm kernel with sgl-kernel. Generally, the w8a8 fp8 sgl-kernel yields higher accuracy on gsm8k. On sm89-L40, the w8a8 fp8 sgl-kernel delivers a 14% higher throughput than the vllm kernel. On sm90-H100, both kernels exhibit similar performance.

Benchmark

model: neuralmagic/Meta-Llama-3-8B-Instruct-FP8

sm89-L40

gsm8k

# w8a8 fp8 vllm-kernel

python3 -m sglang.launch_server --model-path neuralmagic/Meta-Llama-3-8B-Instruct-FP8  --trust-remote-code --disable-radix --port 33333
python3 benchmark/gsm8k/bench_sglang.py --num-questions 2000 --parallel 500 --num-shots 8 --port 33333

Accuracy: 0.752
Invalid: 0.002
Latency: 225.308 s
Output throughput: 616.437 token/s

# w8a8 fp8 sgl-kernel

python3 -m sglang.launch_server --model-path neuralmagic/Meta-Llama-3-8B-Instruct-FP8  --trust-remote-code --disable-radix --port 33333
python3 benchmark/gsm8k/bench_sglang.py --num-questions 2000 --parallel 500 --num-shots 8 --port 33333

Accuracy: 0.761
Invalid: 0.001
Latency: 190.474 s
Output throughput: 702.757 token/s

throughput under various request rates

tok/s

request rate 8 16 32 inf
vllm kernel 1319.32 1942.67 2120.59 2122.10
sgl kernel 1347.00 2132.91 2416.50 2422.71

sm90-H100

gsm8k

# w8a8 fp8 vllm-kernel

python3 -m sglang.launch_server --model-path neuralmagic/Meta-Llama-3-8B-Instruct-FP8  --trust-remote-code --disable-radix --port 33333
python3 benchmark/gsm8k/bench_sglang.py --num-questions 2000 --parallel 2000 --num-shots 8 --port 33333

Accuracy: 0.754
Invalid: 0.000
Latency: 49.751 s
Output throughput: 2801.533 token/s

# w8a8 fp8 sgl-kernel

python3 -m sglang.launch_server --model-path neuralmagic/Meta-Llama-3-8B-Instruct-FP8  --trust-remote-code --disable-radix --port 33333
python3 benchmark/gsm8k/bench_sglang.py --num-questions 2000 --parallel 2000 --num-shots 8 --port 33333

Accuracy: 0.759
Invalid: 0.001
Latency: 49.257 s
Output throughput: 2805.524 token/s

throughput under various request rates

tok/s

request rate 8 16 32 inf
vllm kernel 1468.51 2775.83 4209.25 7121.02
sgl kernel 1468.41 2767.79 4236.97 7168.12

model: neuralmagic/Meta-Llama-3.1-8B-Instruct-FP8-dynamic

activation dynamic quantization

sm89-L40

gsm8k

# w8a8 fp8 vllm-kernel

python3 -m sglang.launch_server --model-path neuralmagic/Meta-Llama-3.1-8B-Instruct-FP8-dynamic  --trust-remote-code --disable-radix --port 33333
python3 benchmark/gsm8k/bench_sglang.py --num-questions 2000 --parallel 500 --num-shots 8 --port 33333

Accuracy: 0.779
Invalid: 0.001
Latency: 226.505 s
Output throughput: 590.980 token/s

# w8a8 fp8 sgl-kernel

python3 -m sglang.launch_server --model-path neuralmagic/Meta-Llama-3.1-8B-Instruct-FP8-dynamic --quantization w8a8_fp8 --trust-remote-code --disable-radix --port 33333
python3 benchmark/gsm8k/bench_sglang.py --num-questions 2000 --parallel 500 --num-shots 8 --port 33333

Accuracy: 0.795
Invalid: 0.002
Latency: 194.464 s
Output throughput: 694.344 token/s

throughput under various request rates

tok/s

request rate 8 16 32 inf
vllm kernel 1317.03 1928.51 2098.10 2104.52
sgl kernel 1344.21 2124.57 2420.53 2445.45

sm90-H100

gsm8k

# w8a8 fp8 vllm-kernel

python3 -m sglang.launch_server --model-path neuralmagic/Meta-Llama-3.1-8B-Instruct-FP8-dynamic  --trust-remote-code --disable-radix --port 33333
python3 benchmark/gsm8k/bench_sglang.py --num-questions 2000 --parallel 2000 --num-shots 8 --port 33333

Accuracy: 0.754
Invalid: 0.000
Latency: 49.751 s
Output throughput: 2801.533 token/s

# w8a8 fp8 sgl-kernel

python3 -m sglang.launch_server --model-path neuralmagic/Meta-Llama-3.1-8B-Instruct-FP8-dynamic --quantization w8a8_fp8  --trust-remote-code --disable-radix --port 33333
python3 benchmark/gsm8k/bench_sglang.py --num-questions 2000 --parallel 2000 --num-shots 8 --port 33333

Accuracy: 0.796
Invalid: 0.001
Latency: 47.991 s
Output throughput: 2801.793 token/s

throughput under various request rates

tok/s

request rate 8 16 32 inf
vllm kernel 1454.58 2692.00 4231.36 6664.01
sgl kernel 1454.70 2692.19 4217.78 6693.59

@zhyncs
Copy link
Member

zhyncs commented Jan 26, 2025

Let me bump a new sgl-kernel version to unblock this PR.

@merrymercy
Copy link
Contributor

  1. Can we still provide a flag to fallback to vllm's implementation? Similar to the custom allreduce kernel
    use_vllm_custom_allreduce = os.environ.get("USE_VLLM_CUSTOM_ALLREDUCE", default=True)
    . We need this at the beginning to quickly debug possible regression and AMD compatibility.
  2. How is the performance on dynamic quantization?

@zhyncs
Copy link
Member

zhyncs commented Feb 10, 2025

@HandH1998 What is the progress of this PR? Please let me know when it is ready.

@HandH1998
Copy link
Collaborator Author

@zhyncs two days later

@HandH1998
Copy link
Collaborator Author

@merrymercy @zhyncs
I have added support for falling back to the vLLM cutlass w8a8 fp8 kernel and have benchmarked dynamic quantization. The benchmark results of dynamic quantization are appended to the static quantization benchmark results, showing similar performance. However, as the batch size increases, the end-to-end performance using the sgl_kernel is worse than using the vllm_kernel. I believe the bottleneck is at the Triton kernel per_token_group_quant_fp8 as @BBuf mentioned in #3493. If we replace the Triton implementation with the CUDA implementation, the performance gap should decrease further.

@HandH1998
Copy link
Collaborator Author

I also added a quantization config w8a8_fp8 to support the inference of quantized model underactivation dynamic per-token quantization, weight static per-channel quantization following #2881. You can use --quantization w8a8_fp8 to load the quantized checkpoint then perform the inference directly without any modification to config.json.

@zhyncs
Copy link
Member

zhyncs commented Feb 12, 2025

#3493 @HandH1998 this has been merged

@HandH1998
Copy link
Collaborator Author

@zhyncs
image
This PR depends on the latest sgl-kernel, but the CI doesn't use the latest sgl-kernel.

@zhyncs
Copy link
Member

zhyncs commented Mar 7, 2025

@zhyncs
Copy link
Member

zhyncs commented Mar 7, 2025

update this to 0.0.3.post7

"sgl-kernel==0.0.3.post6",

@HandH1998
Copy link
Collaborator Author

update this to 0.0.3.post7

"sgl-kernel==0.0.3.post6",

need to upload to pypi?

@zhyncs
Copy link
Member

zhyncs commented Mar 7, 2025

update this to 0.0.3.post7

"sgl-kernel==0.0.3.post6",

need to upload to pypi?

done https://pypi.org/project/sgl-kernel/0.0.3.post7/

@HandH1998
Copy link
Collaborator Author

@zhyncs
image

The two falied CIs seems are related with DSv3. I tried to reproduce them locally. But I can't find lmsys/sglang-ci-dsv3-test model in huggingface.

@zhyncs
Copy link
Member

zhyncs commented Mar 7, 2025

@HandH1998 You can give me the HF user name or use DeepSeek V3/R1 for testing. I have also updated this, so if you wish to upgrade, please update this as well.

pip install sgl-kernel==0.0.3.post6 --force-reinstall --no-deps

@hebiao064
Copy link
Collaborator

@HandH1998 Do you think we should support similar api like scaled_fp8_quant
https://github.com/vllm-project/vllm/blob/8ca7a71df787ad711ad3ac70a5bd2eb2bb398938/vllm/_custom_ops.py#L842

@HandH1998
Copy link
Collaborator Author

@HandH1998 Do you think we should support similar api like scaled_fp8_quant https://github.com/vllm-project/vllm/blob/8ca7a71df787ad711ad3ac70a5bd2eb2bb398938/vllm/_custom_ops.py#L842

The cutlass w8a8 fp8 kernel only support per-channel activation scales, so I only apply per_token_quant. The scaled_fp8_quant also also per_tensor_quant, but it is not compatible with the cutlass w8a8 fp8 kernel and also bring a worse accuracy than per_token_quant and only a little better speed. So I don't think it is necessary.

@HandH1998
Copy link
Collaborator Author

sglang/scripts/ci_install_dependency.sh

My HF user name is HandH1998.

@HandH1998
Copy link
Collaborator Author

The CI failures are caused by 0.0.3.post7 sgl-kernel. Ref to #4214. @zhyncs

@zhyncs zhyncs merged commit 0dd6cda into sgl-project:main Mar 9, 2025
1 of 4 checks passed
aoshen524 pushed a commit to aoshen524/sglang that referenced this pull request Mar 10, 2025
@merrymercy merrymercy mentioned this pull request Apr 26, 2025
67 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

high priority quant LLM Quantization

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants