Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions tests/quantization/test_rtn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# SPDX-License-Identifier: Apache-2.0
# Copyright © 2025, Oracle and/or its affiliates.
"""Tests RTN quantization startup and generation,
doesn't test correctness
"""
import pytest

from tests.quantization.utils import is_quant_method_supported

MODELS = ["microsoft/Phi-3-mini-4k-instruct"]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason why you are using Phi here, like to get around sharded weight loading? IIRC Phi models have their mergable layers like q/k/v already merged in the checkpoint as qkv_proj. I notice you override the weight loading with your RTNParameter class so I'm curious if it works with an un-merged checkpoint like Llama

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, absolutely it works with any dense model, including un-merged LLama checkpoints. The Phi model is an arbitrary choice of a small dense model, happy to change it to something else



@pytest.mark.skipif(not is_quant_method_supported("rtn"),
reason="RTN is not supported on this GPU type.")
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["bfloat16"])
@pytest.mark.parametrize("max_tokens", [10])
def test_model_rtn_startup(
hf_runner,
vllm_runner,
example_prompts,
model: str,
dtype: str,
max_tokens: int,
) -> None:

with vllm_runner(model, dtype=dtype, quantization="rtn") as vllm_model:
vllm_model.generate_greedy(example_prompts, max_tokens)
3 changes: 3 additions & 0 deletions vllm/model_executor/layers/quantization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
"moe_wna16",
"torchao",
"auto-round",
"rtn",
]
QUANTIZATION_METHODS: list[str] = list(get_args(QuantizationMethods))

Expand Down Expand Up @@ -110,6 +111,7 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
from .neuron_quant import NeuronQuantConfig
from .ptpc_fp8 import PTPCFp8Config
from .qqq import QQQConfig
from .rtn import RTNConfig
from .torchao import TorchAOConfig
from .tpu_int8 import Int8TpuConfig

Expand Down Expand Up @@ -142,6 +144,7 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
"moe_wna16": MoeWNA16Config,
"torchao": TorchAOConfig,
"auto-round": AutoRoundConfig,
"rtn": RTNConfig
}
# Update the `method_to_config` with customized quantization methods.
method_to_config.update(_CUSTOMIZED_METHOD_TO_QUANT_CONFIG)
Expand Down
288 changes: 288 additions & 0 deletions vllm/model_executor/layers/quantization/rtn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,288 @@
# SPDX-License-Identifier: Apache-2.0
# Copyright © 2025, Oracle and/or its affiliates.

import os
from typing import Any, Optional

import torch
import torch.nn.functional as F
from torch.nn.parameter import Parameter

from vllm.logger import init_logger
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
set_weight_attrs)
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)

logger = init_logger(__name__)
"""By default, use 8 bit as target precision, but it can be
overridden by setting the RTN_NUM_BITS envvar
"""
NUM_BITS = os.getenv('RTN_NUM_BITS', "8")
"""By default, use group size of 128 parameters, but it can be
overridden by setting the RTN_GROUP_SIZE envvar
"""
GROUP_SIZE = os.getenv('RTN_GROUP_SIZE', "128")


class RTNConfig(QuantizationConfig):
"""Config class for RTN.
"""

def __init__(
self,
weight_bits: int = int(NUM_BITS),
group_size: int = int(GROUP_SIZE),
) -> None:
self.weight_bits = weight_bits
self.group_size = group_size

if self.weight_bits != 4 and self.weight_bits != 8:
raise ValueError(
"Currently, only 4-bit or 8-bit weight quantization is "
f"supported for RTN, but got {self.weight_bits} bits.")

def __repr__(self) -> str:
return (f"RTNConfig(weight_bits={self.weight_bits}, "
f"group_size={self.group_size})")

@classmethod
def get_name(cls) -> QuantizationMethods:
return "rtn"

@classmethod
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
return [torch.bfloat16, torch.half]

@classmethod
def get_min_capability(cls) -> int:
return 80

@classmethod
def get_config_filenames(cls) -> list[str]:
return []

@classmethod
def from_config(cls, config: dict[str, Any]) -> "RTNConfig":
weight_bits = cls.get_from_keys(config, ["bits"])
group_size = cls.get_from_keys(config, ["group_size"])
return cls(weight_bits, group_size)

def get_quant_method(self, layer: torch.nn.Module,
prefix: str) -> Optional["RTNLinearMethod"]:
if isinstance(layer, LinearBase):
return RTNLinearMethod(self)
return None


class RTNTensor:
"""A wrapper over Tensor that enables quantization on-the-fly by
overloading the copy_ method.
"""

def __init__(self, data: torch.Tensor, scale: torch.Tensor,
quant_config: RTNConfig) -> None:
self.data = data
self.scale = scale
self.quant_config = quant_config

