-
Notifications
You must be signed in to change notification settings - Fork 28.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add BartModel #2745
Add BartModel #2745
Changes from 1 commit
b5c20db
8420dc5
22ccda1
03d2cf3
d99326e
43c7e21
24fb639
61409b4
0b79f39
dcf2b88
92e487f
e0c54ed
3871a7a
dbe83c9
2373e8a
0dda528
69327e4
d630887
3cbc6ca
51ab277
9e694a7
f355e36
3971d97
f42997f
56c4744
26656a0
4d77a7c
0ce724a
38e057f
a48f89e
2ad6e7b
a772509
831fd14
be62f89
6726d33
c0e9510
28bcf61
1f0b885
6aea2b8
4e7279c
28345b4
effa170
8c7df3a
cee5051
586098d
67b02c6
3811209
28c977b
b79509d
5bc3081
df6edc3
5eaade8
edc492e
7c090b0
1d6cde6
1c06538
73cad04
a68c20e
5d1bc99
c23a07b
67ef42f
7a4a6e2
f80ce45
42e061b
28b1f80
4e008e6
60bd737
a4edf2e
e1d106d
a9b979f
4b97345
ed642cc
87ddeae
4628b7d
a653c78
ab594b4
73f49a6
f7d88db
8f04dd5
459aeaf
9ecee5b
aadf762
bac8348
66310db
3f03344
808bbd5
92b5f6e
21ac214
2196cc2
960af22
a812adc
a8a7839
49f60d7
376a358
8ecdd0d
537af62
02b56df
765c98a
4e1a5e0
4339102
3ce6c1e
fd3d991
b22b368
e5c3485
e2827b1
4d49735
ac1657b
2a1260a
afbfdeb
82877e7
8252075
6bacd55
71c345f
67a4cee
6fd50b3
db3bc84
264f6d1
ba25b7a
8f1e8b4
6124967
c01e719
5dfc207
dafdac8
40f7f79
e7ea674
36e1adc
de2ced0
8b5bb52
c2973d4
dbe0f4e
a42ac9c
5faa0dd
6a08f84
c439e19
6205ba6
cda9ced
f3b4f21
85c3b77
5292ab3
e2353c3
16d2e2e
2ede7ab
cb425f3
360db12
3c6f62d
2d69571
de98500
35d421b
9e66bbc
d546db4
3a37397
9b97322
0f2819c
12b83b9
12becba
77578ac
5990cfe
6cff072
e032d06
5592784
0e0b9b1
4a4723e
4a212a2
086b17a
feaf207
2c8225a
300df06
6db143e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
- Loading branch information
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -36,50 +36,42 @@ | |
} | ||
|
||
BART_START_DOCSTRING = r""" | ||
"BART is a sequence to sequence model which uses a standard Transformer based Translation architecture. | ||
|
||
This model is a PyTorch `torch.nn.Module`_ sub-class. Use it as a regular PyTorch Module and | ||
refer to the PyTorch documentation for all matter related to general usage and behavior. | ||
|
||
.. _`Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer`: | ||
https://arxiv.org/abs/1910.10683 | ||
|
||
.. _`torch.nn.Module`: | ||
https://pytorch.org/docs/stable/nn.html#module | ||
Paper: BART: Denoising Sequence-to-Sequence Pre-training for Natural Language Generation, Translation, and Comprehension | ||
https://arxiv.org/abs/1910.13461 | ||
Authors: Mike Lewis, Yinhan Liu, Naman Goyal, Marjan Ghazvininejad, Abdelrahman Mohamed, Omer Levy, Ves Stoyanov, Luke Zettlemoyer | ||
(Submitted on 29 Oct 2019) | ||
Code Ported from https://github.com/pytorch/fairseq/tree/master/examples/bart | ||
An encoder decoder transformer pre-trained in a text-to-text denoising generative setting. | ||
'BART is a An encoder decoder transformer pre-trained in a text-to-text denoising generative setting.' | ||
This model is a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`_ sub-class. Use it as a regular PyTorch Module and | ||
refer to the PyTorch documentation for all matters related to general usage and behavior. | ||
|
||
`Paper <https://arxiv.org/abs/1910.13461>`_: BART: Denoising Sequence-to-Sequence Pre-training for Natural Language Generation, Translation, and Comprehension | ||
Authors: Mike Lewis, Yinhan Liu, Naman Goyal, Marjan Ghazvininejad, Abdelrahman Mohamed, Omer Levy, Ves Stoyanov, Luke Zettlemoyer | ||
(Submitted on 29 Oct 2019) `Paper` `Paper` | ||
Code Ported from https://github.com/pytorch/fairseq/tree/master/examples/bart | ||
|
||
Parameters: | ||
config (:class:`~transformers.BartConfig`): Model configuration class with all the parameters of the model. | ||
Initializing with a config file does not load the weights associated with the model, only the configuration. | ||
Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights. | ||
|
||
""" | ||
|
||
BART_INPUTS_DOCSTRING = r""" | ||
Inputs: | ||
**input_ids**: ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``: | ||
Indices of input sequence tokens in the vocabulary. Use BartTokenizer.encode to produce them. | ||
Args: | ||
input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`): | ||
Indices of input sequence tokens in the vocabulary. Use BartTokenizer.encode to produce them. | ||
Padding will be ignored by default should you provide it. | ||
Indices can be obtained using :class:`transformers.BartTokenizer.encode(text)`. | ||
Also see :func:`transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details. | ||
**attention_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, sequence_length)``: | ||
Mask to avoid performing attention on padding token indices in the encoder inputs. | ||
Default: a mask will be created that ignore config.pad_token_id | ||
|
||
attention_mask (:obj:`torch.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`): | ||
Warning: this parameter is different from other attention_mask parameters and should be used with caution. | ||
OLD | ||
Mask to avoid performing attention on padding token indices. (in input_ids) | ||
Mask values selected in ``[0, 1]``: | ||
``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens. | ||
**decoder_input_ids**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, sequence_length)``: | ||
decoder_input_ids: (:obj:`torch.LongTensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`, defaults to :obj:`None`): | ||
only use for translation and summarization. Otherwise use the default which shifts the encoder's | ||
input_ids right | ||
**decoder_attention_mask** `optional`) ``torch.FloatTensor`` of shape ``(batch_size, sequence_length)``: | ||
default behavior ignore pad tokens and future tokens. | ||
decoder_attention_mask (:obj:`torch.Tensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`, defaults to :obj:`None`): | ||
default behavior (if None is passed is to ignore pad tokens and future tokens.) | ||
See diagram 1 in the paper for more info on the default strategy | ||
|
||
read `prepare_bart_inputs` for more information on the default behavior. | ||
|
||
""" | ||
LARGE_NEGATIVE = -1e4 | ||
|
||
|
@@ -841,7 +833,6 @@ def _filter_out_falsey_values(tup) -> Tuple: | |
"The bare BART Model outputting raw hidden-states without any specific head on top.", BART_START_DOCSTRING, | ||
) | ||
class BartModel(PretrainedBartModel): | ||
"""""" | ||
|
||
def __init__(self, config: BartConfig): | ||
super().__init__(config) | ||
|
@@ -856,15 +847,6 @@ def __init__(self, config: BartConfig): | |
|
||
self.init_weights() | ||
|
||
def get_input_embeddings(self): | ||
return self.shared | ||
|
||
def set_input_embeddings(self, value): | ||
self.shared = value | ||
|
||
def get_output_embeddings(self): | ||
return _make_linear_from_emb(self.shared) | ||
|
||
@add_start_docstrings_to_callable(BART_INPUTS_DOCSTRING) | ||
def forward( | ||
self, | ||
|
@@ -902,6 +884,17 @@ def forward( | |
encoder_outputs = _filter_out_falsey_values(encoder_outputs) # type: tuple | ||
return decoder_outputs + encoder_outputs | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. for language generation we would need the following variables from |
||
|
||
def get_input_embeddings(self): | ||
return self.shared | ||
|
||
def set_input_embeddings(self, value): | ||
self.shared = value | ||
|
||
def get_output_embeddings(self): | ||
return _make_linear_from_emb(self.shared) | ||
|
||
|
||
|
||
|
||
@add_start_docstrings( | ||
"The bare BART Model with a language modeling head", BART_START_DOCSTRING, | ||
|
@@ -927,29 +920,32 @@ def forward( | |
**unused | ||
): | ||
r""" | ||
**lm_labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``: | ||
Labels for computing the masked language modeling loss. | ||
Indices should either be in ``[0, ..., config.vocab_size]`` or -100 (see ``input_ids`` docstring). | ||
Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens | ||
with labels | ||
in ``[0, ..., config.vocab_size]``. | ||
|
||
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs: | ||
**loss**: (`optional`, returned when ``lm_labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``: | ||
Masked language modeling loss. | ||
**prediction_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, config.vocab_size)`` | ||
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). | ||
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``) | ||
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings) | ||
of shape ``(batch_size, sequence_length, hidden_size)``: | ||
Hidden-states of the model at the output of each layer plus the initial embedding outputs. | ||
**attentions**: (`optional`, returned when ``config.output_attentions=True``) | ||
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, | ||
sequence_length, sequence_length)``: | ||
Attentions weights after the attention softmax, used to compute the weighted average in the | ||
self-attention heads. | ||
masked_lm_labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`): | ||
Labels for computing the masked language modeling loss. | ||
Indices should either be in ``[0, ..., config.vocab_size]`` or -100 (see ``input_ids`` docstring). | ||
Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens | ||
with labels | ||
in ``[0, ..., config.vocab_size]``. | ||
|
||
Returns: | ||
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.RobertaConfig`) and inputs: | ||
masked_lm_loss (`optional`, returned when ``masked_lm_labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``: | ||
Masked language modeling loss. | ||
prediction_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`) | ||
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). | ||
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``): | ||
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) | ||
of shape :obj:`(batch_size, sequence_length, hidden_size)`. | ||
|
||
Examples:: | ||
Hidden-states of the model at the output of each layer plus the initial embedding outputs. | ||
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``): | ||
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape | ||
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`. | ||
|
||
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention | ||
heads. | ||
|
||
Examples:: | ||
|
||
tokenizer = BartTokenizer.from_pretrained('bart-large') | ||
model = BartForMaskedLM.from_pretrained('bart-large') | ||
|
@@ -1008,46 +1004,39 @@ def forward( | |
labels=None, | ||
): | ||
r""" | ||
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`): | ||
Labels for computing the sequence classification/regression loss. | ||
Indices should be in :obj:`[0, ..., config.num_labels - 1]`. | ||
If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss), | ||
If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy). | ||
Returns: | ||
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration ( | ||
:class:`~transformers.BartConfig`) and inputs: | ||
loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`label` is | ||
provided): | ||
Classification loss. | ||
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`): | ||
Labels for computing the sequence classification/regression loss. | ||
Indices should be in :obj:`[0, ..., config.num_labels - 1]`. | ||
If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy). | ||
|
||
Returns: | ||
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BartConfig`) and inputs: | ||
loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`label` is provided): | ||
Classification loss (cross entropy) | ||
logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, config.num_labels)`): | ||
Classification (or regression if config.num_labels==1) scores (before SoftMax). | ||
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when | ||
``config.output_hidden_states=True``): | ||
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of | ||
each layer) | ||
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``): | ||
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) | ||
of shape :obj:`(batch_size, sequence_length, hidden_size)`. | ||
Hidden-states of the model at the output of each layer plus the initial embedding outputs. | ||
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when | ||
``config.output_attentions=True``): | ||
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape | ||
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`. | ||
|
||
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``): | ||
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length, sequence_length)`. | ||
Attentions weights after the attention softmax, used to compute the weighted average in the | ||
self-attention | ||
heads. | ||
|
||
Examples:: | ||
Examples:: | ||
|
||
from transformers import BartTokenizer, BartForSequenceClassification | ||
import torch | ||
from transformers import BartTokenizer, BartForSequenceClassification | ||
import torch | ||
|
||
tokenizer = BartTokenizer.from_pretrained('bart-large') | ||
model = BartForSequenceClassification.from_pretrained('bart-large') | ||
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", | ||
add_special_tokens=True)).unsqueeze(0) # Batch size 1 | ||
labels = torch.tensor([1]).unsqueeze(0) # Batch size 1 | ||
outputs = model(input_ids, labels=labels) | ||
loss, logits = outputs[:2] | ||
tokenizer = BartTokenizer.from_pretrained('bart-large') | ||
model = BartForSequenceClassification.from_pretrained('bart-large') | ||
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", | ||
add_special_tokens=True)).unsqueeze(0) # Batch size 1 | ||
labels = torch.tensor([1]).unsqueeze(0) # Batch size 1 | ||
outputs = model(input_ids, labels=labels) | ||
loss, logits = outputs[:2] | ||
|
||
""" | ||
outputs = self.model.forward( | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We now link to the inputs in the forward method, cf. BERT file