Skip to content

Commit 68ca65e

Browse files
committed
Update utils.py
1 parent 11759bb commit 68ca65e

File tree

1 file changed

+8
-9
lines changed

1 file changed

+8
-9
lines changed

src/transformers/generation/utils.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -658,15 +658,14 @@ def prepare_inputs_for_generation(
658658
token_type_ids = getattr(model_input, "token_type_ids", None)
659659
causal_mask_creation_function = getattr(self, "create_masks_for_generate", create_masks_for_generate)
660660
attention_mask = causal_mask_creation_function(
661-
self.config,
662-
torch.empty(
663-
(batch_size, sequence_length), dtype=self.dtype
664-
), # we only need batch size, seq_length and dtype here - we don't care about the values
665-
attention_mask,
666-
cache_position,
667-
past_key_values,
668-
output_attentions,
669-
token_type_ids,
661+
config=self.config,
662+
# we only need batch size, seq_length and dtype here - we don't care about the values of the embeddings
663+
input_embeds=torch.empty((batch_size, sequence_length), dtype=self.dtype),
664+
attention_mask=attention_mask,
665+
cache_position=cache_position,
666+
past_key_values=past_key_values,
667+
output_attentions=output_attentions,
668+
token_type_ids=token_type_ids,
670669
)
671670
else:
672671
attention_mask = causal_mask_creation_function(

0 commit comments

Comments
 (0)