-
-
Notifications
You must be signed in to change notification settings - Fork 11.3k
[Feature] A calibration-free RTN-based quantization for accurate and accelerated INT4/INT8 inference #18768
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
[Feature] A calibration-free RTN-based quantization for accurate and accelerated INT4/INT8 inference #18768
Changes from all commits
Commits
Show all changes
4 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,28 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # Copyright © 2025, Oracle and/or its affiliates. | ||
| """Tests RTN quantization startup and generation, | ||
| doesn't test correctness | ||
| """ | ||
| import pytest | ||
|
|
||
| from tests.quantization.utils import is_quant_method_supported | ||
|
|
||
| MODELS = ["microsoft/Phi-3-mini-4k-instruct"] | ||
|
|
||
|
|
||
| @pytest.mark.skipif(not is_quant_method_supported("rtn"), | ||
| reason="RTN is not supported on this GPU type.") | ||
| @pytest.mark.parametrize("model", MODELS) | ||
| @pytest.mark.parametrize("dtype", ["bfloat16"]) | ||
| @pytest.mark.parametrize("max_tokens", [10]) | ||
| def test_model_rtn_startup( | ||
| hf_runner, | ||
| vllm_runner, | ||
| example_prompts, | ||
| model: str, | ||
| dtype: str, | ||
| max_tokens: int, | ||
| ) -> None: | ||
|
|
||
| with vllm_runner(model, dtype=dtype, quantization="rtn") as vllm_model: | ||
| vllm_model.generate_greedy(example_prompts, max_tokens) | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,288 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # Copyright © 2025, Oracle and/or its affiliates. | ||
|
|
||
| import os | ||
| from typing import Any, Optional | ||
|
|
||
| import torch | ||
| import torch.nn.functional as F | ||
| from torch.nn.parameter import Parameter | ||
|
|
||
| from vllm.logger import init_logger | ||
| from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, | ||
| set_weight_attrs) | ||
| from vllm.model_executor.layers.quantization import QuantizationMethods | ||
| from vllm.model_executor.layers.quantization.base_config import ( | ||
| QuantizationConfig) | ||
|
|
||
| logger = init_logger(__name__) | ||
| """By default, use 8 bit as target precision, but it can be | ||
| overridden by setting the RTN_NUM_BITS envvar | ||
| """ | ||
| NUM_BITS = os.getenv('RTN_NUM_BITS', "8") | ||
| """By default, use group size of 128 parameters, but it can be | ||
| overridden by setting the RTN_GROUP_SIZE envvar | ||
| """ | ||
| GROUP_SIZE = os.getenv('RTN_GROUP_SIZE', "128") | ||
|
|
||
|
|
||
| class RTNConfig(QuantizationConfig): | ||
| """Config class for RTN. | ||
| """ | ||
|
|
||
| def __init__( | ||
| self, | ||
| weight_bits: int = int(NUM_BITS), | ||
| group_size: int = int(GROUP_SIZE), | ||
| ) -> None: | ||
| self.weight_bits = weight_bits | ||
| self.group_size = group_size | ||
|
|
||
| if self.weight_bits != 4 and self.weight_bits != 8: | ||
| raise ValueError( | ||
| "Currently, only 4-bit or 8-bit weight quantization is " | ||
| f"supported for RTN, but got {self.weight_bits} bits.") | ||
|
|
||
| def __repr__(self) -> str: | ||
| return (f"RTNConfig(weight_bits={self.weight_bits}, " | ||
| f"group_size={self.group_size})") | ||
|
|
||
| @classmethod | ||
| def get_name(cls) -> QuantizationMethods: | ||
| return "rtn" | ||
|
|
||
| @classmethod | ||
| def get_supported_act_dtypes(cls) -> list[torch.dtype]: | ||
| return [torch.bfloat16, torch.half] | ||
|
|
||
| @classmethod | ||
| def get_min_capability(cls) -> int: | ||
| return 80 | ||
|
|
||
| @classmethod | ||
| def get_config_filenames(cls) -> list[str]: | ||
| return [] | ||
|
|
||
| @classmethod | ||
| def from_config(cls, config: dict[str, Any]) -> "RTNConfig": | ||
| weight_bits = cls.get_from_keys(config, ["bits"]) | ||
| group_size = cls.get_from_keys(config, ["group_size"]) | ||
| return cls(weight_bits, group_size) | ||
|
|
||
| def get_quant_method(self, layer: torch.nn.Module, | ||
| prefix: str) -> Optional["RTNLinearMethod"]: | ||
| if isinstance(layer, LinearBase): | ||
| return RTNLinearMethod(self) | ||
| return None | ||
|
|
||
|
|
||
| class RTNTensor: | ||
| """A wrapper over Tensor that enables quantization on-the-fly by | ||
| overloading the copy_ method. | ||
| """ | ||
|
|
||
| def __init__(self, data: torch.Tensor, scale: torch.Tensor, | ||
| quant_config: RTNConfig) -> None: | ||
| self.data = data | ||
| self.scale = scale | ||
| self.quant_config = quant_config | ||
|
|
||
| def narrow(self, dim, start, length): | ||
| factor = 1 if self.quant_config.weight_bits == 8 else 2 | ||
| return RTNTensor( | ||
| self.data.narrow(dim, start // factor, length // factor), | ||
| self.scale.narrow(dim, start, length), self.quant_config) | ||
|
|
||
| @property | ||
| def shape(self): | ||
| shape = self.data.shape | ||
| factor = 1 if self.quant_config.weight_bits == 8 else 2 | ||
| return torch.Size((shape[0] * factor, shape[1])) | ||
|
|
||
| def copy_(self, loaded_weight: torch.Tensor) -> None: | ||
| qweight, weight_scale = rtn_quantize(loaded_weight.cuda(), | ||
| self.quant_config.weight_bits, | ||
| self.quant_config.group_size) | ||
|
|
||
| self.data.copy_(qweight) | ||
| self.scale.data.copy_(weight_scale) | ||
|
|
||
|
|
||
| class RTNParameter(Parameter): | ||
| """A wrapper over Parameter that returns RTNTensor (a wrapper over Tensor) | ||
| when its data is accessed. We need this wrapper for the data loading phase | ||
| only, so we can intercept a weight copying function (torch.Tensor.copy_) | ||
| and apply quantization on-the-fly. | ||
| """ | ||
|
|
||
| def __new__(cls, data: torch.Tensor, **kwargs): | ||
| return super().__new__(cls, data=data, requires_grad=False) | ||
|
|
||
| def __init__(self, data: torch.Tensor, scale: torch.Tensor, | ||
| quant_config: RTNConfig) -> None: | ||
| self.scale = scale | ||
| self.quant_config = quant_config | ||
|
|
||
| @property | ||
| def data(self): | ||
| return RTNTensor(super().data, self.scale, self.quant_config) | ||
|
|
||
|
|
||
| class RTNLinearMethod(LinearMethodBase): | ||
| """Linear method for RTN. | ||
|
|
||
| Args: | ||
| quant_config: The RTN quantization config. | ||
| """ | ||
|
|
||
| def __init__(self, quant_config: RTNConfig): | ||
| self.quant_config = quant_config | ||
|
|
||
| def create_weights( | ||
| self, | ||
| layer: torch.nn.Module, | ||
| input_size_per_partition: int, | ||
| output_partition_sizes: list[int], | ||
| input_size: int, | ||
| output_size: int, | ||
| params_dtype: torch.dtype, | ||
| **extra_weight_attrs, | ||
| ): | ||
| output_size_per_partition = sum(output_partition_sizes) | ||
| num_groups_per_col = (input_size_per_partition // | ||
| self.quant_config.group_size | ||
| if self.quant_config.group_size != -1 else 1) | ||
|
|
||
| scale = Parameter( | ||
| torch.empty(output_size_per_partition, | ||
| num_groups_per_col, | ||
| dtype=params_dtype), | ||
| requires_grad=False, | ||
| ) | ||
| factor = 1 if self.quant_config.weight_bits == 8 else 2 | ||
|
|
||
| weight = RTNParameter(data=torch.empty(output_size_per_partition // | ||
| factor, | ||
| input_size_per_partition, | ||
| dtype=torch.int8), | ||
| scale=scale, | ||
| quant_config=self.quant_config) | ||
|
|
||
| layer.register_parameter("weight", weight) | ||
| set_weight_attrs(weight, { | ||
| **extra_weight_attrs, | ||
| "input_dim": 1, | ||
| "output_dim": 0, | ||
| }) | ||
|
|
||
| layer.register_parameter("scale", scale) | ||
| layer.output_size_per_partition = output_size_per_partition | ||
|
|
||
| def process_weights_after_loading(self, layer: torch.nn.Module) -> None: | ||
| """torch.compile does not know how to deal with a Parameter subclass | ||
| (aka RTNParameter). As we don't really need RTNParameters for the | ||
| forward pass, we replace them with equivalent instances of Parameters. | ||
| """ | ||
| old_weight = layer.weight | ||
| assert isinstance(old_weight, RTNParameter) | ||
| data = old_weight.data.data | ||
|
|
||
| delattr(layer, "weight") | ||
|
|
||
| new_weight = Parameter(data=data, requires_grad=False) | ||
| layer.register_parameter("weight", new_weight) | ||
|
|
||
| def apply(self, | ||
| layer: torch.nn.Module, | ||
| x: torch.Tensor, | ||
| bias: Optional[torch.Tensor] = None) -> torch.Tensor: | ||
| qweight = layer.weight | ||
| scale = layer.scale | ||
|
|
||
| weight = rtn_dequantize(qweight, scale) | ||
| out = F.linear(x, weight) | ||
| del weight | ||
| if bias is not None: | ||
| out.add_(bias) | ||
|
|
||
| return out | ||
|
|
||
|
|
||
| def rtn_quantize(tensor: torch.Tensor, num_bits: int, | ||
| group_size: int) -> tuple[torch.Tensor, torch.Tensor]: | ||
| """Quantize a tensor using per-group static scaling factor. | ||
|
|
||
| Args: | ||
| tensor: The input tensor. | ||
| num_bits: Target precision for the result (supported values are | ||
| 8 or 4). | ||
| group_size: Quantization granularity. | ||
| If equal to -1, each row in the input tensor is treated | ||
| as one group. | ||
| """ | ||
|
|
||
| q_range = 2**num_bits | ||
| num_groups = (tensor.shape[0] * tensor.shape[1] // | ||
| group_size if group_size != -1 else tensor.shape[0]) | ||
| """Calculate a scaling factor per input group. | ||
| """ | ||
| input_flat = tensor.reshape(num_groups, -1) | ||
| input_min = torch.min(input_flat, dim=1, keepdim=True)[0] | ||
| input_max = torch.max(input_flat, dim=1, keepdim=True)[0] | ||
| input_max_abs = torch.max(input_min.abs(), input_max.abs()) | ||
| scale = (input_max_abs * 2.0 / (q_range - 1)) | ||
| """Scale each input group, truncate and round to the nearest integer. | ||
| """ | ||
| scaled_input = input_flat / scale | ||
| scaled_input = scaled_input.clamp(-q_range // 2, q_range // 2 - 1) | ||
| scaled_input = scaled_input.round() | ||
|
|
||
| scale = scale.reshape(tensor.shape[0], -1).contiguous() | ||
| inputs_q = scaled_input.reshape(tensor.shape).to(torch.int8) | ||
| inputs_q = inputs_q.contiguous() | ||
|
|
||
| if num_bits == 4: | ||
| """Pack two 4-bit values into each byte. | ||
| """ | ||
| inputs_q = (inputs_q[:, 1::2] << 4) | (inputs_q[:, ::2] & 0xf) | ||
| inputs_q = inputs_q.reshape(tensor.shape[0] // 2, tensor.shape[1]) | ||
| inputs_q = inputs_q.contiguous() | ||
|
|
||
| return inputs_q, scale | ||
|
|
||
|
|
||
| def rtn_dequantize(tensor: torch.Tensor, scale: torch.Tensor) -> torch.Tensor: | ||
| """Dequantize a tensor using per-group static scaling factors. | ||
|
|
||
| Args: | ||
| tensor: The input tensor. | ||
| scale: The tensor with per-group scale factors. | ||
| """ | ||
|
|
||
| num_groups = scale.size(0) * scale.size(1) | ||
| input_dim, output_dim = tensor.shape | ||
|
|
||
| num_bits = 8 if input_dim == scale.size(0) else 4 | ||
| if num_bits == 4: | ||
| input_dim *= 2 | ||
|
|
||
| data = torch.empty((input_dim, output_dim), | ||
| dtype=scale.dtype, | ||
| device=tensor.device) | ||
|
|
||
| if num_bits == 8: | ||
| data.copy_(tensor) | ||
| else: | ||
| """Unpack two 4-bit values from each byte. | ||
| """ | ||
| tensor = tensor.reshape(input_dim, output_dim // 2) | ||
| for i in range(2): | ||
| data[:, i::2] = (tensor << 4 * (1 - i)) >> 4 | ||
| """Scale each input group with its scaling factor. | ||
| """ | ||
| scale = scale.reshape(num_groups, -1) | ||
| data = data.reshape(num_groups, -1) | ||
| data = torch.mul(data, scale) | ||
|
|
||
| input_deq = data.reshape((input_dim, output_dim)).contiguous() | ||
| return input_deq |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a reason why you are using Phi here, like to get around sharded weight loading? IIRC Phi models have their mergable layers like q/k/v already merged in the checkpoint as
qkv_proj. I notice you override the weight loading with your RTNParameter class so I'm curious if it works with an un-merged checkpoint like LlamaThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, absolutely it works with any dense model, including un-merged LLama checkpoints. The Phi model is an arbitrary choice of a small dense model, happy to change it to something else