Skip to content

Commit 284eecd

Browse files
authored
fix lm head bug and rm clear_mem_reach_threhold (#997)
1 parent 268f7dd commit 284eecd

File tree

4 files changed

+40
-51
lines changed

4 files changed

+40
-51
lines changed

auto_round/compressors/base.py

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2271,10 +2271,10 @@ def _quantize_layer(
22712271
init_loss = None
22722272
gradient_accumulate_steps = self.batch_size # Force to low gpu
22732273
batch_size = 1 # Force to low gpu
2274-
pick_samples = batch_size * gradient_accumulate_steps
2275-
pick_samples = min(nsamples, pick_samples)
2274+
global_batch_size = batch_size * gradient_accumulate_steps
2275+
global_batch_size = min(nsamples, global_batch_size)
22762276
if self.sampler != "rand":
2277-
whole_indices = torch.randperm(nsamples)[:pick_samples]
2277+
whole_indices = torch.randperm(nsamples)[:global_batch_size]
22782278
total_loss = 0
22792279
num_elm = 1
22802280
mse_reduction = "mean"
@@ -2285,7 +2285,7 @@ def _quantize_layer(
22852285
for i in range(self.iters):
22862286
total_loss = 0
22872287
if self.sampler == "rand":
2288-
whole_indices = torch.randperm(nsamples)[:pick_samples]
2288+
whole_indices = torch.randperm(nsamples)[:global_batch_size]
22892289
if gradient_accumulate_steps != 1:
22902290
if q_inputs is not None:
22912291
num_elm = self._get_current_num_elm(q_inputs, whole_indices)
@@ -2564,10 +2564,10 @@ def _quantize_block(
25642564
else:
25652565
nsamples = len(input_ids)
25662566

2567-
pick_samples = self.batch_size * self.gradient_accumulate_steps
2568-
pick_samples = min(nsamples, pick_samples)
2567+
global_batch_size = self.batch_size * self.gradient_accumulate_steps
2568+
global_batch_size = min(nsamples, global_batch_size)
25692569
if self.sampler != "rand":
2570-
whole_indices = torch.randperm(nsamples)[:pick_samples]
2570+
whole_indices = torch.randperm(nsamples)[:global_batch_size]
25712571
last_best_iter = 0
25722572
best_loss = torch.finfo(torch.float).max
25732573
num_elm = 1
@@ -2579,13 +2579,15 @@ def _quantize_block(
25792579
init_loss = None
25802580
best_params = {}
25812581
total_loss = 0
2582+
# We assume the block input and output shape is same
2583+
if self.gradient_accumulate_steps != 1:
2584+
whole_indices = torch.arange(global_batch_size)
2585+
num_elm = self._get_current_num_elm(input_ids, whole_indices)
2586+
25822587
for i in range(self.iters):
25832588
total_loss = 0
25842589
if self.sampler == "rand":
2585-
whole_indices = torch.randperm(nsamples)[:pick_samples]
2586-
# We assume the block input and output shape is same
2587-
if self.gradient_accumulate_steps != 1:
2588-
num_elm = self._get_current_num_elm(input_ids, whole_indices)
2590+
whole_indices = torch.randperm(nsamples)[:global_batch_size]
25892591

25902592
for tmp_step in range(self.gradient_accumulate_steps):
25912593
indices = whole_indices[tmp_step * self.batch_size : (tmp_step + 1) * self.batch_size]
@@ -2600,6 +2602,9 @@ def _quantize_block(
26002602
tmp_attention_mask = [self.attention_mask[i] for i in indices]
26012603
tmp_attention_mask = torch.cat(tmp_attention_mask, dim=0).to(device)
26022604
tmp_attention_mask.unsqueeze_(-1)
2605+
num_elm = torch.sum(tmp_attention_mask).item()
2606+
if num_elm == 0:
2607+
num_elm = 1
26032608
else:
26042609
tmp_attention_mask = 1.0
26052610
if self.amp:
@@ -2615,7 +2620,6 @@ def _quantize_block(
26152620

26162621
total_loss += loss.item() / num_elm
26172622
self._scale_loss_and_backward(scaler, loss)
2618-
clear_memory_if_reached_threshold(threshold=0.85)
26192623

26202624
if i == 0:
26212625
init_loss = total_loss
@@ -2655,7 +2659,8 @@ def _quantize_block(
26552659
set_amax_for_all_moe_layers(block, attr_name="orig_layer.act_max")
26562660

26572661
if self.enable_quanted_input:
2658-
clear_memory()
2662+
if self.low_gpu_mem_usage:
2663+
clear_memory()
26592664
q_outputs = self._get_block_outputs(
26602665
block,
26612666
input_ids,

auto_round/compressors/utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -377,6 +377,9 @@ def normalize_item(item: Union[str, dict, "QuantizationScheme"], layer_name: str
377377
if lm_head_name not in layer_config and quant_lm_head:
378378
layer_config[lm_head_name] = copy.deepcopy(default_dict)
379379

380+
if not quant_lm_head and not gguf_name:
381+
layer_config.pop(lm_head_name, None)
382+
380383
# 8. enforce shape divisibility for int weight-only
381384
if default_dict["data_type"] == "int" and default_dict["act_bits"] >= 16 and not gguf_name:
382385
for n, m in model.named_modules():

test/test_cpu/test_act_quantization.py

Lines changed: 5 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -154,17 +154,8 @@ def test_act_config_MXFP4_saving(self):
154154
quantized_model_path = self.save_dir
155155
autoround.quantize_and_save(output_dir=quantized_model_path, format="auto_round")
156156
model = AutoModelForCausalLM.from_pretrained(quantized_model_path, device_map="cpu")
157-
lmhead_config = model.config.quantization_config.extra_config["lm_head"]
158-
assert "act_data_type" in lmhead_config.keys() and lmhead_config["act_data_type"] == "mx_fp_rceil"
159-
assert "act_bits" in lmhead_config.keys() and lmhead_config["act_bits"] == 8
160-
assert "act_group_size" in lmhead_config.keys() and lmhead_config["act_group_size"] == 32
161-
assert "act_sym" in lmhead_config.keys() and lmhead_config["act_sym"]
162-
assert "data_type" in lmhead_config.keys() and lmhead_config["data_type"] == "mx_fp"
163-
assert "bits" in lmhead_config.keys() and lmhead_config["bits"] == 8
164-
assert "group_size" in lmhead_config.keys() and lmhead_config["group_size"] == 32
165-
assert "sym" in lmhead_config.keys() and lmhead_config["sym"]
166-
assert "super_bits" in lmhead_config.keys() and lmhead_config["super_bits"] is None
167-
assert "super_group_size" in lmhead_config.keys() and lmhead_config["super_group_size"] is None
157+
assert "lm_head" not in model.config.quantization_config.extra_config
158+
168159
# check inblock layer config values
169160
kproj_config = model.config.quantization_config.extra_config["model.decoder.layers.1.self_attn.k_proj"]
170161
assert "act_data_type" in kproj_config.keys() and kproj_config["act_data_type"] == "mx_fp_rceil"
@@ -204,7 +195,7 @@ def test_act_config_NVFP4_saving(self):
204195

205196
def test_WOQ_config_INT_saving(self):
206197
scheme = "W4A16"
207-
layer_config = {"k_proj": {"bits": 8}} # "lm_head": {"bits": 4},
198+
layer_config = {"k_proj": {"bits": 8}}
208199
autoround = AutoRound(
209200
self.model_name,
210201
scheme=scheme,
@@ -218,18 +209,6 @@ def test_WOQ_config_INT_saving(self):
218209
autoround.quantize_and_save(output_dir=quantized_model_path, format="auto_round")
219210
model = AutoModelForCausalLM.from_pretrained(quantized_model_path, device_map="cpu")
220211
extra_config = model.config.quantization_config.extra_config
221-
# lmhead_config = extra_config["lm_head"]
222-
# assert "act_data_type" in lmhead_config.keys() and lmhead_config["act_data_type"] == "float"
223-
# assert "act_bits" in lmhead_config.keys() and lmhead_config["act_bits"] == 16
224-
# assert "act_group_size" in lmhead_config.keys() and lmhead_config["act_group_size"] == 128
225-
# assert "act_sym" in lmhead_config.keys() and not lmhead_config["act_sym"]
226-
# assert "data_type" in lmhead_config.keys() and lmhead_config["data_type"] == "int"
227-
# assert "bits" in lmhead_config.keys() and lmhead_config["bits"] == 4
228-
# assert "group_size" in lmhead_config.keys() and lmhead_config["group_size"] == 128
229-
# assert "sym" in lmhead_config.keys() and not lmhead_config["sym"]
230-
# assert "act_dynamic" in lmhead_config.keys() and lmhead_config["act_dynamic"]
231-
# assert "super_bits" in lmhead_config.keys() and lmhead_config["super_bits"] is None
232-
# assert "super_group_size" in lmhead_config.keys() and lmhead_config["super_group_size"] is None
233212

234213
# check inblock layer config values
235214
kproj_config = extra_config["model.decoder.layers.1.self_attn.k_proj"]
@@ -270,18 +249,8 @@ def test_act_config_FP8_saving(self):
270249
from transformers import AutoConfig
271250

272251
extra_config = AutoConfig.from_pretrained(quantized_model_path).quantization_config["extra_config"]
273-
lmhead_config = extra_config["lm_head"]
274-
assert "act_data_type" in lmhead_config.keys() and lmhead_config["act_data_type"] == "fp"
275-
assert "act_bits" in lmhead_config.keys() and lmhead_config["act_bits"] == 8
276-
assert "act_group_size" in lmhead_config.keys() and lmhead_config["act_group_size"] == 0
277-
assert "act_sym" in lmhead_config.keys() and lmhead_config["act_sym"]
278-
assert "data_type" in lmhead_config.keys() and lmhead_config["data_type"] == "fp"
279-
assert "bits" in lmhead_config.keys() and lmhead_config["bits"] == 8
280-
assert "group_size" in lmhead_config.keys() and lmhead_config["group_size"] == -1
281-
assert "sym" in lmhead_config.keys() and lmhead_config["sym"]
282-
assert "act_dynamic" in lmhead_config.keys() and not lmhead_config["act_dynamic"]
283-
assert "super_bits" in lmhead_config.keys() and lmhead_config["super_bits"] is None
284-
assert "super_group_size" in lmhead_config.keys() and lmhead_config["super_group_size"] is None
252+
assert "lm_head" not in extra_config
253+
285254
# check inblock layer config values
286255
kproj_config = extra_config["model.decoder.layers.0.self_attn.k_proj"]
287256
assert "act_data_type" in kproj_config.keys() and kproj_config["act_data_type"] == "fp"

test/test_cpu/test_autoround.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -743,8 +743,20 @@ def test_invalid_layer_config(self):
743743

744744
def test_quant_lm_head(self):
745745
model_name = "/tf_dataset/auto_round/models/Qwen/Qwen3-8B"
746-
ar = AutoRound(model_name, quant_lm_head=True, iters=1, nsamples=1, seqlen=32)
747-
ar.quantize()
746+
ar = AutoRound(model_name, quant_lm_head=True, iters=0, disable_opt_rtn=True)
747+
ar.quantize_and_save(output_dir=self.save_folder, format="auto_round")
748+
model = AutoModelForCausalLM.from_pretrained(self.save_folder, device_map="cpu")
749+
assert "lm_head" in model.config.quantization_config.extra_config
750+
assert model.config.quantization_config.extra_config["lm_head"]["bits"] == 4
751+
752+
def test_quant_lm_head_layer_config(self):
753+
model_name = "/tf_dataset/auto_round/models/Qwen/Qwen3-8B"
754+
layer_config = {"lm_head": {"bits": 4}}
755+
ar = AutoRound(model_name, quant_lm_head=True, iters=0, disable_opt_rtn=True, layer_config=layer_config)
756+
ar.quantize_and_save(output_dir=self.save_folder, format="auto_round")
757+
model = AutoModelForCausalLM.from_pretrained(self.save_folder, device_map="cpu")
758+
assert "lm_head" in model.config.quantization_config.extra_config
759+
assert model.config.quantization_config.extra_config["lm_head"]["bits"] == 4
748760

749761
def test_compressor(self):
750762
model_name = "Qwen/Qwen2-VL-2B-Instruct"

0 commit comments

Comments
 (0)