🚨 Our latest method, QTIP, uses trellis quantization to achieve even higher quality quantized models. This codebase is no longer under active development.
QuIP# is a weight-only post-training quantization method that achieves state-of-the-art performance in extreme compression (
QuIP# is the first PTQ method where 3 bit models scale better than theoretically lossless 4 bit models.
Timed on a RTX6000 ADA.
Method | 2-7B | 2-70B |
---|---|---|
FP16 | 57.7 tok/s | OOM |
AQLM 2 Bit | 81.1 | 8.72 |
QuIP# 2 Bit | 176 | 21.9 |
- Clone the repo
- Install the requirements via
pip install -r requirements.txt
. - Build and install the CUDA inference kernels. (
cd quiptools && python setup.py install && cd ../
) - Install the fast-hadamard-transform package. This package is also available through pip but recently I've had issues installing it through pip.
Example quantization scripts for the Llama family of models are located in quantize_llama
. Follow these scripts to use QuIP# on other architectures. Within quantize_llama
:
hessian_offline_llama.py
contains code to generate model Hessians. Hessian calculation uses afp64
accumulator for numerical accuracy. Running this script on a device with slowfp64
capabilities will take longer -- you may want to change the accumulator tofp32
if so. The HF repo includes pregenerated Hessians for a variety of models.--batch_size
Batch size per GPU. Tune so you don't run out of memory.--devset_size
Size of devset to use for Hessian generation.--ctx_size
Context size (sequence length) to use for Hessian generation.--base_model
Full precision HF model.
quantize_finetune_llama.py
contains code to quantize llama with fine-tuning ("fine-tuning during quantization" in the paper).- To reproduce earlier QuIP# results without fine-tuning, pass
--ft_epochs 0
--save_path
Output path.--base_model
Full precision HF model. Llama 1 weights are available atrelaxml/Llama-1-<7,13,30,65>b-hf
.--hessian_path
Offline Hessians. We provide precomputed Hessians at repo_id'srelaxml/Hessians*-<n>
. These Hessians were computed withn
samples and the context length and attention mask used to train the original model. To download them, runpython scripts/download_hf.py --folder_path <local path to save Hessians> --repo_id <repo_id> --read_token <huggingface read token>
.--codebook
Codebook. UseE8P12
for 2 bits,E8P12RVQ3B
for 3 bits, andE8P12RVQ4B
for 4 bits (RVQ stands for residual vector quantization).--scale_override
and--resid_scale_override
. Post-incoherence processing scale overrides. We suggest using 0.9 forE8P12
and the default scales for 3 and 4 bit models. You may want to manually tune these for your specific model.--ft*
Various fine tuning arguments.--ft_grad_ckpt
turns on gradient checkpointing and--ft_train_mode
manifests the full quantized matrix during fine-tuning. We recommend turning--ft_train_mode
on if you have enough memory since it makes training go faster.
- To reproduce earlier QuIP# results without fine-tuning, pass
finetune_e2e_llama.py
tunes the sign vectors (SU/SV), layernorms, and language model head of a given model (the second fine-tuning step in the paper). The arguments are similar toquantize_finetune_llama.py
. You will need to convert the output of that script to a Hf model withhfize_llama.py
before running this script. The HF-ized model should be passed in through--hf_path
.hfize_llama.py
converts a quantized model to the HF format.
The scripts in quantize_llama
are written with the Llama architecture in mind.
However, QuIP# is adaptable to any architecture with linear layers.
To use QuIP# on a new architecture, identify the relevant linear layers and update the scripts in quantize_llama
.
Feel free to open a GitHub issue if you run into issues.
eval
contains evaluation scripts. These scripts may need CUDA_VISIBLE_DEVICES=0
if you run into CUDA errors due to how HF accelerate works.
eval_ppl.py
calculates perplexity on Wikitext2 and C4.eval_zeroshot.py
calculates performance on zeroshot tasks.eval_speed.py
times the forward pass for one token.
QuIP# was designed to support fast inference. Example inference kernels for recent NVIDIA GPUs can be found in the quiptools
folder.
We are currently missing a 1 bit matrix-vector multiply kernel needed to make 3 bit inference fast, so if you'd like to contribute feel free to open a pull request.
eval/interactive_gen.py
contains a very simple interactive generation script.
This script is very rudimentary and you may want to write your own - all it does is call HF's .generate()
function.
HF generate does not currently work out-of-the-box with CUDA graphs. Thus, this script will be very slow since most of the time is spent on kernel launches.
QuIP# should work with any codebase and people have reported success integrating it with vLLM, so we may switch away from HF in the future -- the purpose of this codebase is to provide a reference implementation for QuIP#.
[Update] #65 adds CUDA graph support to HF, so this codebase will support fast inference soon!
Example quantized models (mostly Llama 1 and 2) can be found on our HF repo.
To use them, pass the given HF repo_id to --hf_path
.
The 3 bit models are currently significantly slower than the 2 and 4 bit models during generation since we have not written an optimized matvec CUDA kernel for them yet.
Feel free to open a pull request with a link to your own quantized QuIP# model if you want us to list it here.
https://github.com/chu-tianxiang/QuIP-for-all contains a third party implementation of QuIP#. We have not verified the correctness of the repo, but it seems to work properly and has out of the box integration with other frameworks (vLLM, gpt-fast, etc).
Use of Llama models is governed by the Meta license available here. Use of Mistral models is governed by the Apache 2.0 license. Use of this code is governed by the GNU GPL v3 license.
If you found this work useful, please consider citing
@inproceedings{
tseng2024quip,
title={Qu{IP}\${\textbackslash}\#\$: Even Better {LLM} Quantization with Hadamard Incoherence and Lattice Codebooks},
author={Albert Tseng and Jerry Chee and Qingyao Sun and Volodymyr Kuleshov and Christopher De Sa},
booktitle={Forty-first International Conference on Machine Learning},
year={2024},
url={https://openreview.net/forum?id=9BrydUVcoe}
}