Skip to content

AWQ Modifier #1177

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 41 commits into from
Apr 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
860f3e3
cherry picked files from stale PR #181 branch awq-feature-branch
brian-dellabetta Feb 18, 2025
34fa92a
updated to be compatible with latest, unit tests passing
brian-dellabetta Feb 18, 2025
f67b386
switch to using HooksMixin api
brian-dellabetta Feb 18, 2025
f341dc0
pydantic serialization issue fix
brian-dellabetta Feb 18, 2025
ee76752
switch to accelerate with align_module_device
brian-dellabetta Feb 19, 2025
9e415f2
AWQ running but OOMs unless NUM_CALIBRATION_SAMPLES and MAX_SEQUENCE_…
brian-dellabetta Feb 19, 2025
db767b7
working with larger num_calibration_samples
brian-dellabetta Feb 20, 2025
15a0b16
fix pile dataset issue
brian-dellabetta Feb 20, 2025
91ad7fc
updated config dataclasses
brian-dellabetta Feb 24, 2025
c1c6a6c
OOM error resolved
brian-dellabetta Feb 25, 2025
eb32054
codereview updates
brian-dellabetta Feb 25, 2025
c7be277
minor touchups
brian-dellabetta Feb 25, 2025
ab32f21
updates from debugging
brian-dellabetta Mar 3, 2025
ff857e5
styling
brian-dellabetta Mar 4, 2025
3e79d37
slightly improved rtn calculate_qparams logic
brian-dellabetta Mar 5, 2025
80767ab
code cleanup
brian-dellabetta Mar 10, 2025
d352bcf
rename smoothquant private vars
brian-dellabetta Mar 10, 2025
7ed2e72
squashed codereview updates for rebase
brian-dellabetta Mar 19, 2025
ea41fe5
cleanup fixes from rebase
brian-dellabetta Mar 19, 2025
433bb2b
awq mappings registry
brian-dellabetta Mar 19, 2025
2519643
remove empty_cache calls
brian-dellabetta Mar 19, 2025
3b9b813
resolve attention module forward missing attention_mask input
brian-dellabetta Mar 19, 2025
698b057
improve order of check for optional kwargs setting to None
brian-dellabetta Mar 20, 2025
7b9d85e
run awq one shot example
brian-dellabetta Mar 21, 2025
ab962ce
clean up awq_one_shot example
brian-dellabetta Mar 26, 2025
da16def
rename bits to num_bits
brian-dellabetta Mar 31, 2025
c7e274f
added TODOs
brian-dellabetta Apr 1, 2025
90c266c
update example file
brian-dellabetta Apr 10, 2025
7306de7
revise get_parent_by_name test
brian-dellabetta Apr 10, 2025
99cf589
revert smoothquant changes
brian-dellabetta Apr 10, 2025
64690a8
revert smoothquant changes
brian-dellabetta Apr 10, 2025
cb6f840
sanitize_kwargs cleanup
brian-dellabetta Apr 10, 2025
018b255
remove deprecated AWQModifier apply_clip
brian-dellabetta Apr 14, 2025
39a4745
PR revision
brian-dellabetta Apr 14, 2025
58c3968
add lifecycle to docstring
brian-dellabetta Apr 14, 2025
2980c05
update docstring
brian-dellabetta Apr 15, 2025
f7ece22
remove comment
brian-dellabetta Apr 15, 2025
6e7468b
rearrange so it's clear when hooks are removed
brian-dellabetta Apr 16, 2025
dc63821
Merge branch 'main' into bdellabe/awq-modifier-v3
brian-dellabetta Apr 17, 2025
3beba89
style fixes
brian-dellabetta Apr 17, 2025
d1d3766
revisions from codereview
brian-dellabetta Apr 18, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 106 additions & 0 deletions examples/awq/awq_one_shot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
import lm_eval
from compressed_tensors.quantization import (
QuantizationArgs,
QuantizationScheme,
QuantizationStrategy,
QuantizationType,
)
from lm_eval.utils import make_table
from transformers import AutoModelForCausalLM, AutoTokenizer

from llmcompressor import oneshot
from llmcompressor.modifiers.awq import AWQModifier
from llmcompressor.modifiers.quantization import QuantizationModifier

# This example demonstrates how to:
# 1) Run the `llm-compressor` implementation of AWQ
# 2) Evaluate the compressed model with the lm_eval framework

MODEL_ID = "meta-llama/Meta-Llama-3-8B-Instruct"
DATASET_ID = "mit-han-lab/pile-val-backup"
DATASET_SPLIT = "validation"
NUM_CALIBRATION_SAMPLES = 256
MAX_SEQUENCE_LENGTH = 512
OUTPUT_DIR = MODEL_ID.split("/")[-1] + "-awq-asym"

#
# 1) Run LLM Compressor AWQ implementation
#

recipe = [
AWQModifier(bits=4, symmetric=False),
QuantizationModifier(
ignore=["lm_head"],
config_groups={
"group_0": QuantizationScheme(
targets=["Linear"],
weights=QuantizationArgs(
num_bits=4,
type=QuantizationType.INT,
dynamic=False,
symmetric=False,
strategy=QuantizationStrategy.GROUP,
group_size=128,
),
)
},
),
]

model = AutoModelForCausalLM.from_pretrained(
MODEL_ID, device_map="auto", torch_dtype="auto"
)
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)


def get_calib_dataset(tokenizer):
from datasets import load_dataset

ds = load_dataset(
DATASET_ID,
split=f"{DATASET_SPLIT}[:{NUM_CALIBRATION_SAMPLES*100}]",
)

def preprocess(example):
return {
"input_ids": tokenizer.encode(example["text"].strip())[:MAX_SEQUENCE_LENGTH]
}

ds = (
ds.shuffle(seed=42)
.map(preprocess, remove_columns=ds.column_names)
.filter(lambda example: len(example["input_ids"]) >= MAX_SEQUENCE_LENGTH)
.select(range(NUM_CALIBRATION_SAMPLES))
)

return ds


oneshot(
model=model,
dataset=get_calib_dataset(tokenizer=tokenizer),
recipe=recipe,
output_dir=OUTPUT_DIR,
max_seq_length=MAX_SEQUENCE_LENGTH,
num_calibration_samples=NUM_CALIBRATION_SAMPLES,
)

print("Done! model saved to", OUTPUT_DIR)

#
# 2) Evaluate model on wikitext perplexity
#

results = lm_eval.simple_evaluate(
model="vllm",
model_args={
"pretrained": OUTPUT_DIR,
"add_bos_token": True,
"dtype": "bfloat16",
"gpu_memory_utilization": 0.5,
},
tasks=["wikitext"],
num_fewshot=5,
batch_size="auto",
)
print(make_table(results))
4 changes: 4 additions & 0 deletions src/llmcompressor/modifiers/awq/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# flake8: noqa

from .base import *
from .mappings import *
Loading