Skip to content

Commit

Permalink
reverted changes in individual models
Browse files Browse the repository at this point in the history
  • Loading branch information
poedator committed Nov 30, 2023
1 parent ec830a7 commit 204b4b8
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 26 deletions.
8 changes: 1 addition & 7 deletions src/transformers/models/biogpt/modeling_biogpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,13 +71,7 @@ def forward(self, attention_mask: torch.LongTensor, past_key_values_length: int
attention_mask = attention_mask.long()

# create positions depending on attention_mask
if len(attention_mask.shape) == 2:
positions = (torch.cumsum(attention_mask, dim=-1).type_as(attention_mask) * attention_mask).long() - 1
elif len(attention_mask.shape) == 4:
# assumes 4D mask for efficient beam search
token_positions = torch.cumsum(attention_mask, dim=-1).amax(dim=(1, 2))
used_tokens_mask = attention_mask.amax(dim=(1, 2))
positions = (token_positions * used_tokens_mask).long() - 1
positions = (torch.cumsum(attention_mask, dim=1).type_as(attention_mask) * attention_mask).long() - 1

# cut positions if `past_key_values_length` is > 0
positions = positions[:, past_key_values_length:]
Expand Down
17 changes: 5 additions & 12 deletions src/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -872,18 +872,11 @@ def forward(
past_key_values_length = past_key_values[0][0].shape[2]

if position_ids is None:
if attention_mask is not None and len(attention_mask.shape) == 4:
# assumes 4D mask for efficient beam search
token_positions = torch.cumsum(attention_mask, dim=-1).amax(dim=(1, 2))
used_tokens_mask = attention_mask.amax(dim=(1, 2))
position_ids = (token_positions * used_tokens_mask).long() - 1
position_ids = position_ids[:, past_key_values_length:]
else:
device = input_ids.device if input_ids is not None else inputs_embeds.device
position_ids = torch.arange(
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
)
position_ids = position_ids.unsqueeze(0)
device = input_ids.device if input_ids is not None else inputs_embeds.device
position_ids = torch.arange(
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
)
position_ids = position_ids.unsqueeze(0)

if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
Expand Down
8 changes: 1 addition & 7 deletions src/transformers/models/opt/modeling_opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,13 +100,7 @@ def forward(self, attention_mask: torch.LongTensor, past_key_values_length: int
attention_mask = attention_mask.long()

# create positions depending on attention_mask
if len(attention_mask.shape) == 2:
positions = (torch.cumsum(attention_mask, dim=-1).type_as(attention_mask) * attention_mask).long() - 1
elif len(attention_mask.shape) == 4:
# assumes 4D mask for efficient beam search
token_positions = torch.cumsum(attention_mask, dim=-1).amax(dim=(1, 2))
used_tokens_mask = attention_mask.amax(dim=(1, 2))
positions = (token_positions * used_tokens_mask).long() - 1
positions = (torch.cumsum(attention_mask, dim=1).type_as(attention_mask) * attention_mask).long() - 1

# cut positions if `past_key_values_length` is > 0
positions = positions[:, past_key_values_length:]
Expand Down

0 comments on commit 204b4b8

Please sign in to comment.