Skip to content

Commit b5efafa

Browse files
authored
Moving PR from llm-compressor#2112
1 parent 797d301 commit b5efafa

File tree

3 files changed

+400
-0
lines changed

3 files changed

+400
-0
lines changed

src/compressed_tensors/converters/__init__.py

Whitespace-only changes.
Lines changed: 344 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,344 @@
1+
"""
2+
Convert AutoAWQ models to compressed-tensors compatible format.
3+
4+
This module offers the functionality to convert models quantized with AutoAWQ into
5+
compressed models in compressed-tensors's format, which can then be served with vLLM.
6+
This module can be used as a CLI tool or as a Python API.
7+
8+
## CLI Usage
9+
10+
```sh
11+
python -m compressed_tensors.converters.autoawq \
12+
--model-name-or-path /path/to/model \
13+
--output-dir /path/to/compressed/model \
14+
--quantization-format naive-quantized
15+
```
16+
17+
For more information, run `python -m compressed_tensors.converters.autoawq --help`
18+
or refer to the `ConversionArgs` dataclass below.
19+
20+
## Python API Usage
21+
22+
```python
23+
from llmcompressor.modifiers.awq.convert_autoawq import load_and_convert_from_autoawq
24+
25+
awq_model_path = "/path/to/model" # can also be model_id on huggingface hub
26+
model = load_and_convert_from_autoawq(awq_model_path)
27+
model.generate(...) # the converted model is now ready to be used.
28+
```
29+
"""
30+
31+
import glob
32+
import os
33+
import re
34+
from dataclasses import dataclass, field
35+
from pathlib import Path
36+
from tempfile import TemporaryDirectory
37+
from typing import Any, Literal, cast
38+
39+
import torch
40+
import transformers
41+
from auto_round.export.export_to_awq.utils import (
42+
reverse_awq_order,
43+
unpack_awq,
44+
)
45+
from compressed_tensors import ModelCompressor
46+
from compressed_tensors.quantization import (
47+
QuantizationArgs,
48+
QuantizationConfig,
49+
QuantizationScheme,
50+
QuantizationStatus,
51+
QuantizationStrategy,
52+
QuantizationType,
53+
)
54+
from huggingface_hub import load_state_dict_from_file, snapshot_download
55+
56+
57+
def is_autoawq_model(model_path: Path, trust_remote_code: bool = False) -> bool:
58+
config = transformers.AutoConfig.from_pretrained(
59+
model_path, trust_remote_code=trust_remote_code
60+
)
61+
if not hasattr(config, "quantization_config"):
62+
return False
63+
64+
quantization_config = cast(dict[str, Any], config.quantization_config)
65+
return quantization_config.get("quant_method") == "awq"
66+
67+
68+
def resolve_model_path(model_name_or_path: str) -> Path:
69+
if os.path.isdir(model_name_or_path):
70+
return Path(model_name_or_path)
71+
else:
72+
# If the input is a model ID, download the model from the Hugging Face Hub and
73+
# return the path to the local directory.
74+
return Path(snapshot_download(model_name_or_path))
75+
76+
77+
def load_state_dict_from_model_dir(model_path: Path) -> dict[str, torch.Tensor]:
78+
weight_files = glob.glob(str(model_path / "*.safetensors"))
79+
if not weight_files:
80+
weight_files = glob.glob(str(model_path / "*.bin"))
81+
82+
state_dict = {}
83+
for weight_file in weight_files:
84+
state_dict.update(
85+
load_state_dict_from_file(
86+
weight_file, map_location="cpu", weights_only=True
87+
)
88+
)
89+
return state_dict
90+
91+
92+
def dequantize_gemm(
93+
state_dict: dict[str, torch.Tensor], prefix: str, autoawq_config: dict[str, Any]
94+
) -> None:
95+
num_bits = cast(int, autoawq_config.get("bits"))
96+
group_size = cast(int, autoawq_config.get("group_size"))
97+
98+
qweight = state_dict.pop(f"{prefix}.qweight")
99+
scales = state_dict.pop(f"{prefix}.scales")
100+
qzeros = state_dict.pop(f"{prefix}.qzeros")
101+
102+
def dequantize_gemm_original(
103+
qweight: torch.Tensor,
104+
qzeros: torch.Tensor,
105+
scales: torch.Tensor,
106+
bits: int,
107+
group_size: int,
108+
) -> tuple[torch.Tensor, torch.Tensor]:
109+
"""Modified from auto_round.export.export_to_awq.utils.dequantize_gemm."""
110+
# Unpack the qweight and qzeros tensors
111+
iweight, izeros = unpack_awq(qweight, qzeros, bits)
112+
# Reverse the order of the iweight and izeros tensors
113+
iweight, izeros = reverse_awq_order(iweight, izeros, bits)
114+
115+
# overflow checks
116+
iweight = torch.bitwise_and(iweight, (2**bits) - 1)
117+
izeros = torch.bitwise_and(izeros, (2**bits) - 1)
118+
119+
# fp16 weights
120+
scales_interleaved = scales.repeat_interleave(group_size, dim=0)
121+
izeros_interleaved = izeros.repeat_interleave(group_size, dim=0)
122+
fweight = (iweight - izeros_interleaved) * scales_interleaved
123+
124+
return fweight, izeros
125+
126+
weight, zero_point = dequantize_gemm_original(
127+
qweight, qzeros, scales, num_bits, group_size
128+
)
129+
130+
# AutoAWQ uses [0, 2^bits - 1], e.g., [0, 15], for quantized weights, but
131+
# compressed-tensors uses [-2^(bits - 1), 2^(bits - 1) - 1], e.g., [-8, 7].
132+
# Therefore, we need to shift the zero point by 2^(bits - 1) to match the range
133+
# of compressed-tensors and to allow correct quant/dequantization.
134+
shifted_zero_point = zero_point - 2 ** (num_bits - 1)
135+
136+
state_dict.update(
137+
{
138+
f"{prefix}.weight": weight.T,
139+
f"{prefix}.weight_scale": scales.T,
140+
f"{prefix}.weight_zero_point": shifted_zero_point.T,
141+
}
142+
)
143+
144+
145+
def dequantize_autoawq_state_dict(
146+
state_dict: dict[str, torch.Tensor], autoawq_config: dict[str, Any]
147+
) -> dict[str, torch.Tensor]:
148+
version = cast(str, autoawq_config.get("version"))
149+
150+
# TODO: maybe add support for other versions?
151+
match version:
152+
case "gemm":
153+
dequantize_fn = dequantize_gemm
154+
case _:
155+
raise ValueError(f"Unsupported version: {version}")
156+
157+
keys = list(state_dict.keys())
158+
for key in filter(lambda k: k.endswith("qweight"), keys):
159+
prefix = key.removesuffix(".qweight")
160+
dequantize_fn(state_dict, prefix, autoawq_config)
161+
162+
return state_dict
163+
164+
165+
def convert_and_save(
166+
model_name_or_path: str,
167+
output_dir: str,
168+
quantization_format: str,
169+
overwrite: bool = False,
170+
trust_remote_code: bool = False,
171+
) -> None:
172+
"""Convert an AutoAWQ model to a compressed model and save it.
173+
174+
Steps:
175+
176+
1. Load the model weights directly.
177+
2. Dequantize the weights accordingly.
178+
3. Load the model with the dequantized weights.
179+
4. Add the quantization parameters to the model.
180+
5. Re-pack the weights using `ModelCompressor` with the correct configuration.
181+
6. Save the model to the output directory.
182+
183+
:param model_name_or_path: Model ID on huggingface hub or path to local model.
184+
:param output_dir: Path to save the converted model.
185+
:param quantization_format: Compression format to be saved.
186+
:param overwrite: Overwrite the existing output directory if it exists.
187+
:param trust_remote_code: Whether to trust remote code.
188+
"""
189+
output_exists = os.path.exists(output_dir)
190+
is_directory = os.path.isdir(output_dir) if output_exists else False
191+
is_empty_dir = False
192+
if output_exists and is_directory:
193+
is_empty_dir = not any(os.scandir(output_dir))
194+
195+
if not output_exists:
196+
# Safe: output_dir does not exist
197+
pass
198+
elif not is_directory or (not is_empty_dir and not overwrite):
199+
raise FileExistsError(
200+
f"{output_dir=} already exists. Set `overwrite=True` to"
201+
" overwrite the existing directory."
202+
)
203+
204+
model_path = resolve_model_path(model_name_or_path)
205+
if not is_autoawq_model(model_path, trust_remote_code):
206+
raise ValueError("Model is not an AutoAWQ model")
207+
208+
config = transformers.AutoConfig.from_pretrained(
209+
model_path, trust_remote_code=trust_remote_code
210+
)
211+
autoawq_config = cast(dict[str, Any], config.quantization_config)
212+
num_bits = cast(int, autoawq_config.get("bits"))
213+
is_symmetric = not autoawq_config.get("zero_point")
214+
group_size = cast(int, autoawq_config.get("group_size"))
215+
216+
# Convert AutoAWQ's substring-based ignore list to llm-compressor's regex format
217+
# Usage in AutoAWQ:
218+
# ```python
219+
# if any(key in name for key in modules_to_not_convert): ...
220+
# ```
221+
# See https://github.com/casper-hansen/AutoAWQ/blob/88e4c76b20755db275574e6a03c83c84ba3bece5/awq/utils/module.py#L62
222+
modules_to_not_convert = autoawq_config.get("modules_to_not_convert", None)
223+
ignore = []
224+
if modules_to_not_convert is not None:
225+
# Convert each substring pattern to a regex pattern that matches it anywhere
226+
for module in modules_to_not_convert:
227+
ignore.append(f"re:.*{re.escape(module)}.*")
228+
229+
ignore.append("lm_head") # AutoAWQ ignores lm_head by default
230+
231+
# 1. Load the model weights directly.
232+
state_dict = load_state_dict_from_model_dir(model_path)
233+
234+
# 2. Dequantize the weights accordingly.
235+
state_dict = dequantize_autoawq_state_dict(state_dict, autoawq_config)
236+
237+
# 3. Load the model with the dequantized weights.
238+
del config.quantization_config # remove to avoid loading with AutoAWQ.
239+
with transformers.modeling_utils.no_init_weights():
240+
model = transformers.AutoModelForCausalLM.from_config(
241+
config, torch_dtype=torch.float16, trust_remote_code=trust_remote_code
242+
)
243+
244+
model.load_state_dict(state_dict, strict=False)
245+
246+
# 4. Add the quantization parameters to the model.
247+
quantization_scheme = QuantizationScheme(
248+
targets=["Linear"],
249+
weights=QuantizationArgs(
250+
num_bits=num_bits,
251+
type=QuantizationType.INT,
252+
symmetric=is_symmetric,
253+
group_size=group_size,
254+
strategy=QuantizationStrategy.GROUP,
255+
),
256+
)
257+
258+
for key in filter(lambda k: k.endswith("weight_zero_point"), state_dict.keys()):
259+
module_name = key.removesuffix(".weight_zero_point")
260+
setattr(
261+
model.get_submodule(module_name), "quantization_scheme", quantization_scheme
262+
)
263+
264+
quant_config = QuantizationConfig(
265+
config_groups={"group_0": quantization_scheme},
266+
quant_method="compressed-tensors",
267+
quantization_status=QuantizationStatus.COMPRESSED,
268+
format=quantization_format,
269+
ignore=ignore,
270+
)
271+
272+
# 5. Re-pack the weights using `ModelCompressor`.
273+
compressor = ModelCompressor(quantization_config=quant_config)
274+
compressed_state_dict = compressor.compress(model, state_dict, show_progress=True)
275+
276+
# 6. Save the model.
277+
tokenizer = transformers.AutoTokenizer.from_pretrained(
278+
model_name_or_path, trust_remote_code=trust_remote_code
279+
)
280+
model.save_pretrained(output_dir, state_dict=compressed_state_dict)
281+
tokenizer.save_pretrained(output_dir)
282+
compressor.update_config(output_dir)
283+
284+
285+
def load_and_convert_from_autoawq(
286+
model_name_or_path: str,
287+
quantization_format: str = "naive-quantized",
288+
trust_remote_code: bool = False,
289+
) -> transformers.modeling_utils.PreTrainedModel:
290+
"""
291+
Load an AutoAWQ checkpoint and convert it to a compressed model.
292+
293+
:param model_name_or_path: Model ID on huggingface hub or path to local model.
294+
:param quantization_format: Compression format to be saved.
295+
:param trust_remote_code: Whether to trust remote code.
296+
:return: A compressed model.
297+
"""
298+
with TemporaryDirectory() as temp_dir:
299+
convert_and_save(
300+
model_name_or_path,
301+
temp_dir,
302+
quantization_format,
303+
trust_remote_code=trust_remote_code,
304+
)
305+
return transformers.AutoModelForCausalLM.from_pretrained(
306+
temp_dir, torch_dtype=torch.float16, trust_remote_code=trust_remote_code
307+
)
308+
309+
310+
@dataclass
311+
class ConversionArgs:
312+
model_name_or_path: str = field(
313+
metadata={"help": "Model ID on huggingface hub or path to local model."},
314+
)
315+
output_dir: str = field(
316+
metadata={"help": "Path to save the converted model."},
317+
)
318+
quantization_format: Literal["naive-quantized", "pack-quantized"] = field(
319+
default="naive-quantized",
320+
metadata={"help": "Compression format to be saved."},
321+
) # TODO: switch default to packed-quantized once supported by compressed-tensors.
322+
overwrite: bool = field(
323+
default=False,
324+
metadata={"help": "Overwrite the existing output directory if it exists."},
325+
)
326+
trust_remote_code: bool = field(
327+
default=False,
328+
metadata={"help": "Whether to trust remote code."},
329+
)
330+
331+
332+
__all__ = ["convert_and_save", "load_and_convert_from_autoawq", "ConversionArgs"]
333+
334+
335+
if __name__ == "__main__":
336+
parser = transformers.HfArgumentParser(ConversionArgs)
337+
args = parser.parse_args_into_dataclasses()[0]
338+
convert_and_save(
339+
args.model_name_or_path,
340+
args.output_dir,
341+
args.quantization_format,
342+
args.overwrite,
343+
args.trust_remote_code,
344+
)

0 commit comments

Comments
 (0)