Skip to content

Commit 8551a99

Browse files
authored
Add auto next sentence prediction (#8432)
* Add auto next sentence prediction * Fix style * Add mobilebert next sentence prediction
1 parent c314b1f commit 8551a99

File tree

2 files changed

+107
-2
lines changed

2 files changed

+107
-2
lines changed

src/transformers/modeling_tf_auto.py

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@
6161
from .modeling_tf_bert import (
6262
TFBertForMaskedLM,
6363
TFBertForMultipleChoice,
64+
TFBertForNextSentencePrediction,
6465
TFBertForPreTraining,
6566
TFBertForQuestionAnswering,
6667
TFBertForSequenceClassification,
@@ -120,6 +121,7 @@
120121
from .modeling_tf_mobilebert import (
121122
TFMobileBertForMaskedLM,
122123
TFMobileBertForMultipleChoice,
124+
TFMobileBertForNextSentencePrediction,
123125
TFMobileBertForPreTraining,
124126
TFMobileBertForQuestionAnswering,
125127
TFMobileBertForSequenceClassification,
@@ -355,6 +357,13 @@
355357
]
356358
)
357359

360+
TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING = OrderedDict(
361+
[
362+
(BertConfig, TFBertForNextSentencePrediction),
363+
(MobileBertConfig, TFMobileBertForNextSentencePrediction),
364+
]
365+
)
366+
358367

359368
TF_AUTO_MODEL_PRETRAINED_DOCSTRING = r"""
360369
@@ -1412,3 +1421,101 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
14121421
", ".join(c.__name__ for c in TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING.keys()),
14131422
)
14141423
)
1424+
1425+
1426+
class TFAutoModelForNextSentencePrediction:
1427+
r"""
1428+
This is a generic model class that will be instantiated as one of the model classes of the library---with a
1429+
multiple choice classification head---when created with the when created with the
1430+
:meth:`~transformers.TFAutoModelForNextSentencePrediction.from_pretrained` class method or the
1431+
:meth:`~transformers.TFAutoModelForNextSentencePrediction.from_config` class method.
1432+
1433+
This class cannot be instantiated directly using ``__init__()`` (throws an error).
1434+
"""
1435+
1436+
def __init__(self):
1437+
raise EnvironmentError(
1438+
"TFAutoModelForNextSentencePrediction is designed to be instantiated "
1439+
"using the `TFAutoModelForNextSentencePrediction.from_pretrained(pretrained_model_name_or_path)` or "
1440+
"`TFAutoModelForNextSentencePrediction.from_config(config)` methods."
1441+
)
1442+
1443+
@classmethod
1444+
@replace_list_option_in_docstrings(TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING, use_model_types=False)
1445+
def from_config(cls, config):
1446+
r"""
1447+
Instantiates one of the model classes of the library---with a next sentence prediction head---from a
1448+
configuration.
1449+
1450+
Note:
1451+
Loading a model from its configuration file does **not** load the model weights. It only affects the
1452+
model's configuration. Use :meth:`~transformers.TFAutoModelForNextSentencePrediction.from_pretrained` to
1453+
load the model weights.
1454+
1455+
Args:
1456+
config (:class:`~transformers.PretrainedConfig`):
1457+
The model class to instantiate is selected based on the configuration class:
1458+
1459+
List options
1460+
1461+
Examples::
1462+
1463+
>>> from transformers import AutoConfig, TFAutoModelForNextSentencePrediction
1464+
>>> # Download configuration from S3 and cache.
1465+
>>> config = AutoConfig.from_pretrained('bert-base-uncased')
1466+
>>> model = TFAutoModelForNextSentencePrediction.from_config(config)
1467+
"""
1468+
if type(config) in TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING.keys():
1469+
return TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING[type(config)](config)
1470+
raise ValueError(
1471+
"Unrecognized configuration class {} for this kind of TFAutoModel: {}.\n"
1472+
"Model type should be one of {}.".format(
1473+
config.__class__,
1474+
cls.__name__,
1475+
", ".join(c.__name__ for c in TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING.keys()),
1476+
)
1477+
)
1478+
1479+
@classmethod
1480+
@replace_list_option_in_docstrings(TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING)
1481+
@add_start_docstrings(
1482+
"Instantiate one of the model classes of the library---with a next sentence prediction head---from a "
1483+
"pretrained model.",
1484+
TF_AUTO_MODEL_PRETRAINED_DOCSTRING,
1485+
)
1486+
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
1487+
r"""
1488+
Examples::
1489+
1490+
>>> from transformers import AutoConfig, TFAutoModelForNextSentencePrediction
1491+
1492+
>>> # Download model and configuration from S3 and cache.
1493+
>>> model = TFAutoModelForNextSentencePrediction.from_pretrained('bert-base-uncased')
1494+
1495+
>>> # Update configuration during loading
1496+
>>> model = TFAutoModelForNextSentencePrediction.from_pretrained('bert-base-uncased', output_attentions=True)
1497+
>>> model.config.output_attentions
1498+
True
1499+
1500+
>>> # Loading from a PyTorch checkpoint file instead of a TensorFlow model (slower)
1501+
>>> config = AutoConfig.from_json_file('./pt_model/bert_pt_model_config.json')
1502+
>>> model = TFAutoModelForNextSentencePrediction.from_pretrained('./pt_model/bert_pytorch_model.bin', from_pt=True, config=config)
1503+
"""
1504+
config = kwargs.pop("config", None)
1505+
if not isinstance(config, PretrainedConfig):
1506+
config, kwargs = AutoConfig.from_pretrained(
1507+
pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs
1508+
)
1509+
1510+
if type(config) in TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING.keys():
1511+
return TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING[type(config)].from_pretrained(
1512+
pretrained_model_name_or_path, *model_args, config=config, **kwargs
1513+
)
1514+
raise ValueError(
1515+
"Unrecognized configuration class {} for this kind of TFAutoModel: {}.\n"
1516+
"Model type should be one of {}.".format(
1517+
config.__class__,
1518+
cls.__name__,
1519+
", ".join(c.__name__ for c in TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING.keys()),
1520+
)
1521+
)

utils/check_repo.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,10 +87,8 @@
8787
"RagSequenceForGeneration",
8888
"RagTokenForGeneration",
8989
"T5Stack",
90-
"TFBertForNextSentencePrediction",
9190
"TFFunnelBaseModel",
9291
"TFGPT2DoubleHeadsModel",
93-
"TFMobileBertForNextSentencePrediction",
9492
"TFOpenAIGPTDoubleHeadsModel",
9593
"XLMForQuestionAnswering",
9694
"XLMProphetNetDecoder",

0 commit comments

Comments
 (0)