Skip to content

Commit 48395d6

Browse files
authored
Fix init for MT5 (#8591)
1 parent a6cf9ca commit 48395d6

File tree

2 files changed

+2
-6
lines changed

2 files changed

+2
-6
lines changed

src/transformers/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,7 @@
145145
from .models.mbart import MBartConfig
146146
from .models.mmbt import MMBTConfig
147147
from .models.mobilebert import MOBILEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, MobileBertConfig, MobileBertTokenizer
148+
from .models.mt5 import MT5Config
148149
from .models.openai import OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP, OpenAIGPTConfig, OpenAIGPTTokenizer
149150
from .models.pegasus import PegasusConfig
150151
from .models.phobert import PhobertTokenizer
@@ -498,7 +499,7 @@
498499
MobileBertPreTrainedModel,
499500
load_tf_weights_in_mobilebert,
500501
)
501-
from .models.mt5 import MT5Config, MT5ForConditionalGeneration, MT5Model
502+
from .models.mt5 import MT5ForConditionalGeneration, MT5Model
502503
from .models.openai import (
503504
OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_LIST,
504505
OpenAIGPTDoubleHeadsModel,

src/transformers/utils/dummy_pt_objects.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1361,11 +1361,6 @@ def load_tf_weights_in_mobilebert(*args, **kwargs):
13611361
requires_pytorch(load_tf_weights_in_mobilebert)
13621362

13631363

1364-
class MT5Config:
1365-
def __init__(self, *args, **kwargs):
1366-
requires_pytorch(self)
1367-
1368-
13691364
class MT5ForConditionalGeneration:
13701365
def __init__(self, *args, **kwargs):
13711366
requires_pytorch(self)

0 commit comments

Comments
 (0)