@@ -2099,6 +2099,31 @@ def forward(
20992099 depth_attentions = None if decoder_outputs is None else decoder_outputs .attentions ,
21002100 )
21012101
2102+ def _prepare_attention_mask_for_generation (
2103+ self ,
2104+ input_ids : torch .LongTensor ,
2105+ generation_config : GenerationConfig ,
2106+ kwargs : Dict [str , Any ],
2107+ ) -> torch .LongTensor :
2108+ pad_token_id = generation_config .pad_token_id
2109+ eos_token_id = generation_config .eos_token_id
2110+
2111+ default_attention_mask = torch .ones (input_ids .shape , dtype = torch .long , device = input_ids .device )
2112+ if pad_token_id is None :
2113+ return default_attention_mask
2114+
2115+ is_pad_token_in_inputs = (pad_token_id is not None ) and torch .isin (input_ids , pad_token_id ).any ()
2116+ is_pad_token_not_equal_to_eos_token_id = (eos_token_id is None ) or ~ torch .isin (
2117+ eos_token_id , pad_token_id
2118+ ).any ()
2119+ can_infer_attention_mask = is_pad_token_in_inputs * is_pad_token_not_equal_to_eos_token_id
2120+ attention_mask_from_padding = input_ids .ne (pad_token_id ).long ()
2121+
2122+ attention_mask = (
2123+ attention_mask_from_padding * can_infer_attention_mask + default_attention_mask * ~ can_infer_attention_mask
2124+ )
2125+ return attention_mask
2126+
21022127 def _prepare_inputs_embeds_for_generation (
21032128 self ,
21042129 input_ids : Optional [torch .LongTensor ] = None ,
@@ -2315,6 +2340,12 @@ def generate(
23152340 kwargs_depth_decoder = depth_decoder_generation_config
23162341
23172342 attention_mask = kwargs .pop ("attention_mask" , None )
2343+ if attention_mask is None :
2344+ attention_mask = self ._prepare_attention_mask_for_generation (
2345+ input_ids = input_ids ,
2346+ generation_config = generation_config ,
2347+ kwargs = kwargs ,
2348+ )
23182349 (
23192350 inputs_embeds ,
23202351 input_ids ,
@@ -2497,11 +2528,11 @@ def prepare_inputs_for_generation(
24972528 batch_size , sequence_length = input_ids .shape
24982529 device = input_ids .device
24992530
2500- attention_mask = self .model ._prepare_4d_causal_attention_mask_with_cache_position (
2531+ attention_mask = self .decoder . model ._prepare_4d_causal_attention_mask_with_cache_position (
25012532 attention_mask ,
25022533 sequence_length = sequence_length ,
25032534 target_length = past_key_values .get_max_cache_shape (),
2504- dtype = self .lm_head .weight .dtype ,
2535+ dtype = self .decoder . lm_head .weight .dtype ,
25052536 device = device ,
25062537 cache_position = cache_position ,
25072538 batch_size = batch_size ,
0 commit comments