Skip to content

Commit dcfacef

Browse files
Loosen the packing restrictions for mxfp&nvfp (#911)
* Loosen the packing restrictions for mxfp&nvfp, enable Qwen1.5-MoE-A2.7B quantize Signed-off-by: Zhang, Weiwei1 <weiwei1.zhang@intel.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix UT Signed-off-by: Zhang, Weiwei1 <weiwei1.zhang@intel.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * refine mxfp&nvfp layer checker Signed-off-by: Zhang, Weiwei1 <weiwei1.zhang@intel.com> * fix pylint Signed-off-by: Zhang, Weiwei1 <weiwei1.zhang@intel.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Zhang, Weiwei1 <weiwei1.zhang@intel.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 41986b5 commit dcfacef

File tree

7 files changed

+533
-390
lines changed

7 files changed

+533
-390
lines changed

auto_round/export/export_to_autoround/qlinear_fp.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
from auto_round.data_type.mxfp import FP32_EXPONENT_BIAS, FP32_MIN_NORMAL
3939
from auto_round.data_type.nvfp import cast_to_fp4, get_reciprocal
4040
from auto_round.data_type.utils import reshape_pad_tensor_by_group_size, revert_tensor_by_pad
41-
from auto_round.utils import _get_packing_device, is_mx_fp, is_nv_fp
41+
from auto_round.utils import BackendDataType, _get_packing_device, is_mx_fp, is_nv_fp
4242

4343
# from auto_round.utils import get_weight_compress_dtype
4444
logger = getLogger(__name__)
@@ -72,14 +72,22 @@ def __init__(
7272
super().__init__()
7373
if bits not in [4, 8]:
7474
raise NotImplementedError("Only 4,8 bits are supported.")
75-
if infeatures % 32 != 0 or outfeatures % 32 != 0:
76-
raise NotImplementedError("in_feature and out_feature must be divisible by 32.")
7775
self.is_mx = is_mx_fp(data_type)
7876
self.is_nv = is_nv_fp(data_type)
79-
if self.is_mx and group_size != 32:
80-
raise NotImplementedError("Only group_size 32 are supported for mxfp.")
81-
if self.is_nv and group_size not in [16, 32]:
82-
raise NotImplementedError("Only group_size 16 are supported for nvfp.")
77+
if self.is_mx:
78+
if group_size != 32:
79+
raise NotImplementedError(f"Only group_size 32 are supported for {BackendDataType.MX_FP} data type.")
80+
if infeatures % group_size != 0:
81+
raise NotImplementedError(
82+
f"in_feature must be divisible by {group_size} for {BackendDataType.MX_FP} data type."
83+
)
84+
if self.is_nv:
85+
if group_size % 16 != 0:
86+
raise NotImplementedError(f"Only group_size 16 are supported for {BackendDataType.NV_FP} data type.")
87+
if infeatures % group_size != 0:
88+
raise NotImplementedError(
89+
f"in_feature must be divisible by {group_size} for {BackendDataType.NV_FP} data type."
90+
)
8391
self.infeatures = infeatures
8492
self.outfeatures = outfeatures
8593
self.bits = bits

auto_round/inference/backend.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -127,12 +127,19 @@ def feature_multiply_checker_group_size(
127127
)
128128

129129

130+
def in_feature_checker_group_size(in_feature, out_feature, config):
131+
group_size = config["group_size"]
132+
return in_feature % group_size == 0
133+
134+
130135
feature_multiply_checker_32 = functools.partial(feature_multiply_checker, in_feature_multiplier=32)
131136
feature_multiply_checker_16 = functools.partial(feature_multiply_checker, in_feature_multiplier=16)
132137
in_output_feature_multiply_checker_32 = functools.partial(
133138
feature_multiply_checker, in_feature_multiplier=32, out_feature_multiplier=32
134139
)
135-
140+
in_feature_multiply_checker_32 = functools.partial(
141+
feature_multiply_checker, in_feature_multiplier=32, out_feature_multiplier=None
142+
)
136143
exllamav2_feature_checker = functools.partial(
137144
feature_multiply_checker_group_size, in_feature_multiplier=32, out_feature_multiplier=32
138145
)
@@ -141,6 +148,8 @@ def feature_multiply_checker_group_size(
141148
feature_multiply_checker_group_size, in_feature_multiplier=1, out_feature_multiplier=64
142149
)
143150

151+
mxfp_nvfp_feature_checker = functools.partial(in_feature_checker_group_size)
152+
144153

145154
def fp8_static_scheme_checker(
146155
in_feature: int,
@@ -239,7 +248,7 @@ def fp8_static_scheme_checker(
239248
act_data_type=["mx_fp_rceil"],
240249
act_dynamic=[True],
241250
priority=0,
242-
checkers=[feature_multiply_checker_32],
251+
checkers=[mxfp_nvfp_feature_checker],
243252
alias=["auto_round", "torch"],
244253
requirements=["auto-round>0.7.0"],
245254
)
@@ -259,7 +268,7 @@ def fp8_static_scheme_checker(
259268
act_data_type=["mx_fp_rceil"],
260269
act_dynamic=[True],
261270
priority=0,
262-
checkers=[feature_multiply_checker_32],
271+
checkers=[mxfp_nvfp_feature_checker],
263272
alias=["auto_round", "torch"],
264273
requirements=["auto-round>0.7.0"],
265274
)
@@ -280,7 +289,7 @@ def fp8_static_scheme_checker(
280289
act_data_type=["nv_fp4_with_static_gs"],
281290
act_dynamic=[True],
282291
priority=0,
283-
checkers=[feature_multiply_checker_16],
292+
checkers=[mxfp_nvfp_feature_checker],
284293
alias=["auto_round", "torch"],
285294
requirements=["auto-round>0.7.0"],
286295
)

auto_round/utils.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2963,6 +2963,18 @@ def normalize_item(item: Union[str, dict, "QuantizationScheme"], layer_name: str
29632963
layer_config.setdefault(n, copy.deepcopy(default_dict))
29642964
layer_config[n].update({"bits": 16, "data_type": "fp", "fixed_by_user": True})
29652965
logger.warning_once(f"{n} skipped quantization (shape not divisible by 32).")
2966+
# enforce shape divisibility for mxfp/nvfp
2967+
if (is_nv_fp(default_dict["data_type"]) or is_mx_fp(default_dict["data_type"])) and not gguf_name:
2968+
for n, m in model.named_modules():
2969+
if type(m) in supported_types or m.__class__.__name__ in inner_supported_types:
2970+
if m.weight.shape[1] % default_dict["group_size"]:
2971+
layer_config.setdefault(n, copy.deepcopy(default_dict))
2972+
layer_config[n].update(
2973+
{"bits": 16, "data_type": "fp", "act_bits": 16, "act_data_type": "fp", "fixed_by_user": True}
2974+
)
2975+
logger.warning_once(
2976+
f"{n} skipped quantization (shape not divisible by {default_dict['group_size']})."
2977+
)
29662978

29672979
# 9. block layers: mark as in_blocks=True
29682980
for name in get_layer_names_in_block(model, supported_types, quant_block_list, inner_supported_types):

test/test_cpu/test_export.py

Lines changed: 0 additions & 269 deletions
Original file line numberDiff line numberDiff line change
@@ -302,275 +302,6 @@ def test_static_afp8_export(self, static_kv_dtype):
302302
self.assertEqual(f.get_tensor("model.decoder.layers.5.self_attn.v_proj.weight").dtype, torch.float8_e4m3fn)
303303
shutil.rmtree(quantized_model_path, ignore_errors=True)
304304

305-
def test_mxfp4_llmcompressor_format(self):
306-
model_name = "/tf_dataset/auto_round/models/facebook/opt-125m"
307-
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="auto", trust_remote_code=True)
308-
from transformers import AutoConfig
309-
310-
scheme = "MXFP4"
311-
layer_config = {}
312-
fp_layers_str = "k_proj"
313-
from auto_round.utils import get_fp_layer_names
314-
315-
not_quantize_layer_names = get_fp_layer_names(model, fp_layers_str)
316-
for name in not_quantize_layer_names:
317-
layer_config[name] = {"bits": 16, "act_bits": 16, "data_type": "float"}
318-
autoround = AutoRound(
319-
model,
320-
self.tokenizer,
321-
scheme=scheme,
322-
iters=2,
323-
seqlen=2,
324-
layer_config=layer_config,
325-
dataset=self.llm_dataloader,
326-
)
327-
quantized_model_path = self.save_dir
328-
autoround.quantize()
329-
compressed_model = autoround.save_quantized(
330-
output_dir=quantized_model_path, inplace=True, format="llm_compressor"
331-
)
332-
tmp_layer = compressed_model.model.decoder.layers[3].self_attn.q_proj
333-
skip_layer = compressed_model.model.decoder.layers[3].self_attn.k_proj
334-
assert (
335-
hasattr(tmp_layer, "weight_scale")
336-
and hasattr(tmp_layer, "weight_packed")
337-
and tmp_layer.weight_scale.dtype is torch.uint8
338-
and tmp_layer.weight_scale.shape[0] == 768
339-
), "Illegal MXFP4 packing name or data_type or shape"
340-
assert not hasattr(skip_layer, "weight_scale") and not hasattr( ## check skipped layers
341-
skip_layer, "weight_packed"
342-
), "Illegal MXFP4 quantization for fp_layers"
343-
quantization_config = AutoConfig.from_pretrained(
344-
quantized_model_path, trust_remote_code=True
345-
).quantization_config
346-
assert (
347-
quantization_config["format"] == "float-quantized"
348-
and quantization_config["config_groups"]["group_0"]["weights"]["is_mx"] is True
349-
and quantization_config["config_groups"]["group_0"]["weights"]["num_bits"] == 4
350-
), f"Invalid MXFP4 quantization configuration: {quantization_config}"
351-
352-
shutil.rmtree("./saved", ignore_errors=True)
353-
354-
def test_rtn_mxfp4_llmcompressor_format(self):
355-
model_name = "/tf_dataset/auto_round/models/facebook/opt-125m"
356-
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="auto", trust_remote_code=True)
357-
from transformers import AutoConfig
358-
359-
scheme = "MXFP4"
360-
layer_config = {}
361-
fp_layers_str = "k_proj"
362-
from auto_round.utils import get_fp_layer_names
363-
364-
not_quantize_layer_names = get_fp_layer_names(model, fp_layers_str)
365-
for name in not_quantize_layer_names:
366-
layer_config[name] = {"bits": 16, "act_bits": 16, "data_type": "float"}
367-
autoround = AutoRound(
368-
model,
369-
self.tokenizer,
370-
scheme=scheme,
371-
iters=0,
372-
seqlen=2,
373-
layer_config=layer_config,
374-
dataset=self.llm_dataloader,
375-
)
376-
quantized_model_path = self.save_dir
377-
autoround.quantize()
378-
compressed_model = autoround.save_quantized(
379-
output_dir=quantized_model_path, inplace=True, format="llm_compressor"
380-
)
381-
tmp_layer = compressed_model.model.decoder.layers[3].self_attn.q_proj
382-
skip_layer = compressed_model.model.decoder.layers[3].self_attn.k_proj
383-
assert (
384-
hasattr(tmp_layer, "weight_scale")
385-
and hasattr(tmp_layer, "weight_packed")
386-
and tmp_layer.weight_scale.dtype is torch.uint8
387-
and tmp_layer.weight_scale.shape[0] == 768
388-
), "Illegal MXFP4 packing name or data_type or shape"
389-
assert not hasattr(skip_layer, "weight_scale") and not hasattr( ## check skipped layers
390-
skip_layer, "weight_packed"
391-
), "Illegal MXFP4 quantization for fp_layers"
392-
quantization_config = AutoConfig.from_pretrained(
393-
quantized_model_path, trust_remote_code=True
394-
).quantization_config
395-
assert (
396-
quantization_config["format"] == "float-quantized"
397-
and quantization_config["config_groups"]["group_0"]["weights"]["is_mx"] is True
398-
and quantization_config["config_groups"]["group_0"]["weights"]["num_bits"] == 4
399-
), f"Invalid MXFP4 quantization configuration: {quantization_config}"
400-
shutil.rmtree("./saved", ignore_errors=True)
401-
402-
def test_mxfp8_llmcompressor_format(self):
403-
model_name = "/tf_dataset/auto_round/models/facebook/opt-125m"
404-
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="auto", trust_remote_code=True)
405-
from transformers import AutoConfig
406-
407-
scheme = "MXFP8"
408-
autoround = AutoRound(
409-
model,
410-
self.tokenizer,
411-
scheme=scheme,
412-
iters=2,
413-
seqlen=2,
414-
dataset=self.llm_dataloader,
415-
)
416-
quantized_model_path = self.save_dir
417-
compressed_model, _ = autoround.quantize_and_save(output_dir=quantized_model_path, format="llm_compressor")
418-
tmp_layer = compressed_model.model.decoder.layers[3].self_attn.q_proj
419-
assert (
420-
hasattr(tmp_layer, "weight_scale")
421-
and hasattr(tmp_layer, "weight")
422-
and tmp_layer.weight.dtype is torch.float8_e4m3fn
423-
and tmp_layer.weight_scale.dtype is torch.uint8
424-
and tmp_layer.weight_scale.shape[0] == 768
425-
), "Illegal MXFP8 packing name or data_type or shape"
426-
quantization_config = AutoConfig.from_pretrained(
427-
quantized_model_path, trust_remote_code=True
428-
).quantization_config
429-
assert (
430-
quantization_config["format"] == "float-quantized"
431-
and quantization_config["config_groups"]["group_0"]["weights"]["is_mx"] is True
432-
and quantization_config["config_groups"]["group_0"]["weights"]["num_bits"] == 8
433-
), f"Invalid MXFP8 quantization configuration: {quantization_config}"
434-
folder_size_gb = _get_folder_size(quantized_model_path)
435-
# Original opt-125m is < 0.5GB -> quantized mxfp8 model should be smaller but not empty
436-
assert (
437-
0.15 < folder_size_gb < 0.2
438-
), f"Quantized model folder size {folder_size_gb:.2f} GB is outside the expected range (0.1~0.2 GB)"
439-
shutil.rmtree("./saved", ignore_errors=True)
440-
441-
def test_nvfp4_llmcompressor_format(self):
442-
model_name = "/tf_dataset/auto_round/models/facebook/opt-125m"
443-
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="auto", trust_remote_code=True)
444-
from transformers import AutoConfig
445-
446-
scheme = "NVFP4"
447-
autoround = AutoRound(
448-
model,
449-
self.tokenizer,
450-
scheme=scheme,
451-
iters=2,
452-
seqlen=2,
453-
dataset=self.llm_dataloader,
454-
)
455-
quantized_model_path = self.save_dir
456-
compressed_model, _ = autoround.quantize_and_save(output_dir=quantized_model_path, format="llm_compressor")
457-
tmp_layer = compressed_model.model.decoder.layers[3].self_attn.q_proj
458-
assert (
459-
hasattr(tmp_layer, "weight_scale")
460-
and hasattr(tmp_layer, "weight_global_scale")
461-
and hasattr(tmp_layer, "input_global_scale")
462-
and tmp_layer.weight_packed.dtype is torch.uint8
463-
and tmp_layer.weight_scale.dtype is torch.float8_e4m3fn
464-
and tmp_layer.weight_scale.shape[0] == 768
465-
), "Illegal NVFP4 packing name or data_type or shape"
466-
quantization_config = AutoConfig.from_pretrained(
467-
quantized_model_path, trust_remote_code=True
468-
).quantization_config
469-
assert (
470-
quantization_config["format"] == "nvfp4-pack-quantized"
471-
and quantization_config["config_groups"]["group_0"]["input_activations"]["num_bits"] == 4
472-
), f"Invalid NVFP4 quantization configuration: {quantization_config}"
473-
folder_size_gb = _get_folder_size(quantized_model_path)
474-
# Original opt-125m is < 0.5GB -> quantized nvfp4 model should be smaller but not empty
475-
assert (
476-
0.1 < folder_size_gb < 0.15
477-
), f"Quantized model folder size {folder_size_gb:.2f} GB is outside the expected range (0.1~0.15 GB)"
478-
shutil.rmtree("./saved", ignore_errors=True)
479-
480-
def test_nvfp4_autoround_format(self):
481-
model_name = "/tf_dataset/auto_round/models/facebook/opt-125m"
482-
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="auto", trust_remote_code=True)
483-
from transformers import AutoConfig
484-
485-
scheme = "NVFP4"
486-
autoround = AutoRound(
487-
model,
488-
self.tokenizer,
489-
scheme="NVFP4",
490-
iters=2,
491-
seqlen=2,
492-
dataset=self.llm_dataloader,
493-
)
494-
quantized_model_path = self.save_dir
495-
compressed_model, _ = autoround.quantize_and_save(output_dir=quantized_model_path, format="auto_round")
496-
tmp_layer = compressed_model.model.decoder.layers[3].self_attn.q_proj
497-
assert (
498-
hasattr(tmp_layer, "weight_scale")
499-
and hasattr(tmp_layer, "weight_global_scale")
500-
and hasattr(tmp_layer, "input_global_scale")
501-
and tmp_layer.weight_packed.dtype is torch.uint8
502-
and tmp_layer.weight_scale.dtype is torch.float8_e4m3fn
503-
and tmp_layer.weight_scale.shape[0] == 768
504-
), "Illegal NVFP4 packing name or data_type or shape"
505-
shutil.rmtree("./saved", ignore_errors=True)
506-
507-
def test_nvfp4_autoround_save_quantized(self):
508-
model_name = "/tf_dataset/auto_round/models/facebook/opt-125m"
509-
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="auto", trust_remote_code=True)
510-
from transformers import AutoConfig
511-
512-
scheme = "NVFP4"
513-
autoround = AutoRound(
514-
model,
515-
self.tokenizer,
516-
scheme="NVFP4",
517-
iters=2,
518-
seqlen=2,
519-
dataset=self.llm_dataloader,
520-
)
521-
quantized_model_path = self.save_dir
522-
autoround.quantize()
523-
compressed_model = autoround.save_quantized(output_dir=quantized_model_path, format="auto_round")
524-
tmp_layer = compressed_model.model.decoder.layers[3].self_attn.q_proj
525-
assert (
526-
hasattr(tmp_layer, "weight_scale")
527-
and hasattr(tmp_layer, "weight_global_scale")
528-
and hasattr(tmp_layer, "input_global_scale")
529-
and tmp_layer.weight_packed.dtype is torch.uint8
530-
and tmp_layer.weight_scale.dtype is torch.float8_e4m3fn
531-
and tmp_layer.weight_scale.shape[0] == 768
532-
), "Illegal NVFP4 packing name or data_type or shape"
533-
shutil.rmtree("./saved", ignore_errors=True)
534-
535-
def test_nvfp4_moe_actmax_rtn(self):
536-
model_name = "/tf_dataset/auto_round/models/deepseek-ai/DeepSeek-V2-Lite"
537-
layer_config = {
538-
"self_attn": {"bits": 16, "act_bits": 16},
539-
"mlp.shared_experts": {"bits": 16, "act_bits": 16},
540-
}
541-
scheme = "nvfp4"
542-
autoround = AutoRound(
543-
model_name,
544-
scheme=scheme,
545-
iters=0,
546-
seqlen=2,
547-
nsamples=2,
548-
dataset=self.llm_dataloader,
549-
layer_config=layer_config,
550-
)
551-
compressed_model, _ = autoround.quantize()
552-
assert hasattr(compressed_model.model.layers[1].mlp.experts[0].gate_proj.orig_layer, "act_max")
553-
554-
def test_nvfp4_moe_actmax_ar(self):
555-
model_name = "/tf_dataset/auto_round/models/deepseek-ai/DeepSeek-V2-Lite"
556-
layer_config = {
557-
"q_proj": {"bits": 16, "act_bits": 16},
558-
"mlp.shared_experts": {"bits": 16, "act_bits": 16},
559-
"experts.*2": {"bits": 16, "act_bits": 16},
560-
"experts.*5": {"bits": 16, "act_bits": 16},
561-
}
562-
scheme = "nvfp4"
563-
autoround = AutoRound(
564-
model_name,
565-
scheme=scheme,
566-
iters=1,
567-
seqlen=2,
568-
nsamples=2,
569-
dataset=self.llm_dataloader,
570-
layer_config=layer_config,
571-
)
572-
autoround.quantize_and_save(output_dir=self.save_dir, inplace=True, format="auto_round")
573-
574305

575306
if __name__ == "__main__":
576307
unittest.main()

0 commit comments

Comments
 (0)