Skip to content

Commit 86822a3

Browse files
T5 & mT5 (#8552)
* add mt5 and t5v1_1 model * fix tests * correct some imports * add tf model * finish tf t5 * improve examples * fix copies * clean doc
1 parent 9e01f98 commit 86822a3

21 files changed

+680
-25
lines changed

docs/source/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -248,6 +248,7 @@ conversion utilities for the following models:
248248
model_doc/marian
249249
model_doc/mbart
250250
model_doc/mobilebert
251+
model_doc/mt5
251252
model_doc/gpt
252253
model_doc/gpt2
253254
model_doc/pegasus

docs/source/model_doc/mt5.rst

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
MT5
2+
-----------------------------------------------------------------------------------------------------------------------
3+
4+
Overview
5+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
6+
7+
The mT5 model was presented in `mT5: A massively multilingual pre-trained text-to-text transformer
8+
<https://arxiv.org/abs/2010.11934>`_ by Linting Xue, Noah Constant, Adam Roberts, Mihir Kale, Rami Al-Rfou, Aditya
9+
Siddhant, Aditya Barua, Colin Raffel.
10+
11+
The abstract from the paper is the following:
12+
13+
*The recent "Text-to-Text Transfer Transformer" (T5) leveraged a unified text-to-text format and scale to attain
14+
state-of-the-art results on a wide variety of English-language NLP tasks. In this paper, we introduce mT5, a
15+
multilingual variant of T5 that was pre-trained on a new Common Crawl-based dataset covering 101 languages. We describe
16+
the design and modified training of mT5 and demonstrate its state-of-the-art performance on many multilingual
17+
benchmarks. All of the code and model checkpoints*
18+
19+
The original code can be found `here <https://github.com/google-research/multilingual-t5>`__.
20+
21+
MT5Config
22+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
23+
24+
.. autoclass:: transformers.MT5Config
25+
:members:
26+
27+
28+
MT5Model
29+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
30+
31+
.. autoclass:: transformers.MT5Model
32+
:members:
33+
34+
35+
MT5ForConditionalGeneration
36+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
37+
38+
.. autoclass:: transformers.MT5ForConditionalGeneration
39+
:members:
40+
41+
42+
TFMT5Model
43+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
44+
45+
.. autoclass:: transformers.TFMT5Model
46+
:members:
47+
48+
49+
TFMT5ForConditionalGeneration
50+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
51+
52+
.. autoclass:: transformers.TFMT5ForConditionalGeneration
53+
:members:

src/transformers/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -498,6 +498,7 @@
498498
MobileBertPreTrainedModel,
499499
load_tf_weights_in_mobilebert,
500500
)
501+
from .models.mt5 import MT5Config, MT5ForConditionalGeneration, MT5Model
501502
from .models.openai import (
502503
OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_LIST,
503504
OpenAIGPTDoubleHeadsModel,
@@ -791,6 +792,7 @@
791792
TFMobileBertModel,
792793
TFMobileBertPreTrainedModel,
793794
)
795+
from .models.mt5 import TFMT5ForConditionalGeneration, TFMT5Model
794796
from .models.openai import (
795797
TF_OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_LIST,
796798
TFOpenAIGPTDoubleHeadsModel,

src/transformers/models/auto/configuration_auto.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
from ..marian.configuration_marian import MarianConfig
4141
from ..mbart.configuration_mbart import MBART_PRETRAINED_CONFIG_ARCHIVE_MAP, MBartConfig
4242
from ..mobilebert.configuration_mobilebert import MobileBertConfig
43+
from ..mt5.configuration_mt5 import MT5Config
4344
from ..openai.configuration_openai import OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP, OpenAIGPTConfig
4445
from ..pegasus.configuration_pegasus import PegasusConfig
4546
from ..prophetnet.configuration_prophetnet import PROPHETNET_PRETRAINED_CONFIG_ARCHIVE_MAP, ProphetNetConfig
@@ -101,6 +102,7 @@
101102
[
102103
# Add configs here
103104
("retribert", RetriBertConfig),
105+
("mt5", MT5Config),
104106
("t5", T5Config),
105107
("mobilebert", MobileBertConfig),
106108
("distilbert", DistilBertConfig),
@@ -178,6 +180,7 @@
178180
("rag", "RAG"),
179181
("xlm-prophetnet", "XLMProphetNet"),
180182
("prophetnet", "ProphetNet"),
183+
("mt5", "mT5"),
181184
]
182185
)
183186

src/transformers/models/auto/modeling_auto.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,7 @@
120120
MobileBertForTokenClassification,
121121
MobileBertModel,
122122
)
123+
from ..mt5.modeling_mt5 import MT5ForConditionalGeneration, MT5Model
123124
from ..openai.modeling_openai import OpenAIGPTForSequenceClassification, OpenAIGPTLMHeadModel, OpenAIGPTModel
124125
from ..pegasus.modeling_pegasus import PegasusForConditionalGeneration
125126
from ..prophetnet.modeling_prophetnet import ProphetNetForCausalLM, ProphetNetForConditionalGeneration, ProphetNetModel
@@ -209,6 +210,7 @@
209210
MarianConfig,
210211
MBartConfig,
211212
MobileBertConfig,
213+
MT5Config,
212214
OpenAIGPTConfig,
213215
PegasusConfig,
214216
ProphetNetConfig,
@@ -235,6 +237,7 @@
235237
[
236238
# Base model mapping
237239
(RetriBertConfig, RetriBertModel),
240+
(MT5Config, MT5Model),
238241
(T5Config, T5Model),
239242
(DistilBertConfig, DistilBertModel),
240243
(AlbertConfig, AlbertModel),
@@ -376,6 +379,7 @@
376379
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = OrderedDict(
377380
[
378381
# Model for Seq2Seq Causal LM mapping
382+
(MT5Config, MT5ForConditionalGeneration),
379383
(T5Config, T5ForConditionalGeneration),
380384
(PegasusConfig, PegasusForConditionalGeneration),
381385
(MarianConfig, MarianMTModel),

src/transformers/models/auto/modeling_tf_auto.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@
106106
TFMobileBertForTokenClassification,
107107
TFMobileBertModel,
108108
)
109+
from ..mt5.modeling_tf_mt5 import TFMT5ForConditionalGeneration, TFMT5Model
109110
from ..openai.modeling_tf_openai import TFOpenAIGPTLMHeadModel, TFOpenAIGPTModel
110111
from ..pegasus.modeling_tf_pegasus import TFPegasusForConditionalGeneration
111112
from ..roberta.modeling_tf_roberta import (
@@ -161,6 +162,7 @@
161162
MarianConfig,
162163
MBartConfig,
163164
MobileBertConfig,
165+
MT5Config,
164166
OpenAIGPTConfig,
165167
PegasusConfig,
166168
RobertaConfig,
@@ -182,6 +184,7 @@
182184
[
183185
# Base model mapping
184186
(LxmertConfig, TFLxmertModel),
187+
(MT5Config, TFMT5Model),
185188
(T5Config, TFT5Model),
186189
(DistilBertConfig, TFDistilBertModel),
187190
(AlbertConfig, TFAlbertModel),
@@ -294,6 +297,7 @@
294297
TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = OrderedDict(
295298
[
296299
# Model for Seq2Seq Causal LM mapping
300+
(MT5Config, TFMT5ForConditionalGeneration),
297301
(T5Config, TFT5ForConditionalGeneration),
298302
(MarianConfig, TFMarianMTModel),
299303
(MBartConfig, TFMBartForConditionalGeneration),
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# flake8: noqa
2+
# There's no way to ignore "F401 '...' imported but unused" warnings in this
3+
# module, but to preserve other warnings. So, don't check this module at all.
4+
5+
from ...file_utils import is_tf_available, is_torch_available
6+
from .configuration_mt5 import MT5Config
7+
8+
9+
if is_torch_available():
10+
from .modeling_mt5 import MT5ForConditionalGeneration, MT5Model
11+
12+
if is_tf_available():
13+
from .modeling_tf_mt5 import TFMT5ForConditionalGeneration, TFMT5Model
Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
# coding=utf-8
2+
# Copyright 2020, The T5 Authors and HuggingFace Inc.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
""" mT5 model configuration """
16+
17+
from ...configuration_utils import PretrainedConfig
18+
from ...utils import logging
19+
20+
21+
logger = logging.get_logger(__name__)
22+
23+
24+
class MT5Config(PretrainedConfig):
25+
r"""
26+
This is the configuration class to store the configuration of a :class:`~transformers.MT5Model` or a
27+
:class:`~transformers.TFMT5Model`. It is used to instantiate a mT5 model according to the specified arguments,
28+
defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration
29+
to that of the mT5 `google/mt5-small <https://huggingface.co/google/mt5-small>`__ architecture.
30+
31+
Configuration objects inherit from :class:`~transformers.PretrainedConfig` and can be used to control the model
32+
outputs. Read the documentation from :class:`~transformers.PretrainedConfig` for more information.
33+
34+
Arguments:
35+
vocab_size (:obj:`int`, `optional`, defaults to 32128):
36+
Vocabulary size of the T5 model. Defines the number of different tokens that can be represented by the
37+
:obj:`inputs_ids` passed when calling :class:`~transformers.T5Model` or :class:`~transformers.TFT5Model`.
38+
d_model (:obj:`int`, `optional`, defaults to 512):
39+
Size of the encoder layers and the pooler layer.
40+
d_kv (:obj:`int`, `optional`, defaults to 64):
41+
Size of the key, query, value projections per attention head. :obj:`d_kv` has to be equal to :obj:`d_model
42+
// num_heads`.
43+
d_ff (:obj:`int`, `optional`, defaults to 1024):
44+
Size of the intermediate feed forward layer in each :obj:`T5Block`.
45+
num_layers (:obj:`int`, `optional`, defaults to 8):
46+
Number of hidden layers in the Transformer encoder.
47+
num_decoder_layers (:obj:`int`, `optional`):
48+
Number of hidden layers in the Transformer decoder. Will use the same value as :obj:`num_layers` if not
49+
set.
50+
num_heads (:obj:`int`, `optional`, defaults to 6):
51+
Number of attention heads for each attention layer in the Transformer encoder.
52+
relative_attention_num_buckets (:obj:`int`, `optional`, defaults to 32):
53+
The number of buckets to use for each attention layer.
54+
dropout_rate (:obj:`float`, `optional`, defaults to 0.1):
55+
The ratio for all dropout layers.
56+
layer_norm_eps (:obj:`float`, `optional`, defaults to 1e-6):
57+
The epsilon used by the layer normalization layers.
58+
initializer_factor (:obj:`float`, `optional`, defaults to 1):
59+
A factor for initializing all weight matrices (should be kept to 1, used internally for initialization
60+
testing).
61+
feed_forward_proj (:obj:`string`, `optional`, defaults to :obj:`"gated-gelu"`):
62+
Type of feed forward layer to be used. Should be one of :obj:`"relu"` or :obj:`"gated-gelu"`.
63+
"""
64+
model_type = "mt5"
65+
66+
def __init__(
67+
self,
68+
vocab_size=250112,
69+
d_model=512,
70+
d_kv=64,
71+
d_ff=1024,
72+
num_layers=8,
73+
num_decoder_layers=None,
74+
num_heads=6,
75+
relative_attention_num_buckets=32,
76+
dropout_rate=0.1,
77+
layer_norm_epsilon=1e-6,
78+
initializer_factor=1.0,
79+
feed_forward_proj="gated-gelu",
80+
is_encoder_decoder=True,
81+
tokenizer_class="T5Tokenizer",
82+
tie_word_embeddings=False,
83+
pad_token_id=0,
84+
eos_token_id=1,
85+
decoder_start_token_id=0,
86+
**kwargs
87+
):
88+
super().__init__(
89+
is_encoder_decoder=is_encoder_decoder,
90+
tokenizer_class=tokenizer_class,
91+
tie_word_embeddings=tie_word_embeddings,
92+
pad_token_id=pad_token_id,
93+
eos_token_id=eos_token_id,
94+
decoder_start_token_id=decoder_start_token_id,
95+
**kwargs,
96+
)
97+
self.vocab_size = vocab_size
98+
self.d_model = d_model
99+
self.d_kv = d_kv
100+
self.d_ff = d_ff
101+
self.num_layers = num_layers
102+
self.num_decoder_layers = (
103+
num_decoder_layers if num_decoder_layers is not None else self.num_layers
104+
) # default = symmetry
105+
self.num_heads = num_heads
106+
self.relative_attention_num_buckets = relative_attention_num_buckets
107+
self.dropout_rate = dropout_rate
108+
self.layer_norm_epsilon = layer_norm_epsilon
109+
self.initializer_factor = initializer_factor
110+
self.feed_forward_proj = feed_forward_proj
111+
112+
@property
113+
def hidden_size(self):
114+
return self.d_model
115+
116+
@property
117+
def num_attention_heads(self):
118+
return self.num_heads
119+
120+
@property
121+
def num_hidden_layers(self):
122+
return self.num_layers
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
# coding=utf-8
2+
# Copyright 2020 Mesh TensorFlow authors, T5 Authors and HuggingFace Inc. team.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
""" PyTorch mT5 model. """
16+
17+
from ...utils import logging
18+
from ..t5.modeling_t5 import T5ForConditionalGeneration, T5Model
19+
from .configuration_mt5 import MT5Config
20+
21+
22+
logger = logging.get_logger(__name__)
23+
24+
_CONFIG_FOR_DOC = "T5Config"
25+
_TOKENIZER_FOR_DOC = "T5Tokenizer"
26+
27+
28+
class MT5Model(T5Model):
29+
r"""
30+
This class overrides :class:`~transformers.T5Model`. Please check the superclass for the appropriate documentation
31+
alongside usage examples.
32+
33+
Examples::
34+
>>> from transformers import MT5Model, T5Tokenizer
35+
>>> model = MT5Model.from_pretrained("google/mt5-small")
36+
>>> tokenizer = T5Tokenizer.from_pretrained("google/mt5-small")
37+
>>> article = "UN Offizier sagt, dass weiter verhandelt werden muss in Syrien."
38+
>>> summary = "Weiter Verhandlung in Syrien."
39+
>>> batch = tokenizer.prepare_seq2seq_batch(src_texts=[article], tgt_texts=[summary], return_tensors="pt")
40+
>>> outputs = model(input_ids=batch.input_ids, decoder_input_ids=batch.labels)
41+
>>> hidden_states = outputs.last_hidden_state
42+
"""
43+
model_type = "mt5"
44+
config_class = MT5Config
45+
authorized_missing_keys = [
46+
r"encoder\.embed_tokens\.weight",
47+
r"decoder\.embed_tokens\.weight",
48+
r"decoder\.block\.0\.layer\.1\.EncDecAttention\.relative_attention_bias\.weight",
49+
]
50+
keys_to_never_save = [
51+
r"encoder\.embed_tokens\.weight",
52+
r"decoder\.embed_tokens\.weight",
53+
]
54+
55+
56+
class MT5ForConditionalGeneration(T5ForConditionalGeneration):
57+
r"""
58+
This class overrides :class:`~transformers.T5ForConditionalGeneration`. Please check the superclass for the
59+
appropriate documentation alongside usage examples.
60+
61+
Examples::
62+
>>> from transformers import MT5ForConditionalGeneration, T5Tokenizer
63+
>>> model = MT5ForConditionalGeneration.from_pretrained("google/mt5-small")
64+
>>> tokenizer = T5Tokenizer.from_pretrained("google/mt5-small")
65+
>>> article = "UN Offizier sagt, dass weiter verhandelt werden muss in Syrien."
66+
>>> summary = "Weiter Verhandlung in Syrien."
67+
>>> batch = tokenizer.prepare_seq2seq_batch(src_texts=[article], tgt_texts=[summary], return_tensors="pt")
68+
>>> outputs = model(**batch)
69+
>>> loss = outputs.loss
70+
"""
71+
72+
model_type = "mt5"
73+
config_class = MT5Config
74+
authorized_missing_keys = [
75+
r"encoder\.embed_tokens\.weight",
76+
r"decoder\.embed_tokens\.weight",
77+
r"lm_head\.weight",
78+
r"decoder\.block\.0\.layer\.1\.EncDecAttention\.relative_attention_bias\.weight",
79+
]
80+
keys_to_never_save = [
81+
r"encoder\.embed_tokens\.weight",
82+
r"decoder\.embed_tokens\.weight",
83+
]

0 commit comments

Comments
 (0)