-
Notifications
You must be signed in to change notification settings - Fork 27k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Current BART Position Embeddings Implementation Seems Wrong #19240
Comments
This comment was posted because the issue had been automatically marked as stale due to lack of recent activity, but it was flagged by a user as an issue that still needs to be addressed. |
Might be of interest to @ArthurZucker |
Hey! So this was answered in #10200, where it the Tell me if this answers your question! |
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
Longer answer: >>> embed_tokens = nn.embedding(vocab_dim, hidden_dim, padding_idx) Is that this makes the positions at index What if you change the padding index to something bigger? Let’s say → Potential usage: Imagine if you need a new starting token in your BartModel. The padding token will no longer be 2 but 4. This means you just want to shift the inputs learned positions by 2, not that you want to zero-out the learned position embedding at position 4. # during training
>>> input_ids = [ 3, 13, 25, 1, 1 ,1 ,1]
>>> pad_token_id = 1
>>> positions = [ 0, 1, 2, 3, 4, 5, 6]
>>> pw_offset = [ 2, 3, 4, 5, 6, 7, 8]
>>> embedding = [ X2, X3, X4, X5, X6, X7, X8]
# finetuning with one more token
>>> new_pad_token_id = 4 # but the position of the padding token is not necessarly 2
>>> input_ids = [ 1, 2, 13, 25, 1, 1, 1, 1]
>>> positions = [ 0, 1, 2, 3, 4, 5, 6, 7]
>>> pw_offset = [ 2, 3, 4, 5, 6, 7, 8, 9]
>>> embedding = [ X2, X3, 0, X5, X6, X7, X8, X9]
# With the code fix:
# finetuning with one more token
>>> new_pad_token_id = 4 # but the position of the padding token is not necessarly 2
>>> input_ids = [ 1, 2, 13, 25, 1, 1, 1, 1]
>>> positions = [ 0, 1, 2, 3, 4, 5, 6, 7]
>>> pw_offset = [ 2, 3, 4, 5, 6, 7, 8, 9]
>>> embedding = [ X2, X3, X4, X5, X6, X7, X8, X9] If you zero-out the embeddings corresponding to the index of the padding token, changing the ID of the padding token will result in a change of the inputs that are positioned at this index. The subtil difference is that it does not matter if your padding token has index 0, 1, or 999. The tokens that are at the position of the index ( let’s say the 999th token) should not have a zeroed-out embedding. But, if the token at that position is a padding token, then the loss will not make it contribute. If we zero out at index 4, the 4th token will never have a learned positional embedding. |
System Info
transformers
version: 4.22.2Who can help?
@patil-suraj
Information
Tasks
examples
folder (such as GLUE/SQuAD, ...)Reproduction
I understand that both BART and RoBERTa came from Facebook with original implementation in FairSeq and BART's offset in position embedding is copied from RoBERTa's implementation. When facebook first implemented RoBERTa in fairseq they had the offset and HF also copied that in their implementation. Based on this issue and this other issue I understand that motivation for this was to use
nn.Embedding
padding_idx
to make sure we don't learn position vector padding tokens. Since both RoBERTa and BART havepadding_id
as 1 we pass inpadding_idx=1
innn.Embedding
for the position embedding table. And therefore, the non-padding tokens get offset bypadding_idx + 1
andnum_embeddings += padding_idx + 1
.And going through the old transformers BART code here and here the code makes sense. And on an example input the behavior ofcreate_position_ids_from_input_ids
makes sense: we offset the position ids of non-padding tokens and padding tokens get assigned position 1But starting some transformers version (perhaps 4.8.0) BART code was changed and it replaced
LearnedPositionEmbedding
andcreate_position_ids_from_input_ids
functions with just this: BartLearnedPositionalEmbedding which gets used in both encoder and decoder. The hard-coded offset of 2 here makes sense because the BART's pad token is 1 sopadding_idx=1 + 1 = 2
and we offset the non-padding tokens by 2. But what is not clear is why we are not even passing inpadding_idx
tonn.Embedding
constructor anymore unlike the old implementation because I thought the whole point for offset was to use the padding_idx in nn.Embedding for padding token.Also in current implementation we add offset to all tokens (including pad tokens) which means padding positions are also learned in the current version of BART and doesn't that defeat the whole point of having the offset. Is current implementation of BART position embeddings wrong?
Here is example of modified BartLearnedPositionalEmbedding which return positions + offset
Expected behavior
Expected new
BartLearnedPositionalEmbedding
to pass inpadding_idx=1
tonn.Embedding
constructor and add the offset to only non-pad tokens to prevent learning padding positions like old implementationThe text was updated successfully, but these errors were encountered: