Skip to content
10 changes: 2 additions & 8 deletions auto_round/compressors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1651,9 +1651,7 @@ def _quantize_via_rtn_blockwise(self, all_to_quantized_module_names: list[str])
if self.device_map is not None:
accelerate.hooks.remove_hook_from_submodules(block)

if (
is_nv_fp(self.act_data_type) and any("nv_fp" in format_ for format_ in self.formats)
) or is_static_wfp8afp8(self):
if is_nv_fp(self.act_data_type) or is_static_wfp8afp8(self):
# enable moe experts act_max automatic generation for Linear
set_amax_for_all_moe_layers(block, attr_name="act_max")
# Normalize imatrix and quantize layers
Expand Down Expand Up @@ -2911,11 +2909,7 @@ def _quantize_block(
with torch.no_grad():
unwrapper_block(block, best_params)

if (
is_nv_fp(self.act_data_type)
and hasattr(self, "formats")
and any("nv_fp" in format_ for format_ in self.formats)
):
if is_nv_fp(self.act_data_type):
# enable moe experts act_max automatic generation for WrapperWALayer
set_amax_for_all_moe_layers(block, attr_name="orig_layer.act_max")

Expand Down
4 changes: 2 additions & 2 deletions auto_round/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2481,7 +2481,7 @@ def set_nested_attr(module, attr_name: str, value):
attrs = attr_name.split(".")
for attr in attrs[:-1]:
if not hasattr(module, attr):
raise AttributeError(f"{module} has no attribute '{attr}'")
return None # No need to set act_max for fp layers
module = getattr(module, attr)
setattr(module, attrs[-1], value)

Expand Down Expand Up @@ -2546,7 +2546,7 @@ def set_amax_for_all_moe_layers(model: torch.nn.Module, layer_name=None, attr_na
# For other MoE models (like Mixtral) with iterable experts
try:
set_amax_for_uncalibrated_experts(
[getattr(expert, linear_name) for expert in sub_module.experts], attr_name=attr_name
[getattr(expert, linear_name, None) for expert in sub_module.experts], attr_name=attr_name
)
except AttributeError as e:
# Provide more helpful debugging information
Expand Down
39 changes: 39 additions & 0 deletions test/test_cpu/test_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,6 +532,45 @@ def test_nvfp4_autoround_save_quantized(self):
), "Illegal NVFP4 packing name or data_type or shape"
shutil.rmtree("./saved", ignore_errors=True)

def test_nvfp4_moe_actmax_rtn(self):
model_name = "/tf_dataset/auto_round/models/deepseek-ai/DeepSeek-V2-Lite"
layer_config = {
"self_attn": {"bits": 16, "act_bits": 16},
"mlp.shared_experts": {"bits": 16, "act_bits": 16},
}
scheme = "nvfp4"
autoround = AutoRound(
model_name,
scheme=scheme,
iters=0,
seqlen=2,
nsamples=2,
dataset=self.llm_dataloader,
layer_config=layer_config,
)
compressed_model, _ = autoround.quantize()
assert hasattr(compressed_model.model.layers[1].mlp.experts[0].gate_proj.orig_layer, "act_max")

def test_nvfp4_moe_actmax_ar(self):
model_name = "/tf_dataset/auto_round/models/deepseek-ai/DeepSeek-V2-Lite"
layer_config = {
"q_proj": {"bits": 16, "act_bits": 16},
"mlp.shared_experts": {"bits": 16, "act_bits": 16},
"experts.*2": {"bits": 16, "act_bits": 16},
"experts.*5": {"bits": 16, "act_bits": 16},
}
scheme = "nvfp4"
autoround = AutoRound(
model_name,
scheme=scheme,
iters=1,
seqlen=2,
nsamples=2,
dataset=self.llm_dataloader,
layer_config=layer_config,
)
autoround.quantize_and_save(output_dir=self.save_dir, inplace=True, format="auto_round")


if __name__ == "__main__":
unittest.main()
30 changes: 30 additions & 0 deletions test/test_cuda/test_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,6 +402,36 @@ def test_nvfp4_llmcompressor_format(self):
# if "France" in prompt:
# assert "Paris" in generated_text

def test_nvfp4_moe_actmax_rtn(self):
model_name = "/data0/deepseek-ai/DeepSeek-V2-Lite"
scheme = "nvfp4"
autoround = AutoRound(
model_name,
scheme=scheme,
iters=0,
seqlen=2,
nsamples=2,
dataset=self.llm_dataloader,
)
autoround.quantize()
quantized_model_path = self.save_dir
autoround.save_quantized(output_dir=quantized_model_path, inplace=False, format="auto_round")

def test_nvfp4_moe_actmax_ar(self):
model_name = "/data0/deepseek-ai/DeepSeek-V2-Lite"
scheme = "nvfp4"
autoround = AutoRound(
model_name,
scheme=scheme,
iters=1,
seqlen=2,
nsamples=2,
dataset=self.llm_dataloader,
)
autoround.quantize()
quantized_model_path = self.save_dir
autoround.save_quantized(output_dir=quantized_model_path, inplace=False, format="auto_round")


if __name__ == "__main__":
unittest.main()