def narrow(self, dim, start, length):
factor = 1 if self.quant_config.weight_bits == 8 else 2
return RTNTensor(
self.data.narrow(dim, start // factor, length // factor),
self.scale.narrow(dim, start, length), self.quant_config)

@property
def shape(self):
shape = self.data.shape
factor = 1 if self.quant_config.weight_bits == 8 else 2
return torch.Size((shape[0] * factor, shape[1]))

def copy_(self, loaded_weight: torch.Tensor) -> None:
qweight, weight_scale = rtn_quantize(loaded_weight.cuda(),
self.quant_config.weight_bits,
self.quant_config.group_size)

self.data.copy_(qweight)
self.scale.data.copy_(weight_scale)


class RTNParameter(Parameter):
"""A wrapper over Parameter that returns RTNTensor (a wrapper over Tensor)
when its data is accessed. We need this wrapper for the data loading phase
only, so we can intercept a weight copying function (torch.Tensor.copy_)
and apply quantization on-the-fly.
"""

def __new__(cls, data: torch.Tensor, **kwargs):
return super().__new__(cls, data=data, requires_grad=False)

def __init__(self, data: torch.Tensor, scale: torch.Tensor,
quant_config: RTNConfig) -> None:
self.scale = scale
self.quant_config = quant_config

@property
def data(self):
return RTNTensor(super().data, self.scale, self.quant_config)


class RTNLinearMethod(LinearMethodBase):
"""Linear method for RTN.

Args:
quant_config: The RTN quantization config.
"""

def __init__(self, quant_config: RTNConfig):
self.quant_config = quant_config

def create_weights(
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: list[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
output_size_per_partition = sum(output_partition_sizes)
num_groups_per_col = (input_size_per_partition //
self.quant_config.group_size
if self.quant_config.group_size != -1 else 1)

scale = Parameter(
torch.empty(output_size_per_partition,
num_groups_per_col,
dtype=params_dtype),
requires_grad=False,
)
factor = 1 if self.quant_config.weight_bits == 8 else 2

weight = RTNParameter(data=torch.empty(output_size_per_partition //
factor,
input_size_per_partition,
dtype=torch.int8),
scale=scale,
quant_config=self.quant_config)

layer.register_parameter("weight", weight)
set_weight_attrs(weight, {
**extra_weight_attrs,
"input_dim": 1,
"output_dim": 0,
})

layer.register_parameter("scale", scale)
layer.output_size_per_partition = output_size_per_partition

def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
"""torch.compile does not know how to deal with a Parameter subclass
(aka RTNParameter). As we don't really need RTNParameters for the
forward pass, we replace them with equivalent instances of Parameters.
"""
old_weight = layer.weight
assert isinstance(old_weight, RTNParameter)
data = old_weight.data.data

delattr(layer, "weight")

new_weight = Parameter(data=data, requires_grad=False)
layer.register_parameter("weight", new_weight)

def apply(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
qweight = layer.weight
scale = layer.scale

weight = rtn_dequantize(qweight, scale)
out = F.linear(x, weight)
del weight
if bias is not None:
out.add_(bias)

return out


def rtn_quantize(tensor: torch.Tensor, num_bits: int,
group_size: int) -> tuple[torch.Tensor, torch.Tensor]:
"""Quantize a tensor using per-group static scaling factor.

Args:
tensor: The input tensor.
num_bits: Target precision for the result (supported values are
8 or 4).
group_size: Quantization granularity.
If equal to -1, each row in the input tensor is treated
as one group.
"""

q_range = 2**num_bits
num_groups = (tensor.shape[0] * tensor.shape[1] //
group_size if group_size != -1 else tensor.shape[0])
"""Calculate a scaling factor per input group.
"""
input_flat = tensor.reshape(num_groups, -1)
input_min = torch.min(input_flat, dim=1, keepdim=True)[0]
input_max = torch.max(input_flat, dim=1, keepdim=True)[0]
input_max_abs = torch.max(input_min.abs(), input_max.abs())
scale = (input_max_abs * 2.0 / (q_range - 1))
"""Scale each input group, truncate and round to the nearest integer.
"""
scaled_input = input_flat / scale
scaled_input = scaled_input.clamp(-q_range // 2, q_range // 2 - 1)
scaled_input = scaled_input.round()

scale = scale.reshape(tensor.shape[0], -1).contiguous()
inputs_q = scaled_input.reshape(tensor.shape).to(torch.int8)
inputs_q = inputs_q.contiguous()

if num_bits == 4:
"""Pack two 4-bit values into each byte.
"""
inputs_q = (inputs_q[:, 1::2] << 4) | (inputs_q[:, ::2] & 0xf)
inputs_q = inputs_q.reshape(tensor.shape[0] // 2, tensor.shape[1])
inputs_q = inputs_q.contiguous()

return inputs_q, scale


def rtn_dequantize(tensor: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
"""Dequantize a tensor using per-group static scaling factors.

Args:
tensor: The input tensor.
scale: The tensor with per-group scale factors.
"""

num_groups = scale.size(0) * scale.size(1)
input_dim, output_dim = tensor.shape

num_bits = 8 if input_dim == scale.size(0) else 4
if num_bits == 4:
input_dim *= 2

data = torch.empty((input_dim, output_dim),
dtype=scale.dtype,
device=tensor.device)

if num_bits == 8:
data.copy_(tensor)
else:
"""Unpack two 4-bit values from each byte.
"""
tensor = tensor.reshape(input_dim, output_dim // 2)
for i in range(2):
data[:, i::2] = (tensor << 4 * (1 - i)) >> 4
"""Scale each input group with its scaling factor.
"""
scale = scale.reshape(num_groups, -1)
data = data.reshape(num_groups, -1)
data = torch.mul(data, scale)

input_deq = data.reshape((input_dim, output_dim)).contiguous()
return input_deq