Skip to content

Commit 1089004

Browse files
authored
add support for scheme FP8_STATIC to export llm_compressor format (#816)
* add support for scheme FP8_STATIC to export llm_compressor format Signed-off-by: n1ck-guo <heng.guo@intel.com>
1 parent bd4bc2c commit 1089004

34 files changed

+621
-420
lines changed

auto_round/compressors/base.py

Lines changed: 34 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -865,9 +865,9 @@ def remove_duplicates(lst):
865865
elif is_nv_fp(self.data_type) or is_mx_fp(self.data_type):
866866
format = f"auto_round:{self.data_type}"
867867
elif is_static_wfp8afp8(self): # staic wfp8afp8
868-
format = f"auto_round:{AutoRoundFormat.TORCH_FP8_STATIC.value}"
868+
format = f"auto_round:{AutoRoundFormat.FP8_STATIC.value}"
869869
elif self.data_type == "fp" and self.bits == 8 and self.act_bits >= 16: # woq fp8
870-
format = "auto_round:fp8"
870+
format = f"auto_round:{AutoRoundFormat.FP8.value}"
871871
elif self.act_bits < 16:
872872
raise ValueError(
873873
"AutoRound format does not support exporting "
@@ -882,6 +882,20 @@ def remove_duplicates(lst):
882882
check_compressed_tensors_supported()
883883
format = format.replace("llm_compressor", f"llm_compressor:{self.data_type}")
884884
formats[index] = format
885+
if is_static_wfp8afp8(self):
886+
format = f"llm_compressor:{AutoRoundFormat.FP8_STATIC.value}"
887+
formats[index] = format
888+
if self.act_group_size != 0:
889+
logger.warning(
890+
f"scheme FP8_STATIC export to llm_compressor format only support for act_group_size 0,"
891+
f" ,but got act_group_size={self.act_group_size}, reset = 0"
892+
)
893+
self.act_group_size = 0
894+
if self.group_size > 0:
895+
logger.warning(
896+
f"please note that group_size={self.group_size}"
897+
" may not be supported for llm_compressor format, and cannot be loaded in llm_compressor"
898+
)
885899
elif not is_wfp8afp8(self):
886900
logger.error(
887901
"Currently, the llm_compressor format only supports MXFP/NVFP/FP8. "
@@ -971,13 +985,25 @@ def _check_supported_format(self, format: str) -> bool:
971985
)
972986
format = "fake"
973987
else:
974-
if not (format == "auto_round" or format == f"auto_round:{AutoRoundFormat.TORCH_FP8_STATIC.value}"):
988+
if format not in [
989+
"auto_round",
990+
f"auto_round:{AutoRoundFormat.FP8_STATIC.value}",
991+
f"llm_compressor:{AutoRoundFormat.FP8_STATIC.value}",
992+
"auto_round:llm_compressor",
993+
]:
975994
logger.warning(
976995
f"Currently only support to export auto_round or fake format for static W{self.bits}AFP8 model,"
977996
f" change format {format} to auto_round"
978997
)
979-
format = "auto_round"
980-
if self.act_group_size != 0 and not self.act_dynamic and format == "auto_round:fp8":
998+
if is_static_wfp8afp8(self):
999+
format = f"auto_round:{AutoRoundFormat.FP8_STATIC.value}"
1000+
else:
1001+
format = f"auto_round:{AutoRoundFormat.FP8.value}"
1002+
if (
1003+
self.act_group_size != 0
1004+
and not self.act_dynamic
1005+
and format == f"auto_round:{AutoRoundFormat.FP8.value}"
1006+
):
9811007
logger.warning(
9821008
f"Please note that quantize activation with act_group_size={self.act_group_size}"
9831009
" may result in failure to export or import normally."
@@ -1198,7 +1224,7 @@ def register_act_hook(model):
11981224
def get_imatrix_hook(module, input, output):
11991225
input = input[0] if isinstance(input, (tuple, list)) else input
12001226
flattened = input.reshape(-1, input.shape[-1]).to(torch.float32)
1201-
squared = torch.sum(flattened**2, dim=0).to(torch.float32)
1227+
squared = torch.sum(torch.pow(flattened, 2), dim=0).to(torch.float32)
12021228

12031229
if not hasattr(module, "imatrix"):
12041230
module.imatrix = squared
@@ -3094,6 +3120,8 @@ def save_quantized(
30943120
)
30953121
if format == "llm_compressor" and (is_nv_fp(self.data_type) or is_mx_fp(self.data_type)):
30963122
format = format.replace("llm_compressor", f"llm_compressor:{self.data_type}")
3123+
if format == "llm_compressor" and is_static_wfp8afp8(self):
3124+
format = format.replace("llm_compressor", "llm_compressor:{AutoRoundFormat.FP8_STATIC.value}")
30973125

30983126
from auto_round.export import EXPORT_FORMAT
30993127

auto_round/data_type/gguf.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -337,7 +337,7 @@ def quant_tensor_gguf_asym_dq(
337337
if bits == 2:
338338
quant_weights = torch.abs(tensor)
339339
elif bits == 4 or bits == 5:
340-
sigma2 = torch.sum(tensor**2, dim=-1, keepdim=True) / 32 ##Note 32 is different from QK_K
340+
sigma2 = torch.sum(torch.pow(tensor, 2), dim=-1, keepdim=True) / 32 ##Note 32 is different from QK_K
341341
av_x = torch.sqrt(sigma2)
342342
quant_weights = torch.abs(tensor) + av_x
343343
params = search_kwargs[bits]
@@ -384,7 +384,9 @@ def quant_tensor_gguf_asym_dq(
384384
if bits == 2:
385385
tmp_quant_weights = torch.abs(tensor)
386386
elif bits == 4 or bits == 5:
387-
sigma2 = torch.sum(tensor**2, dim=-1, keepdim=True) / 32 ## Note 32 is different from QK_K
387+
sigma2 = (
388+
torch.sum(torch.pow(tensor, 2), dim=-1, keepdim=True) / 32
389+
) ## Note 32 is different from QK_K
388390
av_x = torch.sqrt(sigma2)
389391
tmp_quant_weights = torch.abs(tensor) + av_x
390392
quant_weights[replace_index, :] = tmp_quant_weights[replace_index, :]
@@ -395,7 +397,7 @@ def quant_tensor_gguf_asym_dq(
395397
tmp_quant_weights = tmp_quant_weights.view(-1, 1).expand(-1, quant_weights.shape[1])
396398
quant_weights[mean_replace_index, :] = tmp_quant_weights[mean_replace_index, :]
397399

398-
# sigma2 = torch.sum(tensor ** 2, dim=-1, keepdim=True) / QK_K
400+
# sigma2 = torch.sum(torch.pow(tensor, 2), dim=-1, keepdim=True) / QK_K
399401
# if imatrix is None:
400402
# av_x = torch.sqrt(sigma2)
401403
# quant_weights = torch.abs(av_x + tensor * tensor)
@@ -470,7 +472,7 @@ def iterative_wls_quant_search(data, bits=4, rrmin=-1.0, rdelta=0.1, nstep=20, u
470472
quant_data = torch.clamp(torch.round(iscale * (data - rmin)), minq, maxq)
471473
diff = scale * quant_data + rmin - data
472474

473-
best_mad = torch.sum((weights * torch.abs(diff)) if use_mad else weights * diff**2, dim=1, keepdim=True)
475+
best_mad = torch.sum((weights * torch.abs(diff)) if use_mad else weights * torch.pow(diff, 2), dim=1, keepdim=True)
474476

475477
for is_ in range(nstep):
476478
factor = rrmin + rdelta * is_ + maxq - minq
@@ -484,7 +486,7 @@ def iterative_wls_quant_search(data, bits=4, rrmin=-1.0, rdelta=0.1, nstep=20, u
484486
sum_l2 = torch.sum(mul_weights_quant_data * quant_data_new, dim=-1, keepdim=True)
485487
sum_xl = torch.sum(mul_weights_quant_data * data, dim=-1, keepdim=True)
486488

487-
D = sum_w * sum_l2 - sum_l**2
489+
D = sum_w * sum_l2 - torch.pow(sum_l, 2)
488490
this_scale = (sum_w * sum_xl - sum_x * sum_l) / D
489491
this_min = (sum_l2 * sum_x - sum_l * sum_xl) / D
490492
this_min[this_min > 0] = 0
@@ -494,7 +496,7 @@ def iterative_wls_quant_search(data, bits=4, rrmin=-1.0, rdelta=0.1, nstep=20, u
494496
quant_data = torch.clamp(torch.round(reverse_this_scale * (data - this_min)), minq, maxq)
495497
diff = this_scale * quant_data + this_min - data
496498
# diff = this_scale * quant_data_new + this_min - data
497-
mad = torch.sum((weights * torch.abs(diff)) if use_mad else weights * diff**2, dim=-1, keepdim=True)
499+
mad = torch.sum((weights * torch.abs(diff)) if use_mad else weights * torch.pow(diff, 2), dim=-1, keepdim=True)
498500

499501
idx_to_replace = torch.where((mad < best_mad) & (D > 0))[0]
500502
best_mad[idx_to_replace] = mad[idx_to_replace]
@@ -566,7 +568,7 @@ def quant_tensor_gguf_sym_dq(
566568
imatrix = imatrix.to(tensor.device)
567569

568570
# if bits == 3:
569-
# # sigma2 = 2 * torch.sum(tensor ** 2, dim=-1, keepdim=True) / QK_K
571+
# # sigma2 = 2 * torch.sum(torch.pow(tensor, 2), dim=-1, keepdim=True) / QK_K
570572
# # imatrix = imatrix.reshape(1, -1).expand(tensor.numel() // imatrix.numel(), -1).reshape(tensor.shape)
571573
# # quant_weights = imatrix * torch.sqrt(sigma2 + tensor * tensor)
572574
# # scale, int_w = make_qx_quants(tensor, bits=bits, rmse_type=1, qw=quant_weights)
@@ -588,7 +590,7 @@ def quant_tensor_gguf_sym_dq(
588590
if bits == 6:
589591
quant_weights[replace_index] = tensor[replace_index] * tensor[replace_index]
590592
else:
591-
sigma2 = 2 * torch.sum(tensor**2, dim=-1, keepdim=True) / QK_K
593+
sigma2 = 2 * torch.sum(torch.pow(tensor, 2), dim=-1, keepdim=True) / QK_K
592594
tmp_quant_weights = torch.sqrt(sigma2 + tensor * tensor)
593595
quant_weights[replace_index] = tmp_quant_weights[replace_index]
594596
mean_replace_index = (zero_cnt > 0) & (zero_cnt <= group_size // 2)

auto_round/export/export_to_autogptq/export.py

Lines changed: 4 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
from tqdm import tqdm
4848

4949
import auto_round.export.export_to_autogptq.qlinear_triton
50+
from auto_round.export.utils import save_model
5051
from auto_round.logger import logger
5152
from auto_round.utils import (
5253
SUPPORTED_LAYER_TYPES,
@@ -214,54 +215,7 @@ def wrapper(name):
214215
model.config.quantization_config = quantization_config
215216

216217
dtype = torch.float16 ##force dtype to fp16
217-
save(model, output_dir, safe_serialization=safe_serialization, dtype=dtype)
218+
save_model(
219+
model, output_dir, safe_serialization=safe_serialization, dtype=dtype, config_file="quantize_config.json"
220+
)
218221
return model
219-
220-
221-
def save(
222-
model: torch.nn.Module, save_dir: str, max_shard_size: str = "5GB", safe_serialization: bool = True, dtype=None
223-
):
224-
"""Save model state dict and configs.
225-
226-
Args:
227-
model (`nn.Module`):
228-
Model to be saved. The model can be wrapped or unwrapped.
229-
save_dir (`str`):
230-
Directory to which to save. Will be created if it doesn't exist.
231-
max_shard_size (`str`, defaults to `"10GB"`):
232-
The maximum size for a checkpoint before being sharded. Checkpoints shard will then be each of size
233-
lower than this size. If expressed as a string, needs to be digits followed by a unit (like `"5MB"`).
234-
<Tip warning={true}>
235-
236-
If a single weight of the model is bigger than `max_shard_size`, it will be in its own checkpoint shard
237-
which will be bigger than `max_shard_size`.
238-
239-
</Tip>
240-
safe_serialization (`bool`, defaults to `True`):
241-
Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).
242-
"""
243-
##max_shard_size = "10000GB" ## API of auto-gptq with marlin does not support shard size
244-
os.makedirs(save_dir, exist_ok=True)
245-
try:
246-
model.save_pretrained(save_dir, max_shard_size=max_shard_size, safe_serialization=safe_serialization)
247-
except ValueError as e:
248-
if hasattr(model, "generation_config"):
249-
setattr(model.generation_config, "do_sample", True)
250-
model.save_pretrained(save_dir, max_shard_size=max_shard_size, safe_serialization=safe_serialization)
251-
config_path = os.path.join(save_dir, "config.json")
252-
if dtype is not None and dtype != model.dtype and os.path.exists(os.path.join(save_dir, "config.json")):
253-
with open(config_path, "r") as file:
254-
data = json.load(file)
255-
data["torch_dtype"] = str(dtype).split(".")[-1]
256-
with open(config_path, "w") as file:
257-
json.dump(data, file, indent=2)
258-
259-
config_file = "quantize_config.json"
260-
if hasattr(model, "config") and hasattr(model.config, "quantization_config"):
261-
with open(os.path.join(save_dir, config_file), "w", encoding="utf-8") as f:
262-
json.dump(model.config.quantization_config, f, indent=2)
263-
264-
try:
265-
copy_python_files_from_model_cache(model, save_dir)
266-
except Exception as e:
267-
logger.warning("Skipping source model Python file copy due to error: %s", e)

auto_round/export/export_to_autoround/export.py

Lines changed: 19 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from tqdm import tqdm
2828

2929
from auto_round.export.export_to_autoround.utils import REQUIRED_CONFIG_KEYS, check_neq_config
30+
from auto_round.export.utils import save_model
3031
from auto_round.logger import logger
3132
from auto_round.utils import (
3233
SUPPORTED_FORMATS,
@@ -47,7 +48,8 @@
4748
class AutoRoundFormat(str, Enum):
4849
# Weight: FP8, per-channel, may be extended to per-tensor in future
4950
# Activation: FP8, per-tensor
50-
TORCH_FP8_STATIC = "fp8_static"
51+
FP8_STATIC = "fp8_static"
52+
FP8 = "fp8"
5153

5254

5355
def dynamic_import_quant_linear_for_packing(backend, bits, group_size, sym, act_bits=16):
@@ -159,11 +161,19 @@ def pack_layer(layer_name, model, backend, device=None):
159161

160162
return pack_layer(layer_name, model, backend, device)
161163

162-
if backend == "auto_round:fp8" or backend == f"auto_round:{AutoRoundFormat.TORCH_FP8_STATIC.value}":
164+
if (
165+
backend == f"auto_round:{AutoRoundFormat.FP8.value}"
166+
or backend == f"auto_round:{AutoRoundFormat.FP8_STATIC.value}"
167+
):
163168
from auto_round.export.export_to_autoround.export_to_fp8 import pack_layer
164169

165170
return pack_layer(layer_name, model, backend, device)
166171

172+
if backend == "auto_round:llm_compressor":
173+
from auto_round.export.export_to_llmcompressor.export_to_static_fp import pack_layer
174+
175+
return pack_layer(layer_name, model, backend, device)
176+
167177
layer = get_module(model, layer_name)
168178
if hasattr(layer, "orig_layer"):
169179
layer = layer.orig_layer
@@ -271,6 +281,11 @@ def save_quantized_as_autoround(output_dir, inplace=True, backend="auto_round:ex
271281

272282
return save_quantized_as_fp(output_dir, inplace=inplace, backend="auto_round:llm_compressor", **kwargs)
273283

284+
if backend == "auto_round:llm_compressor":
285+
from auto_round.export.export_to_llmcompressor.export_to_static_fp import save_quantized_as_static_fp
286+
287+
return save_quantized_as_static_fp(output_dir, inplace=inplace, backend="auto_round:llm_compressor", **kwargs)
288+
274289
if kwargs.get("data_type", "int") == "fp" and kwargs.get("bits", 16) == 8 and kwargs.get("act_bits", 16) >= 16:
275290
from auto_round.export.export_to_autoround.export_to_fp8 import save_quantized_as_autoround
276291

@@ -280,7 +295,7 @@ def save_quantized_as_autoround(output_dir, inplace=True, backend="auto_round:ex
280295
if (
281296
(kwargs.get("sym") is None or kwargs.get("sym"))
282297
and ("gptq" not in backend and "awq" not in backend)
283-
and (AutoRoundFormat.TORCH_FP8_STATIC.value not in backend)
298+
and (AutoRoundFormat.FP8_STATIC.value not in backend)
284299
):
285300
backend = backend.replace("auto_round", "auto_round:auto_gptq")
286301

@@ -367,52 +382,6 @@ def wrapper(name):
367382
dtype = torch.float16 ## awq kernel only supports float16 on cuda
368383
else:
369384
dtype = None
370-
save(model, output_dir, safe_serialization=safe_serialization, dtype=dtype)
385+
save_model(model, output_dir, safe_serialization=safe_serialization, dtype=dtype)
371386

372387
return model
373-
374-
375-
def save(model: nn.Module, save_dir: str, max_shard_size: str = "5GB", safe_serialization: bool = True, dtype=None):
376-
"""Save model state dict and configs.
377-
378-
Args:
379-
model (`nn.Module`):
380-
Model to be saved. The model can be wrapped or unwrapped.
381-
save_dir (`str`):
382-
Directory to which to save. Will be created if it doesn't exist.
383-
max_shard_size (`str`, defaults to `"10GB"`):
384-
The maximum size for a checkpoint before being sharded. Checkpoints shard will then be each of size
385-
lower than this size. If expressed as a string, needs to be digits followed by a unit (like `"5MB"`).
386-
<Tip warning={true}>
387-
388-
If a single weight of the model is bigger than `max_shard_size`, it will be in its own checkpoint shard
389-
which will be bigger than `max_shard_size`.
390-
391-
</Tip>
392-
safe_serialization (`bool`, defaults to `True`):
393-
Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).
394-
"""
395-
os.makedirs(save_dir, exist_ok=True)
396-
try:
397-
model.save_pretrained(save_dir, max_shard_size=max_shard_size, safe_serialization=safe_serialization)
398-
except ValueError as e:
399-
if hasattr(model, "generation_config"):
400-
setattr(model.generation_config, "do_sample", True)
401-
model.save_pretrained(save_dir, max_shard_size=max_shard_size, safe_serialization=safe_serialization)
402-
403-
config_path = os.path.join(save_dir, "config.json")
404-
if dtype is not None and dtype != model.dtype and os.path.exists(os.path.join(save_dir, "config.json")):
405-
with open(config_path, "r") as file:
406-
data = json.load(file)
407-
data["torch_dtype"] = str(dtype).split(".")[-1]
408-
with open(config_path, "w") as file:
409-
json.dump(data, file, indent=2)
410-
config_file = "quantization_config.json"
411-
if hasattr(model, "config") and hasattr(model.config, "quantization_config"):
412-
with open(os.path.join(save_dir, config_file), "w", encoding="utf-8") as f:
413-
json.dump(model.config.quantization_config, f, indent=2)
414-
415-
try:
416-
copy_python_files_from_model_cache(model, save_dir)
417-
except Exception as e:
418-
logger.warning("Skipping source model Python file copy due to error: %s", e)

0 commit comments

Comments
 (0)