Skip to content

Add "Leveraging Pretrained Checkpoints for Generation" Seq2Seq models. #6594

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Conversation

patrickvonplaten
Copy link
Contributor

@patrickvonplaten patrickvonplaten commented Aug 19, 2020

This PR adds the models from the following paper:

Paper: https://arxiv.org/pdf/1907.12461.pdf

The paper does a great job at showing how pretrained BERT & RoBERTa model can be leveraged for Seq2Seq tasks and yields good results on many seq2seq tasks. It's fits very well with the current implementation of the EncoderDecoder framework.

This PR adds code to port all pretrained encoder decoder models that can be found here: https://tfhub.dev/s?module-type=text-generation&subtype=module,placeholder,

which can be found here: https://huggingface.co/models?search=google%2Froberta
and here: https://huggingface.co/models?search=google%2Fbert2

An example of how a model can be used is here:
https://huggingface.co/google/roberta2roberta_L-24_bbc

Big thanks to @shashiongithub for providing me with the tokenizer files and giving valuable insights on setting the correct generation parameters!

@patrickvonplaten patrickvonplaten changed the title [Seq2Seq] Port google EncoderDecoder pretrained checkpoint models into EncoderDecoder framework [WIP, Seq2Seq] Port google EncoderDecoder pretrained checkpoint models into EncoderDecoder framework Aug 19, 2020
@patrickvonplaten patrickvonplaten force-pushed the add_seq2seq_tf_hub_conversion_script branch from cf1f3fc to 89538ff Compare September 7, 2020 14:43
@LysandreJik
Copy link
Member

If I understand correctly, the BERT model used here is slightly different because:

  • It doesn't use token type IDs
  • It's tying its word embedding layer to its LM head
  • No pooling layer

Doesn't that just mean we could use an additional architecture instead of an entire model class? Something like the following, in modeling_bert.py:

@add_start_docstrings(
    """Bert Model with a `language modeling` head on top that acts as a decoder in a seq2seq setting.""", BERT_START_DOCSTRING
)
class CausalBertModel(BertPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)

        if not config.is_decoder:
            logger.warning("If you want to use `BertLMHeadModel` as a standalone, add `is_decoder=True.`")

        self.bert = BertModel(config)
        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)

        self.init_weights()

And this module wouldn't accept token type IDs as an input.

I don't know what to do regarding the tokenizer though. This ^ approach could probably leverage @julien-c's #6995

@sshleifer
Copy link
Contributor

Naming ideas: BertFor{Conditional}Generation, BertEncoder, BertDecoder.

Reasoning:

  • CausalBert doesn't make sense if the encoder wasn't trained with a causal mask.
  • I think in class naming it's more important to give someone a sense of how to use something than how that thing was trained, but that's not an opinion I hold strongly.

Anyways, your signatures look super clean, easy and consistent!
Excited to try these out+happy to help check metrics.

@patrickvonplaten
Copy link
Contributor Author

patrickvonplaten commented Sep 7, 2020

@LysandreJik,

There are a couple of problems with that:

  1. I also need a different BertEmbeddings or manually set self.token_type_embeddings to a zero matrix. Even if token_type_ids is set to None in Bert, the self.token_type_embeddings is always used. This model just does not have the embeddings (and should not have them IMO). I could set the self.token_type_embedding matrix just to 0, but then people using this class for training would not realize that a self.token_type_embedding matrix is trained which it shouldn't. So, here I think either way, I will need a separete BertEmbeddings class.

  2. A bigger problem is the config class. Because I need both the new CausalBertForCausalLM and BertLMHeadModel in the AUTO_MODELS_FOR_CAUSAL_LM class (to leverage both models with the EncoderDecoder framework), the two models have to have different config classes. I guess we could also create a separate config class and overwrite the inherited config class from BertPretrainedModel, but then IMO, it's cleaner to just create a new PretrainedModelClass and in this case we can directly create a completely new model class

So overall, it seems to me that a separate model class is the cleaner way to go - what do you think?

