forked from huggingface/transformers
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Results same as fairseq * Wrote a ton of tests * Struggled with api signatures * added some docs
- Loading branch information
Showing
20 changed files
with
1,766 additions
and
59 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
Bart | ||
---------------------------------------------------- | ||
**DISCLAIMER:** This model is still a work in progress, if you see something strange, | ||
file a `Github Issue <https://github.com/huggingface/transformers/issues/new?assignees=&labels=&template=bug-report.md&title>`__ and assign | ||
@sshleifer | ||
|
||
The Bart model was `proposed <https://arxiv.org/abs/1910.13461>`_ by Mike Lewis, Yinhan Liu, Naman Goyal, Marjan Ghazvininejad, Abdelrahman Mohamed, Omer Levy, Ves Stoyanov, Luke Zettlemoyer on 29 Oct, 2019. | ||
It is a sequence to sequence model where both encoder and decoder are transformers. The paper also introduces a novel pretraining objective, and demonstrates excellent summarization results. | ||
The authors released their code `here <https://github.com/pytorch/fairseq/tree/master/examples/bart>`_ | ||
|
||
**Abstract:** | ||
|
||
*We present BART, a denoising autoencoder for pretraining sequence-to-sequence models. BART is trained by (1) corrupting text with an arbitrary noising function, and (2) learning a model to reconstruct the original text. It uses a standard Tranformer-based neural machine translation architecture which, despite its simplicity, can be seen as generalizing BERT (due to the bidirectional encoder), GPT (with the left-to-right decoder), and many other more recent pretraining schemes. We evaluate a number of noising approaches, finding the best performance by both randomly shuffling the order of the original sentences and using a novel in-filling scheme, where spans of text are replaced with a single mask token. BART is particularly effective when fine tuned for text generation but also works well for comprehension tasks. It matches the performance of RoBERTa with comparable training resources on GLUE and SQuAD, achieves new state-of-the-art results on a range of abstractive dialogue, question answering, and summarization tasks, with gains of up to 6 ROUGE. BART also provides a 1.1 BLEU increase over a back-translation system for machine translation, with only target language pretraining. We also report ablation experiments that replicate other pretraining schemes within the BART framework, to better measure which factors most influence end-task performance.* | ||
`BART: Denoising Sequence-to-Sequence Pre-training for Natural Language Generation, Translation, and Comprehension` | ||
|
||
|
||
Notes: | ||
- Bart doesn't use :obj:`token_type_ids`, for sequence classification just use BartTokenizer.encode to get the proper splitting. | ||
- Inputs to the decoder are created by BartModel.forward if they are not passed. This is different than some other model APIs. | ||
- Model predictions are intended to be identical to the original implementation. This only works, however, if the string you pass to fairseq.encode starts with a space. | ||
|
||
BartModel | ||
~~~~~~~~~~~~~~~~~~~~ | ||
|
||
.. autoclass:: transformers.BartModel | ||
:members: forward | ||
|
||
|
||
BartForMaskedLM | ||
~~~~~~~~~~~~~~~~~~~~~~~~~~ | ||
|
||
.. autoclass:: transformers.BartForMaskedLM | ||
:members: forward | ||
|
||
|
||
BartForSequenceClassification | ||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | ||
|
||
.. autoclass:: transformers.BartForSequenceClassification | ||
:members: forward | ||
|
||
BartConfig | ||
~~~~~~~~~~~~~~~~~~~~~ | ||
|
||
.. autoclass:: transformers.BartConfig | ||
:members: | ||
|
||
Automatic Creation of Decoder Inputs | ||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | ||
This is enabled by default | ||
|
||
.. autofunction:: transformers.modeling_bart._prepare_bart_decoder_inputs |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,101 @@ | ||
# coding=utf-8 | ||
# Copyright 2020 The Fairseq Authors and The HuggingFace Inc. team. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
""" BART configuration """ | ||
|
||
|
||
import logging | ||
|
||
from .configuration_utils import PretrainedConfig | ||
|
||
|
||
logger = logging.getLogger(__name__) | ||
|
||
_bart_large_url = "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-large/config.json" | ||
BART_PRETRAINED_CONFIG_ARCHIVE_MAP = { | ||
"bart-large": _bart_large_url, | ||
"bart-large-mnli": _bart_large_url, # fine as same | ||
"bart-cnn": None, # not done | ||
} | ||
|
||
|
||
class BartConfig(PretrainedConfig): | ||
r""" | ||
Configuration class for Bart. Parameters are renamed from the fairseq implementation | ||
""" | ||
model_type = "bart" | ||
pretrained_config_archive_map = BART_PRETRAINED_CONFIG_ARCHIVE_MAP | ||
|
||
def __init__( | ||
self, | ||
activation_dropout=0.0, | ||
vocab_size=50265, | ||
pad_token_id=1, | ||
eos_token_id=2, | ||
d_model=1024, | ||
encoder_ffn_dim=4096, | ||
encoder_layers=12, | ||
encoder_attention_heads=16, | ||
decoder_ffn_dim=4096, | ||
decoder_layers=12, | ||
decoder_attention_heads=16, | ||
encoder_layerdrop=0.0, | ||
decoder_layerdrop=0.0, | ||
attention_dropout=0.0, | ||
dropout=0.1, | ||
max_position_embeddings=1024, | ||
init_std=0.02, | ||
classifier_dropout=0.0, | ||
output_past=False, | ||
num_labels=3, | ||
**common_kwargs | ||
): | ||
r""" | ||
:class:`~transformers.BartConfig` is the configuration class for `BartModel`. | ||
Examples: | ||
config = BartConfig.from_pretrained('bart-large') | ||
model = BartModel(config) | ||
""" | ||
super().__init__(num_labels=num_labels, output_past=output_past, pad_token_id=pad_token_id, **common_kwargs) | ||
|
||
self.vocab_size = vocab_size | ||
self.d_model = d_model # encoder_embed_dim and decoder_embed_dim | ||
self.eos_token_id = eos_token_id | ||
|
||
self.encoder_ffn_dim = encoder_ffn_dim | ||
self.encoder_layers = self.num_hidden_layers = encoder_layers | ||
self.encoder_attention_heads = encoder_attention_heads | ||
self.encoder_layerdrop = encoder_layerdrop | ||
self.decoder_layerdrop = decoder_layerdrop | ||
self.decoder_ffn_dim = decoder_ffn_dim | ||
self.decoder_layers = decoder_layers | ||
self.decoder_attention_heads = decoder_attention_heads | ||
self.max_position_embeddings = max_position_embeddings | ||
self.init_std = init_std # Normal(0, this parameter) | ||
|
||
# 3 Types of Dropout | ||
self.attention_dropout = attention_dropout | ||
self.activation_dropout = activation_dropout | ||
self.dropout = dropout | ||
|
||
# Classifier stuff | ||
self.classif_dropout = classifier_dropout | ||
|
||
@property | ||
def num_attention_heads(self): | ||
return self.encoder_attention_heads | ||
|
||
@property | ||
def hidden_size(self): | ||
return self.d_model |
100 changes: 100 additions & 0 deletions
100
src/transformers/convert_bart_original_pytorch_checkpoint_to_pytorch.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,100 @@ | ||
# coding=utf-8 | ||
# Copyright 2020 The HuggingFace Inc. team. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
"""Convert BART checkpoint.""" | ||
|
||
|
||
import argparse | ||
import logging | ||
from pathlib import Path | ||
|
||
import fairseq | ||
import torch | ||
from packaging import version | ||
|
||
from transformers import BartConfig, BartForSequenceClassification, BartModel, BartTokenizer | ||
|
||
|
||
if version.parse(fairseq.__version__) < version.parse("0.9.0"): | ||
raise Exception("requires fairseq >= 0.9.0") | ||
|
||
|
||
logging.basicConfig(level=logging.INFO) | ||
logger = logging.getLogger(__name__) | ||
|
||
SAMPLE_TEXT = "Hello world! cécé herlolip" | ||
|
||
rename_keys = [ | ||
("model.classification_heads.mnli.dense.weight", "classification_head.dense.weight"), | ||
("model.classification_heads.mnli.dense.bias", "classification_head.dense.bias"), | ||
("model.classification_heads.mnli.out_proj.weight", "classification_head.out_proj.weight"), | ||
("model.classification_heads.mnli.out_proj.bias", "classification_head.out_proj.bias"), | ||
] | ||
IGNORE_KEYS = ["encoder.version", "decoder.version", "model.encoder.version", "model.decoder.version"] | ||
|
||
|
||
def rename_key(dct, old, new): | ||
val = dct.pop(old) | ||
dct[new] = val | ||
|
||
|
||
def convert_bart_checkpoint(checkpoint_path, pytorch_dump_folder_path): | ||
""" | ||
Copy/paste/tweak model's weights to our BERT structure. | ||
""" | ||
b2 = torch.hub.load("pytorch/fairseq", checkpoint_path) | ||
b2.eval() # disable dropout | ||
b2.model.upgrade_state_dict(b2.model.state_dict()) | ||
config = BartConfig() | ||
tokens = b2.encode(SAMPLE_TEXT).unsqueeze(0) | ||
tokens2 = BartTokenizer.from_pretrained("bart-large").encode(SAMPLE_TEXT).unsqueeze(0) | ||
assert torch.eq(tokens, tokens2).all() | ||
|
||
# assert their_output.size() == (1, 11, 1024) | ||
|
||
if checkpoint_path == "bart.large": | ||
state_dict = b2.model.state_dict() | ||
state_dict["shared.weight"] = state_dict["decoder.embed_tokens.weight"] | ||
model = BartModel(config) | ||
their_output = b2.extract_features(tokens) | ||
|
||
else: # MNLI Case | ||
state_dict = b2.state_dict() | ||
state_dict["model.shared.weight"] = state_dict["model.decoder.embed_tokens.weight"] | ||
for src, dest in rename_keys: | ||
rename_key(state_dict, src, dest) | ||
state_dict.pop("_float_tensor", None) | ||
model = BartForSequenceClassification(config) | ||
their_output = b2.predict("mnli", tokens, return_logits=True) | ||
for k in IGNORE_KEYS: | ||
state_dict.pop(k, None) | ||
model.load_state_dict(state_dict) | ||
model.eval() | ||
our_outputs = model.forward(tokens)[0] | ||
|
||
assert their_output.shape == our_outputs.shape | ||
assert (their_output == our_outputs).all().item() | ||
Path(pytorch_dump_folder_path).mkdir(exist_ok=True) | ||
model.save_pretrained(pytorch_dump_folder_path) | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser() | ||
# Required parameters | ||
parser.add_argument("fairseq_path", choices=["bart.large", "bart.large.mnli"], type=str, help="") | ||
parser.add_argument("pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model.") | ||
args = parser.parse_args() | ||
convert_bart_checkpoint( | ||
args.fairseq_path, args.pytorch_dump_folder_path, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.