Skip to content

Enhance Auto-Round #870

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Sep 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 22 additions & 10 deletions torchao/_models/llama/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,19 +128,28 @@ def run_evaluation(

_tokenizer = AutoTokenizer.from_pretrained(checkpoint_path.parent)
# parse args from quantization string:
# autoround-<model_device>-<quant_lm_head>-<iters>-<groupsize>-<batch_size>-<seqlen>-<nsamples>
# autoround-<model_device>-<quant_lm_head>-<iters>-<groupsize>-<batch_size>-<seqlen>-<nsamples>-<grad_acc_steps>-<c>
_quant_args = quantization.split("-")
_default_quant_args = [False, 200, 128, 8, 2048, 128]
_default_quant_args = [False, 200, 128, 8, 2048, 128, 1, 0]
_model_devie = _quant_args[1] if len(_quant_args) > 1 else device
_quant_args = _quant_args[2:]
quant_lm_head, iters, groupsize, batch_size, seqlen, nsamples = [
int(x) for x in _quant_args
] + _default_quant_args[len(_quant_args) :]
(
quant_lm_head,
iters,
groupsize,
batch_size,
seqlen,
nsamples,
grad_acc_steps,
compile_optimization_process,
) = [int(x) for x in _quant_args] + _default_quant_args[len(_quant_args) :]
model = model.to(_model_devie)
print(
(
f"Quantizing model with autoround(iters={iters}, groupsize={groupsize}, "
f"quant_lm_head={quant_lm_head}, batch_size={batch_size}, seqlen={seqlen}, nsamples={nsamples})"
f"quant_lm_head={quant_lm_head}, batch_size={batch_size}, seqlen={seqlen}, nsamples={nsamples}, "
f"gradient_accumulate_steps={grad_acc_steps}, "
f"compile_optimization_process={compile_optimization_process})"
)
)
with torch.device(_model_devie):
Expand All @@ -161,9 +170,11 @@ def run_evaluation(
is_target_module=is_target_module,
bits=4,
seqlen=seqlen,
bs=batch_size,
batch_size=batch_size,
iters=iters,
nsamples=nsamples,
gradient_accumulate_steps=grad_acc_steps,
compile_optimization_process=compile_optimization_process == 1,
)
model.to(device)
model.reset_caches()
Expand Down Expand Up @@ -195,9 +206,10 @@ def run_evaluation(
"--quantization",
type=str,
help=(
"Which quantization techniques to apply: int8dq, int8wo, fp6, int4wo-<groupsize>, int4wo-<groupsize>-gptq, "
"autoquant, autoquant-int4, int4wo-<groupsize>-hqq, uintx-<nbits>-<groupsize>, uintx-<nbits>-<groupsize>-hqq, "
"sparse-marlin, autoround-<model_device>-<quant_lm_head>-<iters>-<groupsize>-<batch_size>-<seqlen>-<nsamples>"
"Which quantization techniques to apply: int8dq, int8wo, fp6, int4wo-<groupsize>, "
"int4wo-<groupsize>-gptq, autoquant, autoquant-int4, int4wo-<groupsize>-hqq, "
"uintx-<nbits>-<groupsize>, uintx-<nbits>-<groupsize>-hqq, sparse-marlin, "
"autoround-<model_device>-<quant_lm_head>-<iters>-<groupsize>-<batch_size>-<seqlen>-<nsamples>-<grad_acc_steps>-<c>"
),
)
parser.add_argument('--compile', action='store_true', help='Whether to compile the model.')
Expand Down
58 changes: 40 additions & 18 deletions torchao/prototype/autoround/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,24 @@ Auto-Round is an advanced quantization algorithm designed for low-bit LLM infere
python autoround_llm.py -m /model/name/or/path
```

This script allows you to apply `Auto-Round` on a given model directly, more configurations options are list below:

| Argument |Default | Description |
|------------------------------------|----------------------------|-------------------------------------------------------------------|
| `model_name_or_path` |`"facebook/opt-125m"` | Pretrained model name or path |
| `dataset_name` | `"NeelNanda/pile-10k"` | Dataset name for calibration |
| `iters` | 200 | Number of steps for optimizing each block |
| `bits` | 4 | Number of bits for quantization |
| `batch_size` | 8 | Batch size for calibration |
| `nsamples` | 128 | Number of samples for calibration process |
| `seqlen` | 2048 | Sequence length for each samples |
| `group_size` | 128 | Group size for quantization |
| `gradient_accumulate_steps` | 1 | Number of steps for accumulating gradients <br> before performing the backward pass |
| `quant_lm_head` | `False` | Whether to quantize the `lm_head` |
| `use_optimized_layer_output` | `False` | Whether to use optimized layer output as input for the next layer |
| `compile_optimization_process` | `False` | Whether to compile the optimization process |
| `model_device` | `"cuda"` | Device for loading the float model (choices: `cpu`, `cuda`) |


> [!NOTE]
> Before running, ensure you have installed the `auto-round` with `pip install -r requirements.txt`.
Expand Down Expand Up @@ -71,31 +89,35 @@ quantize_(model, apply_auto_round(), is_target_module)

## End-to-End Results
### [meta-llama/Meta-Llama-3.1-8B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3.1-8B-Instruct)
| | Avg. | Mmlu | Piqa | Winogrande | Hellaswag | Lambada_openai |
| -------------- | ------- | ------ | ------ | ---------- | --------- | -------------- |
| bf16 | 0.7080 | 0.6783 | 0.8003 | 0.7403 | 0.5910 | 0.7303 |
| auto-round-4bit | 0.6988 | 0.6533 | 0.7949 | 0.7372 | 0.5837 | 0.7250 |
| torchao-int4wo | 0.6883 | 0.6363 | 0.7938 | 0.7348 | 0.5784 | 0.6980 |
| | Avg. | Mmlu | Piqa | Winogrande | Hellaswag | Lambada_openai |
| ---------------- | ------ | ------ | ------ | ---------- | --------- | -------------- |
| bf16 | 0.7080 | 0.6783 | 0.8003 | 0.7403 | 0.5910 | 0.7303 |
| torchao-int4wo | 0.6883 | 0.6363 | 0.7938 | 0.7348 | 0.5784 | 0.6980 |
| autoround-4bit | 0.6996 | 0.6669 | 0.7916 | 0.7285 | 0.5846 | 0.7262 |
| autoround-4bit* | 0.7010 | 0.6621 | 0.7976 | 0.7316 | 0.5847 | 0.7291 |

### [meta-llama/Meta-Llama-3-8B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct)
| | Avg. | Mmlu | Piqa | Winogrande | Hellaswag | Lambada_openai |
| -------------- | ------- | ------ | ------ | ---------- | --------- | -------------- |
| bf16 | 0.6881 | 0.6389 | 0.7840 | 0.7222 | 0.5772 | 0.7184 |
| auto-round-4bit | 0.6818 | 0.6232 | 0.7862 | 0.7230 | 0.5661 | 0.7105 |
| torchao-int4wo | 0.6728 | 0.5939 | 0.7737 | 0.7222 | 0.5612 | 0.7132 |
| | Avg. | Mmlu | Piqa | Winogrande | Hellaswag | Lambada_openai |
| ---------------- | ------ | ------ | ------ | ---------- | --------- | -------------- |
| bf16 | 0.6881 | 0.6389 | 0.7840 | 0.7222 | 0.5772 | 0.7184 |
| torchao-int4wo | 0.6728 | 0.5939 | 0.7737 | 0.7222 | 0.5612 | 0.7132 |
| autoround-4bit | 0.6796 | 0.6237 | 0.7758 | 0.7198 | 0.5664 | 0.7122 |
| autoround-4bit* | 0.6827 | 0.6273 | 0.7737 | 0.7348 | 0.5657 | 0.7120 |


### [meta-llama/Llama-2-7b-chat-hf](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf)
| | Avg. | Mmlu | Piqa | Winogrande | Hellaswag | Lambada_openai |
| -------------- | ------- | ------ | ------ | ---------- | --------- | -------------- |
| bf16 | 0.6347 | 0.4647 | 0.7644 | 0.6606 | 0.577 | 0.7070 |
| auto-round-4bit | 0.6327 | 0.4534 | 0.7590 | 0.6661 | 0.5706 | 0.7143 |
| torchao-int4wo | 0.6252 | 0.4427 | 0.7617 | 0.6654 | 0.5674 | 0.6889 |
| | Avg. | Mmlu | Piqa | Winogrande | Hellaswag | Lambada_openai |
| ---------------- | ------ | ------ | ------ | ---------- | --------- | -------------- |
| bf16 | 0.6347 | 0.4647 | 0.7644 | 0.6606 | 0.5770 | 0.7070 |
| torchao-int4wo | 0.6252 | 0.4427 | 0.7617 | 0.6654 | 0.5674 | 0.6889 |
| autoround-4bit | 0.6311 | 0.4548 | 0.7606 | 0.6614 | 0.5717 | 0.7072 |
| autoround-4bit* | 0.6338 | 0.4566 | 0.7661 | 0.6646 | 0.5688 | 0.7130 |

> [!NOTE]
> - `auto-round-4bit` represents the following configuration: `bits=4`, `iters=200`, `seqlen=2048`, `train_bs=8`, `group_size=128`, and `quant_lm_head=False`. <br>
> - `torchao-int4wo` represents `int4_weight_only(group_size=128)` and `quant_lm_head=False`.
> - If the model includes operations without a deterministic implementation (such as Flash Attention), the results may differ slightly.
> - `torchao-int4wo` quantizes the model to 4 bits with a group size of 128 (`int4_weight_only(group_size=128)`) while leaving the `lm-head` unquantized. <br>
> - `auto-round-4bit` uses the deafult configuration from [quick start](#quick-start). <br>
> - `auto-round-4bit*` follows the same settings as `auto-round-4bit`, but with `gradient_accumulate_steps=2` and `batch_size=4`, which accumulating two batches(4 samples per batch) before performing the backward pass. <br>
> - To reproduce results, run `eval_autoround.py` with `AO_USE_DETERMINISTIC_ALGORITHMS=1`.


## Credits
Expand Down
48 changes: 38 additions & 10 deletions torchao/prototype/autoround/autoround_llm.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import argparse
import logging
from typing import Optional

import torch

import torchao
import torchao.prototype.autoround.utils as ar_utils

from torchao.prototype.autoround.core import (
apply_auto_round,
prepare_model_for_applying_auto_round_,
Expand All @@ -26,9 +26,11 @@ def quantize_model_with_autoround_(
iters: int = 200,
seqlen: int = 2048,
dataset_name: str = "NeelNanda/pile-10k",
bs: int = 8,
batch_size: int = 8,
nsamples: int = 128,
use_optimized_layer_output: bool = False,
gradient_accumulate_steps: Optional[int] = 1,
compile_optimization_process: Optional[bool] = False,
):
# Step 1. Prepare the model for applying auto-round

Expand All @@ -42,6 +44,8 @@ def quantize_model_with_autoround_(
group_size,
iters,
use_optimized_layer_output,
gradient_accumulate_steps,
compile_optimization_process,
device=device,
)

Expand All @@ -50,7 +54,7 @@ def quantize_model_with_autoround_(
tokenizer,
seqlen=seqlen,
dataset_name=dataset_name,
bs=bs,
bs=batch_size,
nsamples=nsamples,
)
input_ids_lst = []
Expand Down Expand Up @@ -104,9 +108,11 @@ def main(args):
iters=args.iters,
seqlen=args.seqlen,
dataset_name=args.dataset_name,
bs=args.train_bs,
batch_size=args.batch_size,
nsamples=args.nsamples,
use_optimized_layer_output=args.use_optimized_layer_output,
gradient_accumulate_steps=args.gradient_accumulate_steps,
compile_optimization_process=args.compile_optimization_process,
)
# Revert the `use_cache` for generation stage.
model.config.use_cache = True
Expand All @@ -124,7 +130,7 @@ def main(args):
"--model_name_or_path",
type=str,
default="facebook/opt-125m",
help="Model name or path",
help="Pretrained model name or path",
)
parser.add_argument(
"--dataset_name",
Expand All @@ -136,37 +142,59 @@ def main(args):
"--iters",
default=200,
type=int,
help="Number of iterations for auto-round optimization",
help="Number of steps for optimizing each block",
)
parser.add_argument(
"--bits", default=4, type=int, help="Number of bits for quantization"
)
parser.add_argument(
"--train_bs", default=8, type=int, help="Batch size for auto-round optimization"
"--batch_size", default=8, type=int, help="Batch size for calibration"
)
parser.add_argument(
"--nsamples",
default=128,
type=int,
help="Number of samples for calibration process",
)
parser.add_argument(
"--group_size",
default=128,
type=int,
help="Group size for quantization",
)
parser.add_argument(
"--seqlen",
default=2048,
type=int,
help="Sequence length for calibration process",
help="Sequence length for each samples",
)
parser.add_argument(
"--gradient_accumulate_steps",
default=1,
type=int,
help=(
"Number of steps for accumulating gradients before performing"
"the backward pass when optimizing each target module"
),
)
parser.add_argument(
"--quant_lm_head",
default=False,
action="store_true",
help="Quantize the `lm_head` or not",
help="Whether to quantize the `lm_head`",
)
parser.add_argument(
"--use_optimized_layer_output",
default=False,
action="store_true",
help="Use the optimized layer output for next layer or not",
help="Whether to use optimized layer output as input for the next layer",
)
parser.add_argument(
"-c",
"--compile_optimization_process",
default=False,
action="store_true",
help="Whether to compile the optimization process",
)
parser.add_argument(
"-d",
Expand Down
18 changes: 15 additions & 3 deletions torchao/prototype/autoround/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ class _AutoRoundConfig:
group_size: int = 128
iters: int = 200
use_optimized_layer_output: bool = False
gradient_accumulate_steps: int = 1
compile_optimization_process: bool = False


_auto_round_config = _AutoRoundConfig()
Expand Down Expand Up @@ -82,6 +84,8 @@ def prepare_model_for_applying_auto_round_(
group_size: int = 128,
iters: int = 200,
use_optimized_layer_output: bool = False,
gradient_accumulate_steps: Optional[int] = 1,
compile_optimization_process: Optional[bool] = False,
device: Optional[torch.types.Device] = None,
):
"""Prepares the model for applying auto round optimization.
Expand All @@ -94,7 +98,10 @@ def prepare_model_for_applying_auto_round_(
group_size (int, optional): The group size for quantization. Defaults to 128.
iters (int, optional): The number of iterations for optimization. Defaults to 200.
use_optimized_layer_output (bool, optional): Whether to use optimized layer output. Defaults to False.
device (Optional[torch.types.Device], optional): The device to use for accelrating optimization and calibration.
gradient_accumulate_steps (Optional[int]): Number of steps for accumulating gradients before
performing the backward pass when optimizing each target module. Defaults to 1.
compile_optimization_process (Optional[bool]): Whether to compile the optimization process. Defaults to False.
device (Optional[torch.types.Device]): The device to use for accelrating optimization and calibration.
Defaults to None.
"""
_multi_tensor_config.device = device
Expand All @@ -105,6 +112,8 @@ def prepare_model_for_applying_auto_round_(
_auto_round_config.group_size = group_size
_auto_round_config.iters = iters
_auto_round_config.use_optimized_layer_output = use_optimized_layer_output
_auto_round_config.gradient_accumulate_steps = gradient_accumulate_steps
_auto_round_config.compile_optimization_process = compile_optimization_process

logging.warning(f"config {_auto_round_config}")

Expand Down Expand Up @@ -172,7 +181,7 @@ def to_uintx_weight(input_float):
quant_min = 0
quant_max = _auto_round_config.bits**2 - 1
block_size = (1, observed_linear.group_size)
from torchao.dtypes.uintx.Uintx import (
from torchao.dtypes.uintx.uintx import (
_BIT_WIDTH_TO_DTYPE,
UintxLayoutType,
)
Expand Down Expand Up @@ -312,9 +321,12 @@ def _apply_auto_round_optimization(
bits=config.bits,
iters=config.iters,
group_size=config.group_size,
gradient_accumulate_steps=config.gradient_accumulate_steps,
amp=True,
model_dtype=next(block.parameters()).dtype,
)
if config.compile_optimization_process:
rounder.quant_block_v2_ = torch.compile(rounder.quant_block_v2_)

with torch.enable_grad():
rounder.quant_block_v2_(
Expand All @@ -326,7 +338,7 @@ def _apply_auto_round_optimization(
block.to(orig_device)


@ar_utils.dump_elapsed_time()
@ar_utils.dump_elapsed_time(record=True)
@torch.no_grad()
def apply_auto_round_optimization(
module: torch.nn.Module,
Expand Down
Loading
Loading