@sshleifer - very much agree here! Think the naming should be different...BertEncoder is already taken though. I could go for BertForGenerationEncoder and BertForGenerationDecoder and BertForGenerationConfig - No need for BertForConditionalGeneration as the `EncoderDecoderModel will be used for this

@sshleifer
Copy link
Contributor

BertForGenerationEncoder and BertForGenerationDecoder and BertForGenerationConfig 👍

I do see lysandre's point though and would be fine with you setting token_type matrix to 0 if it's small (which I think it is).

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks great to me, renaming apart. Since the names have been in a release already, I think we need proper deprecation warnings before removing those old names.

@@ -22,7 +22,7 @@
from .configuration_auto import ALL_PRETRAINED_CONFIG_ARCHIVE_MAP, CONFIG_MAPPING, AutoConfig
from .configuration_bart import BartConfig
from .configuration_bert import BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, BertConfig
from .configuration_causal_bert import CausalBertConfig
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This class has been in a release already. We can't remove it without proper deprecation warnings.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was a weird diff -> the config, tokenizer and model CausalBert... were never in the library - I added them yesterday to the PR. If you look at the changed files now you can see that no previous model names are removed :-)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will still need 1,2 days to fiinish the PR including integration tests and model cards, etc...so no need to review yet :-)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh then in that case, no problem with renaming things :-)

@@ -418,9 +418,9 @@
TransfoXLPreTrainedModel,
load_tf_weights_in_transfo_xl,
)
from .modeling_causal_bert import (
CausalBertModel,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same for those names.

@@ -144,7 +144,7 @@
from .tokenization_bert import BasicTokenizer, BertTokenizer, BertTokenizerFast, WordpieceTokenizer
from .tokenization_bert_japanese import BertJapaneseTokenizer, CharacterTokenizer, MecabTokenizer
from .tokenization_camembert import CamembertTokenizer
from .tokenization_causal_bert import CausalBertTokenizer
from .tokenization_bert import BertForSeqGenerationTokenizer
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here

@patrickvonplaten
Copy link
Contributor Author

patrickvonplaten commented Sep 8, 2020

summarization models seem to function and are uploaded here:

https://huggingface.co/models?search=google%2Froberta2roberta

@patrickvonplaten patrickvonplaten force-pushed the add_seq2seq_tf_hub_conversion_script branch from 1b1a2c8 to a6392eb Compare September 9, 2020 17:40
elif (
hasattr(self.config, "decoder")
and hasattr(self.config.decoder, "bos_token_id")
and self.config.decoder.bos_token_id is not None
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

need one for more check for this

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(out of scope)
I would be down for a helper method
determine_decoder_start_token_id to get this out of the main block.

@@ -22,7 +22,6 @@
import sentencepiece as spm

from .tokenization_utils import PreTrainedTokenizer
from .tokenization_xlnet import SPIECE_UNDERLINE
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

small clean-up

@patrickvonplaten patrickvonplaten changed the title [WIP, Seq2Seq] Port google EncoderDecoder pretrained checkpoint models into EncoderDecoder framework Add "Leveraging Pretrained Checkpoints for Generation" Seq2Seq models. Sep 9, 2020
@patrickvonplaten
Copy link
Contributor Author

UPDATE: PR is ready for review @sshleifer @LysandreJik @sgugger .

Would be awesome if you could take a look

@codecov
Copy link

codecov bot commented Sep 9, 2020

Codecov Report

Merging #6594 into master will increase coverage by 2.11%.
The diff coverage is 75.76%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master    #6594      +/-   ##
==========================================
+ Coverage   78.37%   80.49%   +2.11%     
==========================================
  Files         164      167       +3     
  Lines       31026    31314     +288     
==========================================
+ Hits        24318    25207     +889     
+ Misses       6708     6107     -601     
Impacted Files Coverage Δ
src/transformers/tokenization_t5.py 95.23% <ø> (-0.05%) ⬇️
src/transformers/tokenization_auto.py 91.52% <40.00%> (-4.78%) ⬇️
src/transformers/modeling_encoder_decoder.py 88.78% <50.00%> (-3.22%) ⬇️
src/transformers/modeling_bert_generation.py 69.19% <69.19%> (ø)
src/transformers/tokenization_bert_generation.py 94.64% <94.64%> (ø)
src/transformers/__init__.py 99.33% <100.00%> (+0.01%) ⬆️
src/transformers/configuration_auto.py 93.61% <100.00%> (+0.13%) ⬆️
src/transformers/configuration_bert_generation.py 100.00% <100.00%> (ø)
src/transformers/file_utils.py 82.41% <100.00%> (-0.26%) ⬇️
src/transformers/generation_utils.py 96.92% <100.00%> (-0.28%) ⬇️
... and 22 more

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 15478c1...aa953cb. Read the comment docs.

Copy link
Contributor

@sshleifer sshleifer left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!


*Unsupervised pre-training of large neural models has recently revolutionized Natural Language Processing. By warm-starting from the publicly released checkpoints, NLP practitioners have pushed the state-of-the-art on multiple benchmarks while saving significant amounts of compute time. So far the focus has been mainly on the Natural Language Understanding tasks. In this paper, we demonstrate the efficacy of pre-trained checkpoints for Sequence Generation. We developed a Transformer-based sequence-to-sequence model that is compatible with publicly available pre-trained BERT, GPT-2 and RoBERTa checkpoints and conducted an extensive empirical study on the utility of initializing our model, both encoder and decoder, with these checkpoints. Our models result in new state-of-the-art results on Machine Translation, Text Summarization, Sentence Splitting, and Sentence Fusion.*

Tips:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would make sure a usage example is before Tips (or right after)

elif (
hasattr(self.config, "decoder")
and hasattr(self.config.decoder, "bos_token_id")
and self.config.decoder.bos_token_id is not None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(out of scope)
I would be down for a helper method
determine_decoder_start_token_id to get this out of the main block.

@@ -0,0 +1,44 @@
BertForSeqGeneration
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dont see why Seq should be in the name. What other kind of generation might a confused person be thinking of?
dont feel strongly.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

changed it to BertGeneration

}

@slow
def test_roberta2roberta_summarization(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does generation with model.half() work?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe not sure - will check

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great to me! I mostly have annoying nits about the docs, cause I'm an annoying person.

@patrickvonplaten
Copy link
Contributor Author

Looks great to me! I mostly have annoying nits about the docs, cause I'm an annoying person.

Haha, no you are 100% right - sorry for being so sloppy with the docs! I should have learnt it by now ....

@patrickvonplaten
Copy link
Contributor Author

@sshleifer @sgugger - thanks a lot for your suggestions. I went for the name BertGenerationEncoder and BertGenerationDecoder now. I think it's the best trade-off between short and concise name that is not confusing.

Copy link
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great, very cool!!

@patrickvonplaten patrickvonplaten merged commit 7fd1feb into huggingface:master Sep 10, 2020
@djstrong
Copy link

djstrong commented Sep 24, 2020

Have the "share" models been implemented? In the paper, in many tasks they achieve the best results.

@patrickvonplaten
Copy link
Contributor Author

Yes you can find then under google/roberta2roberta

@djstrong
Copy link

Thank you. How to tie weights in the code for training own model?

@patrickvonplaten
Copy link
Contributor Author

tie_encoder_decoder=True -> The code in this model card should show you how to do it :-) https://huggingface.co/patrickvonplaten/roberta2roberta-share-cnn_dailymail-fp16

Zigur pushed a commit to Zigur/transformers that referenced this pull request Oct 26, 2020
huggingface#6594)

* add conversion script

* improve conversion script

* make style

* add tryout files

* fix

* update

* add causal bert

* better names

* add tokenizer file as well

* finish causal_bert

* fix small bugs

* improve generate

* change naming

* renaming

* renaming

* renaming

* remove leftover files

* clean files

* add fix tokenizer

* finalize

* correct slow test

* update docs

* small fixes

* fix link

* adapt check repo

* apply sams and sylvains recommendations

* fix import

* implement Lysandres recommendations

* fix logger warn
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants