-
Notifications
You must be signed in to change notification settings - Fork 29k
model forwards can take an inputs_embeds param #1695
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Codecov Report
@@ Coverage Diff @@
## master #1695 +/- ##
==========================================
+ Coverage 83.95% 84.03% +0.07%
==========================================
Files 94 94
Lines 13951 14021 +70
==========================================
+ Hits 11713 11782 +69
- Misses 2238 2239 +1
Continue to review full report at Codecov.
|
elif inputs_embeds is not None: | ||
input_shape = inputs_embeds.size()[:-1] | ||
else: | ||
raise ValueError("You have to specify either input_ids or inputs_embeds") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we raise an error if both are provided?
If not we should be explicit about which one has priority in the docstring (this input should be in the docstring anyway).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
indeed
position_ids = torch.arange(past_length, input_ids.size(-1) + past_length, dtype=torch.long, device=input_ids.device) | ||
position_ids = position_ids.unsqueeze(0).expand_as(input_ids) | ||
device = input_ids.device if input_ids is not None else inputs_embeds.device | ||
position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👍
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, for me, just check the comment and add the input in the docstring.
Also, I think we should add this input in all the models.
a4007f2
to
9eddf44
Compare
input_ids = input_ids.transpose(0, 1).contiguous() | ||
qlen, bsz = input_ids.size() | ||
elif inputs_embeds is not None: | ||
inputs_embeds = inputs_embeds.transpose(0, 1).contiguous() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👍
model.eval() | ||
|
||
wte = model.get_input_embeddings() | ||
inputs_dict["inputs_embeds"] = wte(input_ids) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
perfect
LGTM. |
No description provided.