Skip to content

Commit dfbbba4

Browse files
authored
support attention mask in user's dataset (#930)
1 parent af8708e commit dfbbba4

File tree

3 files changed

+126
-26
lines changed

3 files changed

+126
-26
lines changed

auto_round/compressors/base.py

Lines changed: 68 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -369,6 +369,8 @@ def __init__(
369369
import habana_frameworks.torch.core as htcore # pylint: disable=E0401
370370
import habana_frameworks.torch.hpu as hthpu # pylint: disable=E0401]
371371

372+
self.attention_mask = []
373+
372374
def _gen_auto_scheme(
373375
self, model: torch.nn.Module, scheme: AutoScheme, dataset: str, device_map: Union[str, int, dict, torch.device]
374376
) -> dict[str, dict]:
@@ -809,21 +811,6 @@ def _check_compatibility(self) -> None:
809811
" We are likely to release new algorithm for certain configurations in the future."
810812
)
811813

812-
# # Check group_size 32 for auto_round
813-
# if (
814-
# self.data_type == "int"
815-
# and hasattr(self, "formats")
816-
# and any(key in fmt for fmt in self.formats for key in ("auto_round", "auto_gptq", "auto_awq"))
817-
# ):
818-
# for n, m in self.model.named_modules():
819-
# if type(m) in self.supported_types:
820-
# if m.weight.shape[0] % 32 != 0 or m.weight.shape[1] % 32 != 0:
821-
# self.layer_config[n] = {"bits": 16}
822-
# logger.info(
823-
# f"{n} will not be quantized due to its shape not being divisible by 32,"
824-
# " resulting in an exporting issue to autogptq"
825-
# )
826-
827814
if (
828815
self.seqlen is not None
829816
and hasattr(self.model, "config")
@@ -1197,7 +1184,7 @@ def _quantize_embedding_layer(self):
11971184
module.weight.to(self.device),
11981185
**{k: config[k] for k in ["bits", "group_size", "super_bits", "super_group_size", "scale_dtype"]},
11991186
)
1200-
except RuntimeError as e:
1187+
except torch.OutOfMemoryError:
12011188
cuda_error_msg = traceback.format_exc()
12021189
try:
12031190
logger.error(cuda_error_msg)
@@ -1298,7 +1285,7 @@ def get_imatrix_hook(module, input, output):
12981285
model = model.to("cpu")
12991286
clear_memory()
13001287
self._quantize_via_rtn_blockwise(all_to_quantized_module_names)
1301-
except RuntimeError as e:
1288+
except torch.OutOfMemoryError:
13021289
cuda_error_msg = traceback.format_exc()
13031290
try:
13041291
logger.error(cuda_error_msg)
@@ -1372,7 +1359,7 @@ def _quantize_layer_via_rtn(self, name: str) -> None:
13721359
)
13731360
m = m.unwrapper({})
13741361
m.to("cpu")
1375-
except RuntimeError as e:
1362+
except torch.OutOfMemoryError:
13761363
cuda_error_msg = traceback.format_exc()
13771364
m = m.orig_layer if hasattr(m, "orig_layer") else m
13781365
try:
@@ -1474,7 +1461,7 @@ def _quantize_rtn(self) -> tuple[torch.nn.Module, dict[str, Any]]:
14741461
hook_handles = self._register_act_max_hook(self.model)
14751462
try:
14761463
self._quantize_via_rtn_blockwise(all_to_quantized_module_names)
1477-
except RuntimeError as e:
1464+
except torch.OutOfMemoryError:
14781465
logger.warning("Fallback to CPU. Consider using more GPUs via `--device 0,1,2,3`.")
14791466
self.model = self.model.to("cpu")
14801467
clear_memory()
@@ -1932,7 +1919,9 @@ def calib(self, nsamples, bs):
19321919
"""
19331920
from auto_round.calib_dataset import get_dataloader
19341921

1922+
need_attention_mask = True
19351923
if isinstance(self.dataset, str):
1924+
need_attention_mask = False # all supported datasets does not use pad
19361925
dataset = self.dataset.replace(" ", "") ##remove all whitespaces
19371926

19381927
# slow here
@@ -1995,6 +1984,41 @@ def calib(self, nsamples, bs):
19951984
raise error
19961985
except Exception as error:
19971986
raise error
1987+
if need_attention_mask:
1988+
if (
1989+
isinstance(data_new, dict)
1990+
and "attention_mask" in data_new
1991+
and data_new["attention_mask"] is not None
1992+
):
1993+
new_attention_mask = data_new["attention_mask"]
1994+
elif (
1995+
self.tokenizer is not None
1996+
and hasattr(self.tokenizer, "pad_token")
1997+
and self.tokenizer.pad_token is not None
1998+
):
1999+
new_attention_mask = (input_ids != self.tokenizer.pad_token_id).to(torch.long)
2000+
else:
2001+
# Default all ones
2002+
new_attention_mask = torch.ones_like(input_ids, dtype=torch.long)
2003+
2004+
# For each sample, check if there are trailing repeated tokens
2005+
# If so, set the mask of the last token to 0
2006+
batch_size, seq_len = input_ids.shape
2007+
for i in range(batch_size):
2008+
last_token = input_ids[i, -1]
2009+
# Check for trailing repeats
2010+
j = seq_len - 2
2011+
repeated = False
2012+
while j >= 0 and input_ids[i, j] == last_token:
2013+
repeated = True
2014+
new_attention_mask[i, j] = 0
2015+
j -= 1
2016+
# If there was at least one repeat, set last token mask to 0
2017+
if repeated:
2018+
new_attention_mask[i, -1] = 0
2019+
2020+
self.attention_mask.extend(list(torch.split(new_attention_mask, 1, dim=0)))
2021+
19982022
total_cnt += input_ids.shape[0] if len(input_ids.shape) > 1 else 1
19992023
if total_cnt >= nsamples:
20002024
break
@@ -2070,7 +2094,7 @@ def try_cache_inter_data_gpucpu(self, block_names, nsamples, layer_names=None, l
20702094
if hasattr(self.model, "hf_device_map") and len(self.model.hf_device_map) > 1:
20712095
accelerate.hooks.remove_hook_from_submodules(self.model)
20722096

2073-
except RuntimeError as e:
2097+
except torch.OutOfMemoryError:
20742098
cuda_error_msg = traceback.format_exc()
20752099
try:
20762100
logger.info("switch to cpu to cache block inputs")
@@ -2082,10 +2106,10 @@ def try_cache_inter_data_gpucpu(self, block_names, nsamples, layer_names=None, l
20822106
if hasattr(self.model, "hf_device_map") and len(self.model.hf_device_map) > 1:
20832107
accelerate.hooks.remove_hook_from_submodules(
20842108
self.model
2085-
) ##self.model.hf_device_map has not been changed
2109+
) # self.model.hf_device_map has not been changed
20862110
self.model = mv_module_from_gpu(self.model)
20872111
clear_memory()
2088-
## Important change after v0.51, on cpu, we use rtn mode for layers in layer_names
2112+
# Important change after v0.51, on cpu, we use rtn mode for layers in layer_names
20892113
all_inputs = self.cache_inter_data(
20902114
block_names, nsamples, layer_names=[], last_cache_name=last_cache_name
20912115
)
@@ -2397,15 +2421,24 @@ def _quantize_layer(
23972421
org_input = current_input
23982422
with torch.no_grad():
23992423
current_output = layer(org_input)
2424+
if self.attention_mask:
2425+
tmp_attention_mask = [self.attention_mask[i] for i in indices]
2426+
tmp_attention_mask = torch.cat(tmp_attention_mask, dim=0).to(device)
2427+
tmp_attention_mask.unsqueeze_(-1)
2428+
else:
2429+
tmp_attention_mask = 1.0
24002430

24012431
if self.amp:
24022432
with autocast(device_type=device.split(":")[0], dtype=self.amp_dtype):
24032433
output_q = wrapper_linear(current_input) # pylint: disable=not-callable
2404-
loss = mse_loss(output_q, current_output) # pylint: disable=not-callable
2434+
loss = mse_loss( # pylint: disable=not-callable
2435+
output_q * tmp_attention_mask, current_output * tmp_attention_mask
2436+
)
24052437
else:
24062438
output_q = wrapper_linear(current_input) # pylint: disable=not-callable
24072439
loss = mse_loss( # pylint: disable=not-callable
2408-
output_q.to(torch.float32), current_output.to(torch.float32)
2440+
output_q.to(torch.float32) * tmp_attention_mask,
2441+
current_output.to(torch.float32) * tmp_attention_mask,
24092442
)
24102443
total_loss += loss.item() / num_elm
24112444

@@ -2674,12 +2707,21 @@ def _quantize_block(
26742707
current_output = to_device(current_output, device)
26752708

26762709
output_q = self._get_current_q_output(block, input_ids, input_others, indices, device)
2710+
if self.attention_mask:
2711+
tmp_attention_mask = [self.attention_mask[i] for i in indices]
2712+
tmp_attention_mask = torch.cat(tmp_attention_mask, dim=0).to(device)
2713+
tmp_attention_mask.unsqueeze_(-1)
2714+
else:
2715+
tmp_attention_mask = 1.0
26772716
if self.amp:
26782717
with autocast(device_type=device.split(":")[0], dtype=self.amp_dtype):
2679-
loss = mse_loss(output_q, current_output) # pylint: disable=not-callable
2718+
loss = mse_loss( # pylint: disable=not-callable
2719+
output_q * tmp_attention_mask, current_output * tmp_attention_mask
2720+
)
26802721
else:
26812722
loss = mse_loss( # pylint: disable=not-callable
2682-
output_q.to(torch.float32), current_output.to(torch.float32)
2723+
output_q.to(torch.float32) * tmp_attention_mask,
2724+
current_output.to(torch.float32) * tmp_attention_mask,
26832725
)
26842726

26852727
total_loss += loss.item() / num_elm

test/test_cpu/test_autoround.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -755,6 +755,44 @@ def test_compressor(self):
755755
ar = AutoRoundMLLM(model_name)
756756
self.assertTrue(ar.mllm)
757757

758+
def test_attention_mask_in_dataset(self):
759+
from transformers import AutoTokenizer
760+
761+
model_name = "/tf_dataset/auto_round/models/Qwen/Qwen3-0.6B"
762+
# model_name = "/models/Qwen3-0.6B"
763+
tokenizer = AutoTokenizer.from_pretrained(model_name)
764+
text = ["haha", "hello world"]
765+
res = tokenizer(text, return_tensors="pt", max_length=8, padding="max_length", truncation=True)
766+
data = [res.data]
767+
768+
text = ["qudd", "hfd"]
769+
res = tokenizer(text, return_tensors="pt", max_length=8, padding="max_length", truncation=True)
770+
data.append(res.data)
771+
from auto_round import AutoRound
772+
773+
ar = AutoRound(model_name, iters=1, dataset=data, seqlen=8)
774+
ar.quantize()
775+
776+
def test_attention_mask_via_tokenize_in_dataset(self):
777+
from transformers import AutoTokenizer
778+
779+
model_name = "/tf_dataset/auto_round/models/Qwen/Qwen3-0.6B"
780+
# model_name = "/models/Qwen3-0.6B"
781+
tokenizer = AutoTokenizer.from_pretrained(model_name)
782+
text = ["haha", "hello world"]
783+
res = tokenizer(text, return_tensors="pt", max_length=8, padding="max_length", truncation=True)
784+
res.data.pop("attention_mask")
785+
data = [res.data]
786+
787+
text = ["qudd", "hfd"]
788+
res = tokenizer(text, return_tensors="pt", max_length=8, padding="max_length", truncation=True)
789+
res.data.pop("attention_mask")
790+
data.append(res.data)
791+
from auto_round import AutoRound
792+
793+
ar = AutoRound(model_name, iters=1, dataset=data, seqlen=8)
794+
ar.quantize()
795+
758796

759797
if __name__ == "__main__":
760798
unittest.main()

test/test_cuda/test_main_func.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,26 @@ def test_autoround_asym(self): ##need to install false
179179
assert accuracy > 0.35
180180
shutil.rmtree("./saved", ignore_errors=True)
181181

182+
def test_attention_mask_lm_head(self):
183+
from transformers import AutoTokenizer
184+
185+
model_name = "/models/Qwen3-8B"
186+
# model_name = "/models/Qwen3-0.6B"
187+
tokenizer = AutoTokenizer.from_pretrained(model_name)
188+
text = ["haha", "hello world"]
189+
res = tokenizer(text, return_tensors="pt", max_length=8, padding="max_length", truncation=True)
190+
res.data.pop("attention_mask")
191+
data = [res.data]
192+
193+
text = ["qudd", "hfd"]
194+
res = tokenizer(text, return_tensors="pt", max_length=8, padding="max_length", truncation=True)
195+
res.data.pop("attention_mask")
196+
data.append(res.data)
197+
from auto_round import AutoRound
198+
199+
ar = AutoRound(model_name, iters=1, dataset=data, seqlen=8, quant_lm_head=True)
200+
ar.quantize()
201+
182202

183203
if __name__ == "__main__":
184204
unittest.main()

0 commit comments

Comments
 (0)