Skip to content

model forwards can take an inputs_embeds param #1695

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Nov 5, 2019
Merged

Conversation

julien-c
Copy link
Member

@julien-c julien-c commented Nov 1, 2019

No description provided.

@huggingface huggingface deleted a comment from codecov-io Nov 1, 2019
@codecov-io
Copy link

codecov-io commented Nov 1, 2019

Codecov Report

Merging #1695 into master will increase coverage by 0.07%.
The diff coverage is 98.17%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master    #1695      +/-   ##
==========================================
+ Coverage   83.95%   84.03%   +0.07%     
==========================================
  Files          94       94              
  Lines       13951    14021      +70     
==========================================
+ Hits        11713    11782      +69     
- Misses       2238     2239       +1
Impacted Files Coverage Δ
transformers/modeling_tf_distilbert.py 98.59% <ø> (ø) ⬆️
transformers/modeling_tf_gpt2.py 94.79% <ø> (ø) ⬆️
transformers/modeling_tf_bert.py 96.6% <ø> (ø) ⬆️
transformers/modeling_tf_xlnet.py 87.82% <ø> (ø) ⬆️
transformers/modeling_tf_openai.py 96.04% <ø> (ø) ⬆️
transformers/modeling_tf_utils.py 92.4% <ø> (ø) ⬆️
transformers/modeling_tf_ctrl.py 97.75% <ø> (ø) ⬆️
transformers/modeling_tf_transfo_xl.py 92.21% <ø> (ø) ⬆️
transformers/modeling_tf_xlm.py 90.39% <ø> (ø) ⬆️
transformers/modeling_tf_roberta.py 89.9% <ø> (ø) ⬆️
... and 11 more

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 68f7064...00337e9. Read the comment docs.

@julien-c julien-c requested a review from thomwolf November 1, 2019 17:44
julien-c added a commit to w4nderlust/transformers that referenced this pull request Nov 1, 2019
elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we raise an error if both are provided?
If not we should be explicit about which one has priority in the docstring (this input should be in the docstring anyway).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

indeed

position_ids = torch.arange(past_length, input_ids.size(-1) + past_length, dtype=torch.long, device=input_ids.device)
position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
device = input_ids.device if input_ids is not None else inputs_embeds.device
position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

Copy link
Member

@thomwolf thomwolf left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, for me, just check the comment and add the input in the docstring.
Also, I think we should add this input in all the models.

@julien-c julien-c mentioned this pull request Nov 1, 2019
@julien-c julien-c force-pushed the models_inputs_embeds branch from a4007f2 to 9eddf44 Compare November 4, 2019 17:19
input_ids = input_ids.transpose(0, 1).contiguous()
qlen, bsz = input_ids.size()
elif inputs_embeds is not None:
inputs_embeds = inputs_embeds.transpose(0, 1).contiguous()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

model.eval()

wte = model.get_input_embeddings()
inputs_dict["inputs_embeds"] = wte(input_ids)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

perfect

@thomwolf
Copy link
Member

thomwolf commented Nov 5, 2019

LGTM.
Feel free to add the TF version or merge if you don't want to add them now.

@julien-c julien-c merged commit 7daacf0 into master Nov 5, 2019
@julien-c julien-c deleted the models_inputs_embeds branch November 6, 2019 00:59
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants