Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AWQ: Up to 2.66x higher throughput #2566

Merged
merged 6 commits into from
Jan 27, 2024

Conversation

casper-hansen
Copy link
Contributor

@casper-hansen casper-hansen commented Jan 23, 2024

The strategy is to dequantize and run FP16 matmul for longer sequences. This could probably be faster if we just used cublas instead of torch.matmul.

EDIT: It seems throughput can be over 2x in vLLM because context processing is such a crucial part of the framework.

python benchmarks/benchmark_throughput.py --input-len 1024 --output-len 64 --model TheBloke/Mistral-7B-Instruct-v0.2-AWQ --quantization awq --num-prompts 100 --max-model-len 1070 --dtype half

Tested on 1x A100 (80GB).

Input Len Threshold Main (requests/s, tokens/s) PR (requests/s, tokens/s) Speedup (requests/s)
16384 256 0.19, 3138.15 0.54, 8879.04 2.84x
4096 1024 0.84, 3500.00 2.24, 9336.44 2.66x
1024 1024 3.28, 3570.90 6.70, 7286.77 2.04x
512 512 5.96, 3435.14 11.91, 6860.50 1.99x
256 256 10.39, 3323.90 18.49, 5916.05 1.78x

@casper-hansen casper-hansen marked this pull request as ready for review January 24, 2024 21:13
@casper-hansen casper-hansen changed the title [WIP] AWQ: Faster context processing [WIP] AWQ: Up to 2.66x higher throughput Jan 24, 2024
@casper-hansen casper-hansen changed the title [WIP] AWQ: Up to 2.66x higher throughput AWQ: Up to 2.66x higher throughput Jan 24, 2024
@MichaelJayW
Copy link

Does it affect the accuracy?

@fxmarty
Copy link

fxmarty commented Jan 25, 2024

@casper-hansen that's really cool, in line with the bench here where indeed cublas (that exllama kernel is using for longer sequences) is just better than the AWQ GEMM kernel.

I think it would make our life easier if we had the same kind of dispatch for marlin.

@casper-hansen
Copy link
Contributor Author

Does it affect the accuracy?

This should have no impact on accuracy. The dequantization kernel is strictly equivalent to the dequantization from the GEMM kernel.

@casper-hansen that's really cool, in line with the bench here where indeed cublas (that exllama kernel is using for longer sequences) is just better than the AWQ GEMM kernel.

I think it would make our life easier if we had the same kind of dispatch for marlin.

Yes, I agree that the Marlin kernels could achieve even higher throughput. The most crucial part is just missing - it’s only for symmetric quantization.

@fxmarty
Copy link

fxmarty commented Jan 25, 2024

is it that crucial though? Many int4*fp16 models use symmetric weight quantization successfully

@casper-hansen
Copy link
Contributor Author

is it that crucial though? Many int4*fp16 models use symmetric weight quantization successfully

It may turn out to just be an engineering problem, but from my limited experience, the most popular symmetric weight quantization methods suffer from a higher quantization error.

@WoosukKwon WoosukKwon self-requested a review January 25, 2024 19:41
Copy link
Collaborator

@WoosukKwon WoosukKwon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @casper-hansen, thanks for submitting the PR! Left some minor comments. Please take a look.

setup.py Outdated Show resolved Hide resolved
csrc/quantization/awq/gemm_kernels.cu Show resolved Hide resolved
vllm/model_executor/layers/quantization/awq.py Outdated Show resolved Hide resolved
@casper-hansen
Copy link
Contributor Author

Hi @casper-hansen, thanks for submitting the PR! Left some minor comments. Please take a look.

@WoosukKwon Thanks for the review. I applied your suggested fixes and tested that throughput is as expected.

Copy link
Collaborator

@WoosukKwon WoosukKwon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! thanks for the fix!

@WoosukKwon WoosukKwon merged commit beb89f6 into vllm-project:main Jan 27, 2024
16 checks passed
@sitabulaixizawaluduo
Copy link

I tested Llama-13b on A30 with tensor parallel size is 4, and I found awq throughput is lower than fp16.

@casper-hansen
Copy link
Contributor Author

casper-hansen commented Jan 30, 2024

I tested Llama-13b on A30 with tensor parallel size is 4, and I found awq throughput is lower than fp16.

This is as expected. You cannot exceed W16A16 performance with W4A16 when you test for throughput. You would need W4A4 (Atom, lower quality model) or W8A8 (SmoothQuant, also lower quality model).

This is because W4A16 methods require dequantization, so when you test throughput, you become compute bound and then it limits the performance.

EDIT: The throughput can also be lower if the TP implementation is not optimized for quantized models. Not sure if it is in vLLM

Comment on lines +161 to +162
out = ops.awq_dequantize(qweight, scales, qzeros, 0, 0, 0)
out = torch.matmul(reshaped_x, out)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Curious to learn: would this copy the dequantized weights back to the memory before doing torch.matmul? And a potential optimization is through implementing a more efficient mixed precision matmul that saves 1 data transfer to the memory?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are probably right that there is potential to eliminate overhead. Exllama runs dequantization and then directly calls cublas for matmul inside the same CUDA kernel. Definitely something to explore!

NikolaBorisov pushed a commit to deepinfra/vllm that referenced this pull request Jan 31, 2024
hongxiayang pushed a commit to hongxiayang/vllm that referenced this pull request Feb 13, 2024
alexm-neuralmagic pushed a commit to neuralmagic/nm-vllm that referenced this pull request Feb 13, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants