Skip to content

Advanced Quantization Algorithm for LLMs. This is official implementation of "Optimize Weight Rounding via Signed Gradient Descent for the Quantization of LLMs"

License

Notifications You must be signed in to change notification settings

intel/auto-round

AutoRound

Advanced Quantization Algorithm for LLMs

python version license

AutoRound is an advanced quantization algorithm for low-bits LLM inference. It's tailored for a wide range of models. AutoRound adopts sign gradient descent to fine-tune rounding values and minmax values of weights in just 200 steps, which competes impressively against recent methods without introducing any additional inference overhead and keeping low tuning cost. The below image presents an overview of AutoRound. Check out our paper on arxiv for more details and visit low_bit_open_llm_leaderboard for more accuracy data and recipes across various models.

What's New

  • [2024/10] AutoRound has been integrated to torch/ao, check out their release note
  • [2024/10] Important update: We now support full-range symmetric quantization and have made it the default configuration. This configuration is typically better or comparable to asymmetric quantization and significantly outperforms other symmetric variants, especially at low bit-widths like 2-bit.
  • [2024/09] AutoRound format supports several LVM models, check out the examples Qwen2-Vl,Phi-3-vision, Llava
  • [2024/08] AutoRound format supports Intel Gaudi2 devices. Please refer to Intel/Qwen2-7B-int4-inc.
  • [2024/08] AutoRound introduces several experimental features, including fast tuning of norm/bias parameters (for 2-bit and W4A4), activation quantization, and the mx_fp data type.

Installation

Build from Source

pip install -vvv --no-build-isolation -e .

Install from pypi

pip install auto-round

Model Quantization

Basic Usage (Gaudi2/CPU/GPU)

A user guide detailing the full list of supported arguments is provided by calling auto-round -h on the terminal. Alternatively, you can use auto_round instead of auto-round. Set the format you want in format and multiple formats exporting has been supported.

CUDA_VISIBLE_DEVICES=0 auto-round \
    --model facebook/opt-125m \
    --bits 4 \
    --group_size 128 \
    --format "auto_round,auto_gptq" \
    --disable_eval \
    --output_dir ./tmp_autoround

We provide two recipes for best accuracy and fast running speed with low memory. Details as below.

Other Recipes
## best accuracy, 3X slower, low_gpu_mem_usage could save ~20G but ~30% slower
CUDA_VISIBLE_DEVICES=0 auto-round \
  --model facebook/opt-125m \
  --bits 4 \
  --group_size 128 \
  --nsamples 512 \
  --iters 1000 \
  --low_gpu_mem_usage \
  --disable_eval 
## fast and low memory, 2-3X speedup, slight accuracy drop at W4G128
CUDA_VISIBLE_DEVICES=0 auto-round \
  --model facebook/opt-125m \
  --bits 4 \
  --group_size 128 \
  --nsamples 128 \
  --iters 200 \
  --seqlen 512 \
  --batch_size 4 \
  --disable_eval 

Formats

AutoRound Format: This format is well-suited for CPU, HPU devices, 2 bits, as well as mixed-precision inference. [2,4] bits are supported. It also benefits from the Marlin kernel, which can boost inference performance notably. However, it has not yet gained widespread community adoption.

AutoGPTQ Format: This format is well-suited for symmetric quantization on CUDA devices and is widely adopted by the community, [2,3,4,8] bits are supported. It also benefits from the Marlin kernel, which can boost inference performance notably. However, the asymmetric kernel has issues that can cause considerable accuracy drops, particularly at 2-bit quantization and small models. Additionally, symmetric quantization tends to perform poorly at 2-bit precision.

AutoAWQ Format: This format is well-suited for asymmetric 4-bit quantization on CUDA devices and is widely adopted within the community, only 4-bits quantization is supported. It features specialized layer fusion tailored for Llama models.

API Usage (Gaudi2/CPU/GPU)

from transformers import AutoModelForCausalLM, AutoTokenizer

model_name = "facebook/opt-125m"
model = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)

from auto_round import AutoRound

bits, group_size, sym = 4, 128, True
autoround = AutoRound(model, tokenizer, bits=bits, group_size=group_size, sym=sym)

## the best accuracy, 3X slower, low_gpu_mem_usage could save ~20G but ~30% slower
# autoround = AutoRound(model, tokenizer, nsamples=512, iters=1000, low_gpu_mem_usage=True, bits=bits, group_size=group_size, sym=sym)

## fast and low memory, 2-3X speedup, slight accuracy drop at W4G128
# autoround = AutoRound(model, tokenizer, nsamples=128, iters=200, seqlen=512, batch_size=4, bits=bits, group_size=group_size, sym=sym )

autoround.quantize()
output_dir = "./tmp_autoround"
## format= 'auto_round'(default in version>0.3.0), 'auto_gptq', 'auto_awq'
autoround.save_quantized(output_dir, format='auto_round', inplace=True) 
Detailed Hyperparameters
  • model: The PyTorch model to be quantized.

  • tokenizer: An optional tokenizer for processing input data. If none, a dataset must be provided.

  • bits (int): Number of bits for quantization (default is 4).

  • group_size (int): Size of the quantization group (default is 128).

  • sym (bool): Whether to use symmetric quantization (default is True).

  • enable_quanted_input (bool): Whether to use the output of the previous quantized block as the input for the current block for tuning (default is True).

  • enable_minmax_tuning (bool): Whether to enable weight min-max tuning (default is True).

  • iters (int): Number of tuning iterations (default is 200).

  • lr (float): The learning rate for rounding value (default is None, it will be set to 1.0/iters automatically).

  • minmax_lr (float): The learning rate for min-max tuning (default is None, it will be set to lr automatically).

  • nsamples (int): Number of samples for tuning (default is 128).

  • seqlen (int): Data length of the sequence for tuning (default is 2048).

  • batch_size (int): Batch size for training (default is 8).

  • scale_dtype (str): The data type of quantization scale to be used (default is "float16"), different kernels have different choices.

  • amp (bool): Whether to use automatic mixed precision (default is True).

  • nblocks (int): Packing several blocks as one for tuning together (default is 1).

  • gradient_accumulate_steps (int): Number of gradient accumulation steps (default is 1).

  • low_gpu_mem_usage (bool): Whether to save GPU memory at the cost of ~20% more tuning time (default is False).

  • dataset Union[str, list, tuple, torch.utils.data.DataLoader]: The dataset name for tuning (default is " NeelNanda/pile-10k"). Local json file and combination of datasets have been supported, e.g. " ./tmp.json,NeelNanda/pile-10k:train, mbpp:train+validation+test"

  • layer_config (dict): Configuration for weight quantization (default is None), mainly for mixed bits or mixed precision.

  • device: The device to be used for tuning. The default is set to 'auto', allowing for automatic detection.

Model Inference

Please run the quantization code first

AutoRound format

CPU: auto_round version >0.3.1, pip install intel-extension-for-pytorch(much higher speed on Intel CPU) or pip install intel-extension-for-transformers,

HPU: docker image with Gaudi Software Stack is recommended. More details can be found in Gaudi Guide.

CUDA: no extra operations for sym quantization, for asym quantization, need to install auto-round from source

CPU/HPU/CUDA

from transformers import AutoModelForCausalLM, AutoTokenizer
from auto_round import AutoRoundConfig

backend = "auto"  ##cpu, hpu, cuda, cuda:marlin(supported in auto_round>0.3.1 and 'pip install -v gptqmodel --no-build-isolation')
quantization_config = AutoRoundConfig(
    backend=backend
)
quantized_model_path = "./tmp_autoround"
model = AutoModelForCausalLM.from_pretrained(quantized_model_path,
                                             device_map=backend.split(':')[0], quantization_config=quantization_config)
tokenizer = AutoTokenizer.from_pretrained(quantized_model_path)
text = "There is a girl who likes adventure,"
inputs = tokenizer(text, return_tensors="pt").to(model.device)
print(tokenizer.decode(model.generate(**inputs, max_new_tokens=50)[0]))

Evaluation
auto-round --model saved_quantized_model \
    --eval \
    --task lambada_openai \
    --eval_bs 1

AutoGPTQ/AutoAWQ format

from transformers import AutoModelForCausalLM, AutoTokenizer

quantized_model_path = "./tmp_autoround"
model = AutoModelForCausalLM.from_pretrained(quantized_model_path,
                                             device_map="auto")
tokenizer = AutoTokenizer.from_pretrained(quantized_model_path)
text = "There is a girl who likes adventure,"
inputs = tokenizer(text, return_tensors="pt").to(model.device)
print(tokenizer.decode(model.generate(**inputs, max_new_tokens=50)[0]))

Support List

AutoRound supports basically all the major large language models.

Please note that an asterisk (*) indicates third-party quantized models, which may lack accuracy data and use a different recipe. We greatly appreciate their efforts and encourage more users to share their models, as we cannot release most of the models ourselves.

Model Supported
meta-llama/Meta-Llama-3.1-70B-Instruct recipe
meta-llama/Meta-Llama-3.1-8B-Instruct model-kaitchup-autogptq-int4*, model-kaitchup-autogptq-sym-int4*, recipe
meta-llama/Meta-Llama-3.1-8B model-kaitchup-autogptq-sym-int4*
Qwen/Qwen-VL accuracy, recipe
Qwen/Qwen2-7B model-autoround-sym-int4, model-autogptq-sym-int4
Qwen/Qwen2-57B-A14B-Instruct model-autoround-sym-int4,model-autogptq-sym-int4
01-ai/Yi-1.5-9B model-LnL-AI-autogptq-int4*
01-ai/Yi-1.5-9B-Chat model-LnL-AI-autogptq-int4*
Intel/neural-chat-7b-v3-3 model-autogptq-int4
Intel/neural-chat-7b-v3-1 model-autogptq-int4
TinyLlama-1.1B-intermediate model-LnL-AI-autogptq-int4*
mistralai/Mistral-7B-v0.1 model-autogptq-lmhead-int4, model-autogptq-int4
google/gemma-2b model-autogptq-int4
tiiuae/falcon-7b model-autogptq-int4-G64
sapienzanlp/modello-italia-9b model-fbaldassarri-autogptq-int4*
microsoft/phi-2 model-autoround-sym-int4 model-autogptq-sym-int4
microsoft/Phi-3.5-mini-instruct model-kaitchup-autogptq-sym-int4*
microsoft/Phi-3-vision-128k-instruct recipe
mistralai/Mistral-7B-Instruct-v0.2 accuracy, recipe, example
mistralai/Mixtral-8x7B-Instruct-v0.1 accuracy, recipe, example
mistralai/Mixtral-8x7B-v0.1 accuracy, recipe, example
meta-llama/Meta-Llama-3-8B-Instruct accuracy, recipe, example
google/gemma-7b accuracy, recipe, example
meta-llama/Llama-2-7b-chat-hf accuracy, recipe, example
Qwen/Qwen1.5-7B-Chat accuracy, sym recipe, asym recipe , example
baichuan-inc/Baichuan2-7B-Chat accuracy, recipe, example
01-ai/Yi-6B-Chat accuracy, recipe, example
facebook/opt-2.7b accuracy, recipe, example
bigscience/bloom-3b accuracy, recipe, example
EleutherAI/gpt-j-6b accuracy, recipe, example

Integration

AutoRound has been integrated into multiple repositories.

Intel Neural Compressor

ModelCloud/GPTQModel

pytorch/ao

Reference

If you find AutoRound useful for your research, please cite our paper:

@article{cheng2023optimize,
  title={Optimize Weight Rounding via Signed Gradient Descent for the Quantization of LLMs},
  author={Cheng, Wenhua and Zhang, Weiwei and Shen, Haihao and Cai, Yiyang and He, Xin and Lv, Kaokao and Liu, Yi},
  journal={arXiv preprint arXiv:2309.05516},
  year={2023}
}