diff --git a/src/transformers/models/biogpt/modeling_biogpt.py b/src/transformers/models/biogpt/modeling_biogpt.py index 977e54ddac25b9..221c97c86885c5 100755 --- a/src/transformers/models/biogpt/modeling_biogpt.py +++ b/src/transformers/models/biogpt/modeling_biogpt.py @@ -71,13 +71,7 @@ def forward(self, attention_mask: torch.LongTensor, past_key_values_length: int attention_mask = attention_mask.long() # create positions depending on attention_mask - if len(attention_mask.shape) == 2: - positions = (torch.cumsum(attention_mask, dim=-1).type_as(attention_mask) * attention_mask).long() - 1 - elif len(attention_mask.shape) == 4: - # assumes 4D mask for efficient beam search - token_positions = torch.cumsum(attention_mask, dim=-1).amax(dim=(1, 2)) - used_tokens_mask = attention_mask.amax(dim=(1, 2)) - positions = (token_positions * used_tokens_mask).long() - 1 + positions = (torch.cumsum(attention_mask, dim=1).type_as(attention_mask) * attention_mask).long() - 1 # cut positions if `past_key_values_length` is > 0 positions = positions[:, past_key_values_length:] diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 545db24c40e3b1..e53f89276f3f48 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -872,18 +872,11 @@ def forward( past_key_values_length = past_key_values[0][0].shape[2] if position_ids is None: - if attention_mask is not None and len(attention_mask.shape) == 4: - # assumes 4D mask for efficient beam search - token_positions = torch.cumsum(attention_mask, dim=-1).amax(dim=(1, 2)) - used_tokens_mask = attention_mask.amax(dim=(1, 2)) - position_ids = (token_positions * used_tokens_mask).long() - 1 - position_ids = position_ids[:, past_key_values_length:] - else: - device = input_ids.device if input_ids is not None else inputs_embeds.device - position_ids = torch.arange( - past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device - ) - position_ids = position_ids.unsqueeze(0) + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = torch.arange( + past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device + ) + position_ids = position_ids.unsqueeze(0) if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) diff --git a/src/transformers/models/opt/modeling_opt.py b/src/transformers/models/opt/modeling_opt.py index dd571a01464a38..2192f327bc49f9 100644 --- a/src/transformers/models/opt/modeling_opt.py +++ b/src/transformers/models/opt/modeling_opt.py @@ -100,13 +100,7 @@ def forward(self, attention_mask: torch.LongTensor, past_key_values_length: int attention_mask = attention_mask.long() # create positions depending on attention_mask - if len(attention_mask.shape) == 2: - positions = (torch.cumsum(attention_mask, dim=-1).type_as(attention_mask) * attention_mask).long() - 1 - elif len(attention_mask.shape) == 4: - # assumes 4D mask for efficient beam search - token_positions = torch.cumsum(attention_mask, dim=-1).amax(dim=(1, 2)) - used_tokens_mask = attention_mask.amax(dim=(1, 2)) - positions = (token_positions * used_tokens_mask).long() - 1 + positions = (torch.cumsum(attention_mask, dim=1).type_as(attention_mask) * attention_mask).long() - 1 # cut positions if `past_key_values_length` is > 0 positions = positions[:, past_key_values_length:]