@@ -369,6 +369,8 @@ def __init__(
369369 import habana_frameworks .torch .core as htcore # pylint: disable=E0401
370370 import habana_frameworks .torch .hpu as hthpu # pylint: disable=E0401]
371371
372+ self .attention_mask = []
373+
372374 def _gen_auto_scheme (
373375 self , model : torch .nn .Module , scheme : AutoScheme , dataset : str , device_map : Union [str , int , dict , torch .device ]
374376 ) -> dict [str , dict ]:
@@ -809,21 +811,6 @@ def _check_compatibility(self) -> None:
809811 " We are likely to release new algorithm for certain configurations in the future."
810812 )
811813
812- # # Check group_size 32 for auto_round
813- # if (
814- # self.data_type == "int"
815- # and hasattr(self, "formats")
816- # and any(key in fmt for fmt in self.formats for key in ("auto_round", "auto_gptq", "auto_awq"))
817- # ):
818- # for n, m in self.model.named_modules():
819- # if type(m) in self.supported_types:
820- # if m.weight.shape[0] % 32 != 0 or m.weight.shape[1] % 32 != 0:
821- # self.layer_config[n] = {"bits": 16}
822- # logger.info(
823- # f"{n} will not be quantized due to its shape not being divisible by 32,"
824- # " resulting in an exporting issue to autogptq"
825- # )
826-
827814 if (
828815 self .seqlen is not None
829816 and hasattr (self .model , "config" )
@@ -1197,7 +1184,7 @@ def _quantize_embedding_layer(self):
11971184 module .weight .to (self .device ),
11981185 ** {k : config [k ] for k in ["bits" , "group_size" , "super_bits" , "super_group_size" , "scale_dtype" ]},
11991186 )
1200- except RuntimeError as e :
1187+ except torch . OutOfMemoryError :
12011188 cuda_error_msg = traceback .format_exc ()
12021189 try :
12031190 logger .error (cuda_error_msg )
@@ -1298,7 +1285,7 @@ def get_imatrix_hook(module, input, output):
12981285 model = model .to ("cpu" )
12991286 clear_memory ()
13001287 self ._quantize_via_rtn_blockwise (all_to_quantized_module_names )
1301- except RuntimeError as e :
1288+ except torch . OutOfMemoryError :
13021289 cuda_error_msg = traceback .format_exc ()
13031290 try :
13041291 logger .error (cuda_error_msg )
@@ -1372,7 +1359,7 @@ def _quantize_layer_via_rtn(self, name: str) -> None:
13721359 )
13731360 m = m .unwrapper ({})
13741361 m .to ("cpu" )
1375- except RuntimeError as e :
1362+ except torch . OutOfMemoryError :
13761363 cuda_error_msg = traceback .format_exc ()
13771364 m = m .orig_layer if hasattr (m , "orig_layer" ) else m
13781365 try :
@@ -1474,7 +1461,7 @@ def _quantize_rtn(self) -> tuple[torch.nn.Module, dict[str, Any]]:
14741461 hook_handles = self ._register_act_max_hook (self .model )
14751462 try :
14761463 self ._quantize_via_rtn_blockwise (all_to_quantized_module_names )
1477- except RuntimeError as e :
1464+ except torch . OutOfMemoryError :
14781465 logger .warning ("Fallback to CPU. Consider using more GPUs via `--device 0,1,2,3`." )
14791466 self .model = self .model .to ("cpu" )
14801467 clear_memory ()
@@ -1932,7 +1919,9 @@ def calib(self, nsamples, bs):
19321919 """
19331920 from auto_round .calib_dataset import get_dataloader
19341921
1922+ need_attention_mask = True
19351923 if isinstance (self .dataset , str ):
1924+ need_attention_mask = False # all supported datasets does not use pad
19361925 dataset = self .dataset .replace (" " , "" ) ##remove all whitespaces
19371926
19381927 # slow here
@@ -1995,6 +1984,41 @@ def calib(self, nsamples, bs):
19951984 raise error
19961985 except Exception as error :
19971986 raise error
1987+ if need_attention_mask :
1988+ if (
1989+ isinstance (data_new , dict )
1990+ and "attention_mask" in data_new
1991+ and data_new ["attention_mask" ] is not None
1992+ ):
1993+ new_attention_mask = data_new ["attention_mask" ]
1994+ elif (
1995+ self .tokenizer is not None
1996+ and hasattr (self .tokenizer , "pad_token" )
1997+ and self .tokenizer .pad_token is not None
1998+ ):
1999+ new_attention_mask = (input_ids != self .tokenizer .pad_token_id ).to (torch .long )
2000+ else :
2001+ # Default all ones
2002+ new_attention_mask = torch .ones_like (input_ids , dtype = torch .long )
2003+
2004+ # For each sample, check if there are trailing repeated tokens
2005+ # If so, set the mask of the last token to 0
2006+ batch_size , seq_len = input_ids .shape
2007+ for i in range (batch_size ):
2008+ last_token = input_ids [i , - 1 ]
2009+ # Check for trailing repeats
2010+ j = seq_len - 2
2011+ repeated = False
2012+ while j >= 0 and input_ids [i , j ] == last_token :
2013+ repeated = True
2014+ new_attention_mask [i , j ] = 0
2015+ j -= 1
2016+ # If there was at least one repeat, set last token mask to 0
2017+ if repeated :
2018+ new_attention_mask [i , - 1 ] = 0
2019+
2020+ self .attention_mask .extend (list (torch .split (new_attention_mask , 1 , dim = 0 )))
2021+
19982022 total_cnt += input_ids .shape [0 ] if len (input_ids .shape ) > 1 else 1
19992023 if total_cnt >= nsamples :
20002024 break
@@ -2070,7 +2094,7 @@ def try_cache_inter_data_gpucpu(self, block_names, nsamples, layer_names=None, l
20702094 if hasattr (self .model , "hf_device_map" ) and len (self .model .hf_device_map ) > 1 :
20712095 accelerate .hooks .remove_hook_from_submodules (self .model )
20722096
2073- except RuntimeError as e :
2097+ except torch . OutOfMemoryError :
20742098 cuda_error_msg = traceback .format_exc ()
20752099 try :
20762100 logger .info ("switch to cpu to cache block inputs" )
@@ -2082,10 +2106,10 @@ def try_cache_inter_data_gpucpu(self, block_names, nsamples, layer_names=None, l
20822106 if hasattr (self .model , "hf_device_map" ) and len (self .model .hf_device_map ) > 1 :
20832107 accelerate .hooks .remove_hook_from_submodules (
20842108 self .model
2085- ) ## self.model.hf_device_map has not been changed
2109+ ) # self.model.hf_device_map has not been changed
20862110 self .model = mv_module_from_gpu (self .model )
20872111 clear_memory ()
2088- ## Important change after v0.51, on cpu, we use rtn mode for layers in layer_names
2112+ # Important change after v0.51, on cpu, we use rtn mode for layers in layer_names
20892113 all_inputs = self .cache_inter_data (
20902114 block_names , nsamples , layer_names = [], last_cache_name = last_cache_name
20912115 )
@@ -2397,15 +2421,24 @@ def _quantize_layer(
23972421 org_input = current_input
23982422 with torch .no_grad ():
23992423 current_output = layer (org_input )
2424+ if self .attention_mask :
2425+ tmp_attention_mask = [self .attention_mask [i ] for i in indices ]
2426+ tmp_attention_mask = torch .cat (tmp_attention_mask , dim = 0 ).to (device )
2427+ tmp_attention_mask .unsqueeze_ (- 1 )
2428+ else :
2429+ tmp_attention_mask = 1.0
24002430
24012431 if self .amp :
24022432 with autocast (device_type = device .split (":" )[0 ], dtype = self .amp_dtype ):
24032433 output_q = wrapper_linear (current_input ) # pylint: disable=not-callable
2404- loss = mse_loss (output_q , current_output ) # pylint: disable=not-callable
2434+ loss = mse_loss ( # pylint: disable=not-callable
2435+ output_q * tmp_attention_mask , current_output * tmp_attention_mask
2436+ )
24052437 else :
24062438 output_q = wrapper_linear (current_input ) # pylint: disable=not-callable
24072439 loss = mse_loss ( # pylint: disable=not-callable
2408- output_q .to (torch .float32 ), current_output .to (torch .float32 )
2440+ output_q .to (torch .float32 ) * tmp_attention_mask ,
2441+ current_output .to (torch .float32 ) * tmp_attention_mask ,
24092442 )
24102443 total_loss += loss .item () / num_elm
24112444
@@ -2674,12 +2707,21 @@ def _quantize_block(
26742707 current_output = to_device (current_output , device )
26752708
26762709 output_q = self ._get_current_q_output (block , input_ids , input_others , indices , device )
2710+ if self .attention_mask :
2711+ tmp_attention_mask = [self .attention_mask [i ] for i in indices ]
2712+ tmp_attention_mask = torch .cat (tmp_attention_mask , dim = 0 ).to (device )
2713+ tmp_attention_mask .unsqueeze_ (- 1 )
2714+ else :
2715+ tmp_attention_mask = 1.0
26772716 if self .amp :
26782717 with autocast (device_type = device .split (":" )[0 ], dtype = self .amp_dtype ):
2679- loss = mse_loss (output_q , current_output ) # pylint: disable=not-callable
2718+ loss = mse_loss ( # pylint: disable=not-callable
2719+ output_q * tmp_attention_mask , current_output * tmp_attention_mask
2720+ )
26802721 else :
26812722 loss = mse_loss ( # pylint: disable=not-callable
2682- output_q .to (torch .float32 ), current_output .to (torch .float32 )
2723+ output_q .to (torch .float32 ) * tmp_attention_mask ,
2724+ current_output .to (torch .float32 ) * tmp_attention_mask ,
26832725 )
26842726
26852727 total_loss += loss .item () / num_elm
0 commit comments