Skip to content

Conversation

@YouNeedCryDear
Copy link
Contributor

@YouNeedCryDear YouNeedCryDear commented Jul 14, 2025

Motivation

This PR adds basic support for RTN quantization, as a first step for supporting a calibration-free RTN-based quantization for accurate and accelerated INT4/INT8 inference (see this paper for details).

RTN is a simple quantization method that does not require any calibration data nor a corresponding calibration process.
As such, it can be applied on-the-fly (i.e., while loading an original model) in a fast and cheap way, even on a system that does not have enough memory to host the original (unquantized) model. Yet, RTN is often believed to lag behind more advanced quantization techniques in two crucial areas – generation throughput and accuracy.

As this paper shows, both issues can be alleviated, through the use of efficient CUDA kernels based on Marlin (for throughput) and selective quantization (for accuracy). The latter is a simple mechanism that allows a user to select layers and/or specific linear modules that should be quantized to a higher precision. For instance, leaving just a part of one layer of Llama-3.1 70B model in 8 bit precision, while quantizing the rest of that layer and all other 79 layers into 4 bits leads to a substantially improved recovery rate, on-par with or better than other techniques:
image

Note that this adds less than 0.05 bits per weight on average, resulting in only insignificant memory increase.

As noted above, this PR is for basic Python-based implementation for RTN that supports quantizing models on-the-fly.
Once approved, we intend to enhance it with:

Optimized CUDA (Marlin-based) kernels (for fast GEMM operations).
Support for selective quantization (for improved accuracy)

Modifications

Porting @sakogan PR to SGLang

  1. Add rtn for linear and MoE method in quantizations
  2. Include rtn in cmd args and imports

Checklist

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @YouNeedCryDear, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request lays the groundwork for efficient INT4/INT8 inference by integrating a basic, Python-based implementation of Round-To-Nearest (RTN) quantization. It enables models to be quantized on-the-fly during loading, supporting both linear and MoE layers, and makes this new quantization method accessible through existing configuration mechanisms. This is a foundational step, with future plans to enhance performance via optimized CUDA kernels and improve accuracy through selective quantization.

Highlights

  • New RTN Quantization Method: Introduced Round-To-Nearest (RTN) quantization, a calibration-free method for INT4/INT8 inference, with support for both standard linear layers and Mixture-of-Experts (MoE) layers.
  • On-the-Fly Weight Quantization: Implemented custom RTNTensor and RTNParameter wrappers that enable weights to be quantized automatically during the model loading process.
  • Command-Line Integration: Added rtn as a selectable quantization method, allowing users to specify RTN quantization via command-line arguments.
  • Core Quantization Logic: Provided rtn_quantize and rtn_dequantize functions, which handle the underlying 4-bit and 8-bit quantization and dequantization operations with group-wise scaling.
  • Torch.compile Compatibility: Included a fix_weights utility to replace the custom RTNParameter instances with standard torch.nn.Parameter after weights are loaded, ensuring compatibility with torch.compile.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@YouNeedCryDear YouNeedCryDear changed the title Feature/rtn quant [Feature] A calibration-free RTN-based quantization for accurate and accelerated INT4/INT8 inference Jul 14, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This PR introduces RTN quantization. There's a potential division-by-zero error during quantization and an inconsistency in the dimension used for packing 4-bit weights. Addressing these will ensure the feature is robust and correct.

Comment on lines 437 to 441
# Pack two 4-bit values into each byte.
inputs_q = (inputs_q[:, :, 1::2] << 4) | (inputs_q[:, :, ::2] & 0xF)
inputs_q = inputs_q.reshape(
tensor.shape[0], tensor.shape[1] // 2, tensor.shape[2]
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The 4-bit quantization logic packs values along the last dimension (dim=2), but it should pack along the output dimension (dim=1) for batched tensors to align with how RTNLinearMethod creates weights. This misalignment can lead to corrupted data. Modify rtn_quantize to pack along dim=1.

    if num_bits == 4:
        # Pack two 4-bit values into each byte along the output dimension (dim=1).
        inputs_q = (inputs_q[:, 1::2, :] << 4) | (inputs_q[:, ::2, :] & 0xF)
        inputs_q = inputs_q.contiguous()

input_min = torch.min(input_flat, dim=2, keepdim=True)[0]
input_max = torch.max(input_flat, dim=2, keepdim=True)[0]
input_max_abs = torch.max(input_min.abs(), input_max.abs())
scale = input_max_abs * 2.0 / (q_range - 1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The scale tensor can be zero if input_max_abs is zero, leading to division by zero in the subsequent calculation. To avoid this, add a small value to the scale to ensure it's never zero.

Suggested change
scale = input_max_abs * 2.0 / (q_range - 1)
scale = (input_max_abs * 2.0 / (q_range - 1)).clamp(min=1e-6)

input_max = torch.max(input_flat, dim=2, keepdim=True)[0]
input_max_abs = torch.max(input_min.abs(), input_max.abs())
scale = input_max_abs * 2.0 / (q_range - 1)
# Scale each input group, truncate and round to the nearest integer.
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment should read Scale each input group, round to the nearest integer, shift the range and truncate

class TestRTNQuantization(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = "/models/Llama-3.1-8B-Instruct"
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want a test with an MoE model as well?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants