Skip to content

Conversation

@sakogan
Copy link
Contributor

@sakogan sakogan commented May 27, 2025

This PR adds basic support for RTN quantization, as a first step for supporting a calibration-free RTN-based quantization for accurate and accelerated INT4/INT8 inference (see this paper for details).

RTN is a simple quantization method that does not require any calibration data nor a corresponding calibration process.
As such, it can be applied on-the-fly (i.e., while loading an original model) in a fast and cheap way, even on a system that does not have enough memory to host the original (unquantized) model. Yet, RTN is often believed to lag behind more advanced quantization techniques in two crucial areas – generation throughput and accuracy.

As this paper shows, both issues can be alleviated, through the use of efficient CUDA kernels based on Marlin (for throughput) and selective quantization (for accuracy). The latter is a simple mechanism that allows a user to select layers and/or specific linear modules that should be quantized to a higher precision. For instance, leaving just a part of one layer of Llama-3.1 70B model in 8 bit precision, while quantizing the rest of that layer and all other 79 layers into 4 bits leads to a substantially improved recovery rate, on-par with or better than other techniques:
Screenshot 2025-05-27 at 12 38 22 PM
Note that this adds less than 0.05 bits per weight on average, resulting in only insignificant memory increase.

As noted above, this PR is for basic Python-based implementation for RTN that supports quantizing models on-the-fly.
Once approved, we intend to enhance it with:

  • Optimized CUDA (Marlin-based) kernels (for fast GEMM operations).
  • Support for selective quantization (for improved accuracy)
  • Support for MoE models

Signed-off-by: Alex Kogan <alex.kogan@oracle.com>
@github-actions
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

Signed-off-by: Alex Kogan <alex.kogan@oracle.com>
@sakogan
Copy link
Contributor Author

sakogan commented Jun 16, 2025

@mgoin @robertgshaw2-redhat @tlrmchlsmth Can you, please, take a look at this PR and let me know if you have any comments? Thanks!


from tests.quantization.utils import is_quant_method_supported

MODELS = ["microsoft/Phi-3-mini-4k-instruct"]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason why you are using Phi here, like to get around sharded weight loading? IIRC Phi models have their mergable layers like q/k/v already merged in the checkpoint as qkv_proj. I notice you override the weight loading with your RTNParameter class so I'm curious if it works with an un-merged checkpoint like Llama

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, absolutely it works with any dense model, including un-merged LLama checkpoints. The Phi model is an arbitrary choice of a small dense model, happy to change it to something else

Co-authored-by: Michael Goin <mgoin64@gmail.com>
Signed-off-by: Alex Kogan <alex.kogan@oracle.com>
@sakogan sakogan requested a review from mgoin June 25, 2025 13:59
@mgoin mgoin enabled auto-merge (squash) July 1, 2025 02:01
@mgoin mgoin added ready ONLY add when PR is ready to merge/full CI is needed quantization labels Jul 1, 2025
@mgoin mgoin merged commit 2794935 into vllm-project:main Jul 1, 2025
86 checks passed
@renjie0
Copy link

renjie0 commented Jul 9, 2025

"Yet, RTN is often believed to lag behind more advanced quantization techniques in two crucial areas – generation throughput and accuracy." How does it look like now with your latest improvement?

@YouNeedCryDear
Copy link

"Yet, RTN is often believed to lag behind more advanced quantization techniques in two crucial areas – generation throughput and accuracy." How does it look like now with your latest improvement?

You can take a look at the paper mentioned above for more detail.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

quantization ready ONLY add when PR is ready to merge/full CI is needed

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants