Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Roberta's Positional Embedding Offset #5285

Closed
h324yang opened this issue Jun 25, 2020 · 4 comments
Closed

Roberta's Positional Embedding Offset #5285

h324yang opened this issue Jun 25, 2020 · 4 comments
Labels

Comments

@h324yang
Copy link

num_embeddings += padding_idx + 1 # WHY?

positions = create_position_ids_from_input_ids(input, self.padding_idx)

So this offset is added because the function create_position_ids_from_input_ids shifts the position ids by padding_idx + 1. However, I wonder if other models should also include this?

config.max_position_embeddings, config.hidden_size, padding_idx=self.padding_idx

For instance, when I am using Longformer, it looks like the offset is not added to Roberta, so I need to add such a offset to config.max_position_embeddings

@sshleifer
Copy link
Contributor

That's certainly possible. As you can see from my comment, and PR #5188 , I don't fully understand the motivation for the offset. It is very tricky.

@sshleifer sshleifer changed the title Positional Embedding Offset Roberta's Positional Embedding Offset Jun 25, 2020
@cccntu
Copy link
Contributor

cccntu commented Aug 26, 2020

I figured out why. See here facebookresearch/fairseq#1177
So basically the purpose is to make positional embedding = 0 on padding positions (positions where token is padding token), using the padding_idx parameter in torch.nn.Embedding.

I think we can simply use masked_fill() to make positional embedding = 0 on padding positions, so the code is easier to understand (no need for the offset).

@sshleifer
Copy link
Contributor

Exactly!
Would love to do that, but the migration of the existing bart state dicts is non trivial, since they already store the extra position embedding. Even if we tracked down all bart models with config.static_position_embeddings=False and resized their positional embeddings, we would break code that is not up to date w master (lots of code).

So I think we must settle for documenting what is going on better in LearnedPositionalEmbedding and accept the unfortunate reality that we are stuck with the offset forever (or until we have some futuristic model hub tooling to version state dicts).

@stale
Copy link

stale bot commented Oct 25, 2020

This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

3 participants