Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Current BART Position Embeddings Implementation Seems Wrong #19240

Closed
4 tasks
kshitizgupta21 opened this issue Sep 28, 2022 · 5 comments
Closed
4 tasks

Current BART Position Embeddings Implementation Seems Wrong #19240

kshitizgupta21 opened this issue Sep 28, 2022 · 5 comments
Assignees
Labels

Comments

@kshitizgupta21
Copy link

kshitizgupta21 commented Sep 28, 2022

System Info

  • transformers version: 4.22.2
  • Platform: Linux-5.10.16.3-microsoft-standard-WSL2-x86_64-with-glibc2.10
  • Python version: 3.8.13
  • Huggingface_hub version: 0.10.0
  • PyTorch version (GPU?): 1.13.0a0+08820cb (True)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using GPU in script?: no
  • Using distributed or parallel set-up in script?: no

Who can help?

@patil-suraj

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

I understand that both BART and RoBERTa came from Facebook with original implementation in FairSeq and BART's offset in position embedding is copied from RoBERTa's implementation. When facebook first implemented RoBERTa in fairseq they had the offset and HF also copied that in their implementation. Based on this issue and this other issue I understand that motivation for this was to use nn.Embedding padding_idx to make sure we don't learn position vector padding tokens. Since both RoBERTa and BART have padding_id as 1 we pass in padding_idx=1 in nn.Embedding for the position embedding table. And therefore, the non-padding tokens get offset by padding_idx + 1 and num_embeddings += padding_idx + 1.And going through the old transformers BART code here and here the code makes sense. And on an example input the behavior of create_position_ids_from_input_ids makes sense: we offset the position ids of non-padding tokens and padding tokens get assigned position 1

import torch 
from transformers import AutoTokenizer, AutoModel
def create_position_ids_from_input_ids(input_ids, padding_idx):
    """ Replace non-padding symbols with their position numbers. Position numbers begin at
    padding_idx+1. Padding symbols are ignored. This is modified from fairseq's
    `utils.make_positions`.
    :param torch.Tensor x:
    :return torch.Tensor:
    """
    # The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA.
    mask = input_ids.ne(padding_idx).int()
    incremental_indices = torch.cumsum(mask, dim=1).type_as(mask) * mask
    return incremental_indices.long() + padding_idx

tokenizer = AutoTokenizer.from_pretrained("facebook/bart-base")

bart = AutoModel.from_pretrained("facebook/bart-base")

batch_sentences = [
    "But what about second breakfast?",
    "Don't think he knows about second breakfast, Pip.",
    "What about elevensies?",
]
inputs = tokenizer(batch_sentences, padding=True, return_tensors="pt")
input_ids = tokenizer(batch_sentences, padding=True, return_tensors="pt").input_ids
create_position_ids_from_input_ids(input_ids, padding_idx=1)
tensor([[ 2,  3,  4,  5,  6,  7,  8,  9,  1,  1,  1,  1,  1],
        [ 2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14],
        [ 2,  3,  4,  5,  6,  7,  8,  9,  1,  1,  1,  1,  1]]) 

But starting some transformers version (perhaps 4.8.0) BART code was changed and it replaced LearnedPositionEmbedding and create_position_ids_from_input_ids functions with just this: BartLearnedPositionalEmbedding which gets used in both encoder and decoder. The hard-coded offset of 2 here makes sense because the BART's pad token is 1 so padding_idx=1 + 1 = 2 and we offset the non-padding tokens by 2. But what is not clear is why we are not even passing in padding_idx to nn.Embedding constructor anymore unlike the old implementation because I thought the whole point for offset was to use the padding_idx in nn.Embedding for padding token.

Also in current implementation we add offset to all tokens (including pad tokens) which means padding positions are also learned in the current version of BART and doesn't that defeat the whole point of having the offset. Is current implementation of BART position embeddings wrong?

Here is example of modified BartLearnedPositionalEmbedding which return positions + offset

class BartLearnedPositionalEmbedding(nn.Embedding):
    """
    This module learns positional embeddings up to a fixed maximum size.
    """

    def __init__(self, num_embeddings: int, embedding_dim: int):
        # Bart is set up so that if padding_idx is specified then offset the embedding ids by 2
        # and adjust num_embeddings appropriately. Other models don't have this hack
        self.offset = 2
        super().__init__(num_embeddings + self.offset, embedding_dim)

    def forward(self, input_ids: torch.Tensor, past_key_values_length: int = 0):
        """`input_ids' shape is expected to be [bsz x seqlen]."""

        bsz, seq_len = input_ids.shape[:2]
        positions = torch.arange(
            past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device
        ).expand(bsz, -1)
        return positions + self.offset
pos_embed = BartLearnedPositionalEmbedding(num_embeddings=1024, embedding_dim=512)      
pos_embed(input_ids)
tensor([[ 2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14],
        [ 2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14],
        [ 2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14]])

Expected behavior

Expected new BartLearnedPositionalEmbedding to pass in padding_idx=1 to nn.Embedding constructor and add the offset to only non-pad tokens to prevent learning padding positions like old implementation

@ayaka14732
Copy link
Contributor

This comment was posted because the issue had been automatically marked as stale due to lack of recent activity, but it was flagged by a user as an issue that still needs to be addressed.

@sgugger
Copy link
Collaborator

sgugger commented Oct 31, 2022

Might be of interest to @ArthurZucker

@ArthurZucker ArthurZucker self-assigned this Oct 31, 2022
@ArthurZucker
Copy link
Collaborator

Hey! So this was answered in #10200, where it the padding_idx was removed. It explains that adding padding_idx prevents the model from ever learning the first position (and other positional_tokens) that can be set to 0.

Tell me if this answers your question!

@huggingface huggingface deleted a comment from github-actions bot Nov 17, 2022
@github-actions
Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@ArthurZucker
Copy link
Collaborator

Longer answer:
The reason why we are not using

>>> embed_tokens = nn.embedding(vocab_dim, hidden_dim, padding_idx)

Is that this makes the positions at index padding_idx un-learnable , and it zeros them out.

What if you change the padding index to something bigger? Let’s say 4 then the embedding at index 4 will be zeroed out ( basically erased ) but for the model, that means that when it will never receive the embedding that should be at position 4.

→ Potential usage: Imagine if you need a new starting token in your BartModel. The padding token will no longer be 2 but 4. This means you just want to shift the inputs learned positions by 2, not that you want to zero-out the learned position embedding at position 4.
Snippet:

# during training
>>> input_ids = [  3, 13, 25, 1, 1 ,1 ,1]
>>> pad_token_id = 1
>>> positions = [  0,  1,  2,  3,  4,  5,  6]
>>> pw_offset = [  2,  3,  4,  5,  6,  7,  8]  
>>> embedding = [ X2, X3, X4, X5, X6, X7, X8] 

# finetuning with one more token
>>> new_pad_token_id = 4 # but the position of the padding token is not necessarly 2
>>> input_ids = [  1,    2, 13,  25,  1,  1,  1,  1]
>>> positions = [  0,    1,  2,   3,  4,  5,  6,  7]
>>> pw_offset = [  2,    3,  4,   5,  6,  7,  8,  9] 
>>> embedding = [  X2,   X3, 0,  X5, X6,  X7, X8, X9]  

# With the code fix:
# finetuning with one more token
>>> new_pad_token_id = 4 # but the position of the padding token is not necessarly 2
>>> input_ids = [  1,    2, 13,  25,  1,  1,  1,  1]
>>> positions = [  0,    1,  2,   3,  4,  5,  6,  7]
>>> pw_offset = [  2,    3,  4,   5,  6,  7,  8,  9] 
>>> embedding = [  X2,   X3, X4,  X5, X6,  X7, X8, X9] 

If you zero-out the embeddings corresponding to the index of the padding token, changing the ID of the padding token will result in a change of the inputs that are positioned at this index.

The subtil difference is that it does not matter if your padding token has index 0, 1, or 999.

The tokens that are at the position of the index ( let’s say the 999th token) should not have a zeroed-out embedding. But, if the token at that position is a padding token, then the loss will not make it contribute.

If we zero out at index 4, the 4th token will never have a learned positional embedding.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

4 participants