File tree Expand file tree Collapse file tree 1 file changed +8
-9
lines changed
src/transformers/generation Expand file tree Collapse file tree 1 file changed +8
-9
lines changed Original file line number Diff line number Diff line change @@ -658,15 +658,14 @@ def prepare_inputs_for_generation(
658658 token_type_ids = getattr (model_input , "token_type_ids" , None )
659659 causal_mask_creation_function = getattr (self , "create_masks_for_generate" , create_masks_for_generate )
660660 attention_mask = causal_mask_creation_function (
661- self .config ,
662- torch .empty (
663- (batch_size , sequence_length ), dtype = self .dtype
664- ), # we only need batch size, seq_length and dtype here - we don't care about the values
665- attention_mask ,
666- cache_position ,
667- past_key_values ,
668- output_attentions ,
669- token_type_ids ,
661+ config = self .config ,
662+ # we only need batch size, seq_length and dtype here - we don't care about the values of the embeddings
663+ input_embeds = torch .empty ((batch_size , sequence_length ), dtype = self .dtype ),
664+ attention_mask = attention_mask ,
665+ cache_position = cache_position ,
666+ past_key_values = past_key_values ,
667+ output_attentions = output_attentions ,
668+ token_type_ids = token_type_ids ,
670669 )
671670 else :
672671 attention_mask = causal_mask_creation_function (
You can’t perform that action at this time.
0 commit comments