Skip to content
/ RaanA Public

Implementation of "RaanA: A Fast, Flexible, and Data-Efficient Post-Training Quantization Algorithm"

License

FFTYYY/RaanA

Repository files navigation

logo

 

RaanA: A Fast, Flexible, and Data-Efficient Post-Training Quantization Algorithm

This repo contains the implementation of the paper RaanA: A Fast, Flexible, and Data-Efficient Post-Training Quantization Algorithm. We've created a pypi package to make the algorithm ready to use! See below 👇 for the instructions on how to use it.

Installation

  • To install from pypi: pip install raana
  • To build from source:
    pip install build 
    git clone https://github.com/FFTYYY/RaanA
    cd RaanA
    python -m build 
    the generated .whl files will be in dist/.

Quick Start

from transformers import AutoTokenizer, LlamaForCausalLM
from raana import quantize, zeroshot_calibration, trick_centralize, trick_norm_row

# initialize your model
model     = LlamaForCausalLM.from_pretrained(...)
tokenizer = AutoTokenizer.from_pretrained(...)

# quantization
quantized_model = quantize(
    model,                                              # the model to quantize
    b_candidates    = list(range(1,9)),                 # allowed bit-width
    calibrate_data  = zeroshot_calibration(tokenizer),  # use zero-shot calibration
    avg_bits        = 3.3,                              # average number of bits
)["model"]

# evaluate your model
evaluete(quantized_model, ...)

A Complete Example

To run example quantization for llama2 on wikitext2 (and reproduce the result reported in the paper):

pip install raana
git clone https://github.com/FFTYYY/RaanA
cd RaanA/examples
python wikitext2.py --model=meta-llama/Llama-2-7b-hf --avgbits=3.3

See examples/wikitext2.py for a complete example usage.

Detailed Usage

The entry point of raana is ranna.quantize.

from torch.nn   import Module
from torch      import Tensor
from typing     import Callable

from raana.task_adaptor     import TaskAdaptor
from raana.rotations        import RandomRotation, default_rotation
from raana.tricks           import Trick
from raana.select_layers    import default_linear_selector
from raana.quantized_linear import default_weightbias_extractor, default_matmul
from raana.tricks           import trick_centralize, trick_norm_col
from raana                  import quantize

quantize(
    model               : Module,
    b_candidates        : list[float],
    calibrate_data      : TaskAdaptor,
    avg_bits            : float,
    linear_selector     : Callable[[Module], bool]          = default_linear_selector,
    rotation_maker      : Callable[[], RandomRotation]      = default_rotation,
    trick_makers        : list[Callable[[], Trick]]         = [trick_centralize, trick_norm_col],
    weightbias_extractor: Callable[[Module], tuple[Tensor, Tensor | None]] = default_weightbias_extractor,
    matmul              : Callable[[Tensor, Tensor, Tensor, int], Tensor]  = default_matmul,
)

Required Arguments

model: torch.nn.Module

  • The pytorch model to be quantized.

b_candidates: list[float]

  • Candidate number of bits allowed for each layer.
  • Can include float numbers smaller than 1. If so, less-than-one-bit quantization will be enabled.
  • Example: [0.5, 0.75, 1, 2, 3, 4].

calibrate_data: raana.task_adaptor.TaskAdaptor

  • The calibration data used for quantization.
  • For language modeling tasks, can use raana.task_adaptor.LMAdaptor( data: list[str], tokenizer: PreTrainedTokenizer)
  • For zero-shot calibration in language modeling, use raana.zeroshot_calibration(tokenizer).
  • For non-language modeling tasks, can write your own TaskAdaptor class.

avg_bits: float

  • Target average number of bits per quantized linear layer. The quantizer will search for the optimal bit allocation under this constraint.

Optional Arguments

linear_selector: Callable[[torch.nn.Module], bool]

  • A function to choose which sub-modules to quantize.
  • There are different types of linear modules in different model implementations (e.g. some models use nn.Linear while others use nn.Conv1d), so we allow the user to use this function to specify which linear modules are to quantize.
  • Default: selcte all torch.nn.Linear layers.

rotation_maker: Callable[[], raana.rotations.RandomRotation]

  • A function to construct a random rotation.
  • This parameter leaves flexibility for users to specify their own random rotation implementation.
  • The default implementation is randomized Hadamard Transformation, as described in the paper.
  • The Hadamard Transformation used in the default parameter is simply a matrix multiplication with the Hadamard matrix generated by scipy.linear.hadamard. In order to minimize the dependency of raana, we don't use any GPU fast Hadamard kernels in the default implementation. The users are encouraged to install fast Hadamard kernels themselves and pass them to the quantizer through this parameter.
  • We recommend users to install the fast Hadamard implementation from DAO-AILab and pass it to raana:
    from torch import Tensor
    from fast_hadamard_transform import hadamard_transform
    from raana.rotations import PiecewiseHadamard
    from raana import quantize
    
    def hadamard(X: Tensor):
        # normalize it by sqrt(d) to make it an orthornormal operator.
        return hadamard_transform(X) / (X.size(-1) ** 0.5) 
    
    quantize(
        ..., 
        rotation_maker = lambda: PiecewiseHadamard( hadamard = hadamard )
    )
  • Default: randomized Hadamard transformation. Uses scipy.linalg.hadamard as the implementation of Hadamard Transformation.

trick_makers: list[Callable[[], raana.tricks.Trick]]

  • List of functions to construct tricks. See the paper for the definition of "trick" here.
  • Currently implemented four tricks: trick_centralize, trick_pca, trick_norm_row, trick_norm_col.
  • Default: [trick_centralize, trick_norm_col].

weightbias_extractor: Callable[[nn.Module], tuple[Tensor, Tensor | None]]

  • A function to extract weight and bias matrices from a linear module and transform them into the standard shape.
  • The returned value of this function should be extracted weight and bias of the provided layer. weight should be a tensor of shape (d_in, d_out), and bias should be None or a tensor of shape (d_out, ).
  • Default: lambda layer: (layer.weight.t().data, layer.bias.data)

matmul: Callable[[Tensor, Tensor, Tensor, int], Tensor]

  • A function to perform low-precision matrix multiplication.
  • Since there are no standard implementations for low-precision uint-float matrix multiplications and we want to minimize the dependency of raana, we leave the implementation of matrix multiplication to users.
  • The input parameters are X, qW, rescale, B. X is a float tensor, qW is a B-bit uint tensor and rescale is a float rescale tensor. This return value of this function should be equal to (X@qW - ((2**B-1)/2.*X.sum(dim=-1)).view(-1,1)) * rescale.view(1,-1).
  • Default: transform everything to float32 and do standard matrix multiplication. Below is the default implementation.
    def default_matmul(X: tc.Tensor, qW: tc.Tensor, rescale: tc.Tensor, B: int):
        dtype = X.dtype
        X       = X.to(tc.float32)
        rescale = rescale.to(tc.float32).view(1, -1)
        q_bias  = (float(2 ** B - 1) / 2. * X.sum(dim = -1)).view(-1, 1)
        Z = (X @ qW.to(tc.float32)) * rescale
        Z = Z - q_bias * rescale
        return Z.to(dtype)

Returns

{
    "model" : torch.nn.Module,  # quantized model
    "bits"  : dict[str, float], # allocated bitwidth per layer
    "losses": list[float],      # calibration loss per calibration data
}

About

Implementation of "RaanA: A Fast, Flexible, and Data-Efficient Post-Training Quantization Algorithm"

Resources

License

Stars

Watchers

Forks