From f8180af3555ba8b185af95f71b14863ea9146cc3 Mon Sep 17 00:00:00 2001 From: Haoming Jiang Date: Wed, 4 Dec 2019 15:37:25 -0800 Subject: [PATCH] Fix shape issue of `return_all_hiddens` in roberta (#1438) Summary: By default `return_all_hiddens` is False, the shape of `features` will be BxTxC. If use `return_all_hiddens`, the shape of `features` will be TxBxC. See https://github.com/pytorch/fairseq/blob/9398a2829596393b73f5c5f1b99edf4c2d8f9316/fairseq/modules/transformer_sentence_encoder.py#L227 Pull Request resolved: https://github.com/pytorch/fairseq/pull/1438 Differential Revision: D18809509 Pulled By: myleott fbshipit-source-id: 696b395934e2b7e5807387069fe1da49a4df98c7 --- fairseq/models/roberta/model.py | 5 +++-- fairseq/modules/transformer_sentence_encoder.py | 8 ++------ 2 files changed, 5 insertions(+), 8 deletions(-) diff --git a/fairseq/models/roberta/model.py b/fairseq/models/roberta/model.py index b9ca1c012e..c72fd2e54a 100644 --- a/fairseq/models/roberta/model.py +++ b/fairseq/models/roberta/model.py @@ -346,7 +346,8 @@ def forward(self, src_tokens, features_only=False, return_all_hiddens=False, mas tuple: - the LM output of shape `(batch, src_len, vocab)` - a dictionary of additional data, where 'inner_states' - is a list of hidden states. + is a list of hidden states. Note that the hidden + states have shape `(src_len, batch, vocab)`. """ x, extra = self.extract_features(src_tokens, return_all_hiddens=return_all_hiddens) if not features_only: @@ -358,7 +359,7 @@ def extract_features(self, src_tokens, return_all_hiddens=False, **unused): src_tokens, last_state_only=not return_all_hiddens, ) - features = inner_states[-1] + features = inner_states[-1].transpose(0, 1) # T x B x C -> B x T x C return features, {'inner_states': inner_states if return_all_hiddens else None} def output_layer(self, features, masked_tokens=None, **unused): diff --git a/fairseq/modules/transformer_sentence_encoder.py b/fairseq/modules/transformer_sentence_encoder.py index f35babb899..674a725052 100644 --- a/fairseq/modules/transformer_sentence_encoder.py +++ b/fairseq/modules/transformer_sentence_encoder.py @@ -64,7 +64,7 @@ class TransformerSentenceEncoder(nn.Module): Output: - a tuple of the following: - a list of internal model states used to compute the - predictions where each tensor has shape B x T x C + predictions where each tensor has shape T x B x C - sentence representation associated with first input token in format B x C. """ @@ -222,11 +222,7 @@ def forward( if not last_state_only: inner_states.append(x) - - # T x B x C -> B x T x C - x = x.transpose(0, 1) - - sentence_rep = x[:, 0, :] + sentence_rep = x[0, :, :] if last_state_only: inner_states = [x]