Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add BartModel #2745

Merged
merged 168 commits into from
Feb 20, 2020
Merged

Add BartModel #2745

merged 168 commits into from
Feb 20, 2020

Conversation

sshleifer
Copy link
Contributor

@sshleifer sshleifer commented Feb 5, 2020

This ports BART, a "sequence-to-sequence model trained with denoising as pretraining objective." from https://github.com/pytorch/fairseq/tree/master/examples/bart
The decoder is left-to-right, the encoder is biredictional. As such, the code only uses a causal attention mask in the decoder.

TODO:

)

  • Docstrings
  • More comments for code readers

Future PRs

  • example with correct pretraining objective
  • BartForSummarization.from_pretrained('bart-large-cnn')

@LysandreJik LysandreJik self-requested a review February 18, 2020 17:07
Copy link
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm re-reviewing this to add a few comments related to the documentation and what should be updated for this model to be correctly displayed in the docs.

Left a few comments at the appropriate places, you will have to adapt for the three models (base, masked lm and sequence classification)

Comment on lines 63 to 81
Inputs:
**input_ids**: ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
Indices of input sequence tokens in the vocabulary. Use BartTokenizer.encode to produce them.
Padding will be ignored by default should you provide it.
Indices can be obtained using :class:`transformers.BartTokenizer.encode(text)`.
Also see :func:`transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details.
**attention_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, sequence_length)``:
Mask to avoid performing attention on padding token indices in the encoder inputs.
Default: a mask will be created that ignore config.pad_token_id
Mask values selected in ``[0, 1]``:
``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
**decoder_input_ids**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, sequence_length)``:
only use for translation and summarization. Otherwise use the default which shifts the encoder's
input_ids right
**decoder_attention_mask** `optional`) ``torch.FloatTensor`` of shape ``(batch_size, sequence_length)``:
default behavior ignore pad tokens and future tokens.
See diagram 1 in the paper for more info on the default strategy

read `prepare_bart_inputs` for more information on the default behavior.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We've switched to a more uniform format, where you would have to

  1. rename the section from "Inputs" to "Args" for readthedocs/sphinx (our doc generator) to understand it
  2. link to the glossary
  3. If possible re-use as similar docstrings as possible to the other models. Using different docstrings with different vocabulary is bound to confuse users.
  4. (Optional) the glossary currently doesn't contain any information related to the seq2seq models. It would be great if there was, but it is a lengthy process so it might be a better idea to do it once BART is wrapped up. Let me know if this is something that would be interesting for you.

You can check an example in the BERT file.

Comment on lines 826 to 830
@add_start_docstrings(
"The bare BART Model outputting raw hidden-states without any specific head on top.",
BART_START_DOCSTRING,
BART_INPUTS_DOCSTRING,
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We only link to the start docstrings now in the add_start_docstrings decorator.

def get_output_embeddings(self):
return _make_linear_from_emb(self.shared)

def forward(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We now link to the inputs in the forward method, cf. BERT file

Comment on lines 897 to 926
r"""
**lm_labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
Labels for computing the masked language modeling loss.
Indices should either be in ``[0, ..., config.vocab_size]`` or -100 (see ``input_ids`` docstring).
Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens with labels
in ``[0, ..., config.vocab_size]``.

Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
**loss**: (`optional`, returned when ``lm_labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
Masked language modeling loss.
**prediction_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, config.vocab_size)``
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
of shape ``(batch_size, sequence_length, hidden_size)``:
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.

Examples::

tokenizer = BartTokenizer.from_pretrained('bart-large')
model = BartForMaskedLM.from_pretrained('bart-large')
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
outputs = model(input_ids=input_ids, lm_labels=input_ids)
loss, prediction_scores = outputs[:2]

"""
base_model_prefix = "model"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Additional inputs/outputs not detailed in the START/INPUTS docstrings are now added to the forward method as well, cf. BertForPreTraining.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, mind how the format changed from "attentions: (optional, returned when config.output_attentions=True)" to "attentions (:obj:tuple(torch.FloatTensor), optional, returned when config.output_attentions=True):".

I believe you can copy and paste most of it.

@sshleifer sshleifer merged commit 53ce385 into huggingface:master Feb 20, 2020
@sshleifer sshleifer deleted the bart branch February 20, 2020 23:11
jplu pushed a commit to jplu/transformers that referenced this pull request Mar 25, 2020
* Results same as fairseq
* Wrote a ton of tests
* Struggled with api signatures
* added some docs
self.dropout = dropout

# Classifier stuff
self.classif_dropout = classifier_dropout
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Won't this name mismatch cause the saved value (saved by save_pretrained()) not being loaded to the config by the from_pretrained() method?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have no clue what problem you are trying to describe. Please file an issue with a pasteable code snippet that has a different output than you expected.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok filed #7591

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

🌟 BART
8 participants