-
Notifications
You must be signed in to change notification settings - Fork 29.2k
Add "Leveraging Pretrained Checkpoints for Generation" Seq2Seq models. #6594
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add "Leveraging Pretrained Checkpoints for Generation" Seq2Seq models. #6594
Conversation
cf1f3fc
to
89538ff
Compare
If I understand correctly, the BERT model used here is slightly different because:
Doesn't that just mean we could use an additional architecture instead of an entire model class? Something like the following, in @add_start_docstrings(
"""Bert Model with a `language modeling` head on top that acts as a decoder in a seq2seq setting.""", BERT_START_DOCSTRING
)
class CausalBertModel(BertPreTrainedModel):
def __init__(self, config):
super().__init__(config)
if not config.is_decoder:
logger.warning("If you want to use `BertLMHeadModel` as a standalone, add `is_decoder=True.`")
self.bert = BertModel(config)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.init_weights() And this module wouldn't accept token type IDs as an input. I don't know what to do regarding the tokenizer though. This ^ approach could probably leverage @julien-c's #6995 |
Naming ideas: Reasoning:
Anyways, your signatures look super clean, easy and consistent! |
There are a couple of problems with that:
So overall, it seems to me that a separate model class is the cleaner way to go - what do you think? @sshleifer - very much agree here! Think the naming should be different... |
BertForGenerationEncoder and BertForGenerationDecoder and BertForGenerationConfig 👍 I do see lysandre's point though and would be fine with you setting |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks great to me, renaming apart. Since the names have been in a release already, I think we need proper deprecation warnings before removing those old names.
src/transformers/__init__.py
Outdated
@@ -22,7 +22,7 @@ | |||
from .configuration_auto import ALL_PRETRAINED_CONFIG_ARCHIVE_MAP, CONFIG_MAPPING, AutoConfig | |||
from .configuration_bart import BartConfig | |||
from .configuration_bert import BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, BertConfig | |||
from .configuration_causal_bert import CausalBertConfig |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This class has been in a release already. We can't remove it without proper deprecation warnings.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This was a weird diff -> the config, tokenizer and model CausalBert...
were never in the library - I added them yesterday to the PR. If you look at the changed files now you can see that no previous model names are removed :-)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will still need 1,2 days to fiinish the PR including integration tests and model cards, etc...so no need to review yet :-)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh then in that case, no problem with renaming things :-)
src/transformers/__init__.py
Outdated
@@ -418,9 +418,9 @@ | |||
TransfoXLPreTrainedModel, | |||
load_tf_weights_in_transfo_xl, | |||
) | |||
from .modeling_causal_bert import ( | |||
CausalBertModel, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same for those names.
src/transformers/__init__.py
Outdated
@@ -144,7 +144,7 @@ | |||
from .tokenization_bert import BasicTokenizer, BertTokenizer, BertTokenizerFast, WordpieceTokenizer | |||
from .tokenization_bert_japanese import BertJapaneseTokenizer, CharacterTokenizer, MecabTokenizer | |||
from .tokenization_camembert import CamembertTokenizer | |||
from .tokenization_causal_bert import CausalBertTokenizer | |||
from .tokenization_bert import BertForSeqGenerationTokenizer |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here
summarization models seem to function and are uploaded here: https://huggingface.co/models?search=google%2Froberta2roberta |
1b1a2c8
to
a6392eb
Compare
elif ( | ||
hasattr(self.config, "decoder") | ||
and hasattr(self.config.decoder, "bos_token_id") | ||
and self.config.decoder.bos_token_id is not None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
need one for more check for this
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(out of scope)
I would be down for a helper method
determine_decoder_start_token_id
to get this out of the main block.
@@ -22,7 +22,6 @@ | |||
import sentencepiece as spm | |||
|
|||
from .tokenization_utils import PreTrainedTokenizer | |||
from .tokenization_xlnet import SPIECE_UNDERLINE |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
small clean-up
UPDATE: PR is ready for review @sshleifer @LysandreJik @sgugger . Would be awesome if you could take a look |
Codecov Report
@@ Coverage Diff @@
## master #6594 +/- ##
==========================================
+ Coverage 78.37% 80.49% +2.11%
==========================================
Files 164 167 +3
Lines 31026 31314 +288
==========================================
+ Hits 24318 25207 +889
+ Misses 6708 6107 -601
Continue to review full report at Codecov.
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice!
|
||
*Unsupervised pre-training of large neural models has recently revolutionized Natural Language Processing. By warm-starting from the publicly released checkpoints, NLP practitioners have pushed the state-of-the-art on multiple benchmarks while saving significant amounts of compute time. So far the focus has been mainly on the Natural Language Understanding tasks. In this paper, we demonstrate the efficacy of pre-trained checkpoints for Sequence Generation. We developed a Transformer-based sequence-to-sequence model that is compatible with publicly available pre-trained BERT, GPT-2 and RoBERTa checkpoints and conducted an extensive empirical study on the utility of initializing our model, both encoder and decoder, with these checkpoints. Our models result in new state-of-the-art results on Machine Translation, Text Summarization, Sentence Splitting, and Sentence Fusion.* | ||
|
||
Tips: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would make sure a usage example is before Tips (or right after)
elif ( | ||
hasattr(self.config, "decoder") | ||
and hasattr(self.config.decoder, "bos_token_id") | ||
and self.config.decoder.bos_token_id is not None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(out of scope)
I would be down for a helper method
determine_decoder_start_token_id
to get this out of the main block.
@@ -0,0 +1,44 @@ | |||
BertForSeqGeneration |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
dont see why Seq
should be in the name. What other kind of generation might a confused person be thinking of?
dont feel strongly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
changed it to BertGeneration
} | ||
|
||
@slow | ||
def test_roberta2roberta_summarization(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👍
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
does generation with model.half()
work?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe not sure - will check
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great to me! I mostly have annoying nits about the docs, cause I'm an annoying person.
Haha, no you are 100% right - sorry for being so sloppy with the docs! I should have learnt it by now .... |
@sshleifer @sgugger - thanks a lot for your suggestions. I went for the name |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great, very cool!!
Have the "share" models been implemented? In the paper, in many tasks they achieve the best results. |
Yes you can find then under google/roberta2roberta |
Thank you. How to tie weights in the code for training own model? |
|
huggingface#6594) * add conversion script * improve conversion script * make style * add tryout files * fix * update * add causal bert * better names * add tokenizer file as well * finish causal_bert * fix small bugs * improve generate * change naming * renaming * renaming * renaming * remove leftover files * clean files * add fix tokenizer * finalize * correct slow test * update docs * small fixes * fix link * adapt check repo * apply sams and sylvains recommendations * fix import * implement Lysandres recommendations * fix logger warn
This PR adds the models from the following paper:
Paper: https://arxiv.org/pdf/1907.12461.pdf
The paper does a great job at showing how pretrained BERT & RoBERTa model can be leveraged for Seq2Seq tasks and yields good results on many seq2seq tasks. It's fits very well with the current implementation of the EncoderDecoder framework.
This PR adds code to port all pretrained encoder decoder models that can be found here: https://tfhub.dev/s?module-type=text-generation&subtype=module,placeholder,
which can be found here: https://huggingface.co/models?search=google%2Froberta
and here: https://huggingface.co/models?search=google%2Fbert2
An example of how a model can be used is here:
https://huggingface.co/google/roberta2roberta_L-24_bbc
Big thanks to @shashiongithub for providing me with the tokenizer files and giving valuable insights on setting the correct generation parameters!