This repo contains the implementation of the paper RaanA: A Fast, Flexible, and Data-Efficient Post-Training Quantization Algorithm. We've created a pypi package to make the algorithm ready to use! See below 👇 for the instructions on how to use it.
- To install from pypi:
pip install raana - To build from source:
the generated
pip install build git clone https://github.com/FFTYYY/RaanA cd RaanA python -m build.whlfiles will be indist/.
from transformers import AutoTokenizer, LlamaForCausalLM
from raana import quantize, zeroshot_calibration, trick_centralize, trick_norm_row
# initialize your model
model = LlamaForCausalLM.from_pretrained(...)
tokenizer = AutoTokenizer.from_pretrained(...)
# quantization
quantized_model = quantize(
model, # the model to quantize
b_candidates = list(range(1,9)), # allowed bit-width
calibrate_data = zeroshot_calibration(tokenizer), # use zero-shot calibration
avg_bits = 3.3, # average number of bits
)["model"]
# evaluate your model
evaluete(quantized_model, ...)To run example quantization for llama2 on wikitext2 (and reproduce the result reported in the paper):
pip install raana
git clone https://github.com/FFTYYY/RaanA
cd RaanA/examples
python wikitext2.py --model=meta-llama/Llama-2-7b-hf --avgbits=3.3
See examples/wikitext2.py for a complete example usage.
The entry point of raana is ranna.quantize.
from torch.nn import Module
from torch import Tensor
from typing import Callable
from raana.task_adaptor import TaskAdaptor
from raana.rotations import RandomRotation, default_rotation
from raana.tricks import Trick
from raana.select_layers import default_linear_selector
from raana.quantized_linear import default_weightbias_extractor, default_matmul
from raana.tricks import trick_centralize, trick_norm_col
from raana import quantize
quantize(
model : Module,
b_candidates : list[float],
calibrate_data : TaskAdaptor,
avg_bits : float,
linear_selector : Callable[[Module], bool] = default_linear_selector,
rotation_maker : Callable[[], RandomRotation] = default_rotation,
trick_makers : list[Callable[[], Trick]] = [trick_centralize, trick_norm_col],
weightbias_extractor: Callable[[Module], tuple[Tensor, Tensor | None]] = default_weightbias_extractor,
matmul : Callable[[Tensor, Tensor, Tensor, int], Tensor] = default_matmul,
)model: torch.nn.Module
- The pytorch model to be quantized.
b_candidates: list[float]
- Candidate number of bits allowed for each layer.
- Can include float numbers smaller than 1. If so, less-than-one-bit quantization will be enabled.
- Example:
[0.5, 0.75, 1, 2, 3, 4].
calibrate_data: raana.task_adaptor.TaskAdaptor
- The calibration data used for quantization.
- For language modeling tasks, can use
raana.task_adaptor.LMAdaptor( data: list[str], tokenizer: PreTrainedTokenizer) - For zero-shot calibration in language modeling, use
raana.zeroshot_calibration(tokenizer). - For non-language modeling tasks, can write your own
TaskAdaptorclass.
avg_bits: float
- Target average number of bits per quantized linear layer. The quantizer will search for the optimal bit allocation under this constraint.
linear_selector: Callable[[torch.nn.Module], bool]
- A function to choose which sub-modules to quantize.
- There are different types of linear modules in different model implementations (e.g. some models use
nn.Linearwhile others usenn.Conv1d), so we allow the user to use this function to specify which linear modules are to quantize. - Default: selcte all
torch.nn.Linearlayers.
rotation_maker: Callable[[], raana.rotations.RandomRotation]
- A function to construct a random rotation.
- This parameter leaves flexibility for users to specify their own random rotation implementation.
- The default implementation is randomized Hadamard Transformation, as described in the paper.
- The Hadamard Transformation used in the default parameter is simply a matrix multiplication with the Hadamard matrix generated by
scipy.linear.hadamard. In order to minimize the dependency ofraana, we don't use any GPU fast Hadamard kernels in the default implementation. The users are encouraged to install fast Hadamard kernels themselves and pass them to the quantizer through this parameter. - We recommend users to install the fast Hadamard implementation from DAO-AILab and pass it to
raana:from torch import Tensor from fast_hadamard_transform import hadamard_transform from raana.rotations import PiecewiseHadamard from raana import quantize def hadamard(X: Tensor): # normalize it by sqrt(d) to make it an orthornormal operator. return hadamard_transform(X) / (X.size(-1) ** 0.5) quantize( ..., rotation_maker = lambda: PiecewiseHadamard( hadamard = hadamard ) )
- Default: randomized Hadamard transformation. Uses
scipy.linalg.hadamardas the implementation of Hadamard Transformation.
trick_makers: list[Callable[[], raana.tricks.Trick]]
- List of functions to construct tricks. See the paper for the definition of "trick" here.
- Currently implemented four tricks:
trick_centralize,trick_pca,trick_norm_row,trick_norm_col. - Default:
[trick_centralize, trick_norm_col].
weightbias_extractor: Callable[[nn.Module], tuple[Tensor, Tensor | None]]
- A function to extract weight and bias matrices from a linear module and transform them into the standard shape.
- The returned value of this function should be extracted
weightandbiasof the provided layer.weightshould be a tensor of shape(d_in, d_out), andbiasshould beNoneor a tensor of shape(d_out, ). - Default:
lambda layer: (layer.weight.t().data, layer.bias.data)
matmul: Callable[[Tensor, Tensor, Tensor, int], Tensor]
- A function to perform low-precision matrix multiplication.
- Since there are no standard implementations for low-precision uint-float matrix multiplications and we want to minimize the dependency of
raana, we leave the implementation of matrix multiplication to users. - The input parameters are
X, qW, rescale, B.Xis a float tensor,qWis aB-bit uint tensor andrescaleis a float rescale tensor. This return value of this function should be equal to(X@qW - ((2**B-1)/2.*X.sum(dim=-1)).view(-1,1)) * rescale.view(1,-1). - Default: transform everything to float32 and do standard matrix multiplication. Below is the default implementation.
def default_matmul(X: tc.Tensor, qW: tc.Tensor, rescale: tc.Tensor, B: int): dtype = X.dtype X = X.to(tc.float32) rescale = rescale.to(tc.float32).view(1, -1) q_bias = (float(2 ** B - 1) / 2. * X.sum(dim = -1)).view(-1, 1) Z = (X @ qW.to(tc.float32)) * rescale Z = Z - q_bias * rescale return Z.to(dtype)
{
"model" : torch.nn.Module, # quantized model
"bits" : dict[str, float], # allocated bitwidth per layer
"losses": list[float], # calibration loss per calibration data
}