Skip to content

Commit

Permalink
Fix shape issue of return_all_hiddens in roberta (#1438)
Browse files Browse the repository at this point in the history
Summary:
By default `return_all_hiddens` is False, the shape of `features` will be BxTxC.
If use `return_all_hiddens`, the shape of `features` will be TxBxC.
See
https://github.com/pytorch/fairseq/blob/9398a2829596393b73f5c5f1b99edf4c2d8f9316/fairseq/modules/transformer_sentence_encoder.py#L227
Pull Request resolved: #1438

Differential Revision: D18809509

Pulled By: myleott

fbshipit-source-id: 696b395934e2b7e5807387069fe1da49a4df98c7
  • Loading branch information
HMJiangGatech authored and facebook-github-bot committed Dec 4, 2019
1 parent 72bcb9d commit f8180af
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 8 deletions.
5 changes: 3 additions & 2 deletions fairseq/models/roberta/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,8 @@ def forward(self, src_tokens, features_only=False, return_all_hiddens=False, mas
tuple:
- the LM output of shape `(batch, src_len, vocab)`
- a dictionary of additional data, where 'inner_states'
is a list of hidden states.
is a list of hidden states. Note that the hidden
states have shape `(src_len, batch, vocab)`.
"""
x, extra = self.extract_features(src_tokens, return_all_hiddens=return_all_hiddens)
if not features_only:
Expand All @@ -358,7 +359,7 @@ def extract_features(self, src_tokens, return_all_hiddens=False, **unused):
src_tokens,
last_state_only=not return_all_hiddens,
)
features = inner_states[-1]
features = inner_states[-1].transpose(0, 1) # T x B x C -> B x T x C
return features, {'inner_states': inner_states if return_all_hiddens else None}

def output_layer(self, features, masked_tokens=None, **unused):
Expand Down
8 changes: 2 additions & 6 deletions fairseq/modules/transformer_sentence_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ class TransformerSentenceEncoder(nn.Module):
Output:
- a tuple of the following:
- a list of internal model states used to compute the
predictions where each tensor has shape B x T x C
predictions where each tensor has shape T x B x C
- sentence representation associated with first input token
in format B x C.
"""
Expand Down Expand Up @@ -222,11 +222,7 @@ def forward(
if not last_state_only:
inner_states.append(x)


# T x B x C -> B x T x C
x = x.transpose(0, 1)

sentence_rep = x[:, 0, :]
sentence_rep = x[0, :, :]

if last_state_only:
inner_states = [x]
Expand Down

0 comments on commit f8180af

Please sign in to comment.