Skip to content

Commit

Permalink
Add auto next sentence prediction (huggingface#8432)
Browse files Browse the repository at this point in the history
* Add auto next sentence prediction

* Fix style

* Add mobilebert next sentence prediction
  • Loading branch information
jplu authored and fabiocapsouza committed Nov 15, 2020
1 parent 421c606 commit b6604af
Show file tree
Hide file tree
Showing 2 changed files with 107 additions and 2 deletions.
107 changes: 107 additions & 0 deletions src/transformers/modeling_tf_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
from .modeling_tf_bert import (
TFBertForMaskedLM,
TFBertForMultipleChoice,
TFBertForNextSentencePrediction,
TFBertForPreTraining,
TFBertForQuestionAnswering,
TFBertForSequenceClassification,
Expand Down Expand Up @@ -120,6 +121,7 @@
from .modeling_tf_mobilebert import (
TFMobileBertForMaskedLM,
TFMobileBertForMultipleChoice,
TFMobileBertForNextSentencePrediction,
TFMobileBertForPreTraining,
TFMobileBertForQuestionAnswering,
TFMobileBertForSequenceClassification,
Expand Down Expand Up @@ -355,6 +357,13 @@
]
)

TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING = OrderedDict(
[
(BertConfig, TFBertForNextSentencePrediction),
(MobileBertConfig, TFMobileBertForNextSentencePrediction),
]
)


TF_AUTO_MODEL_PRETRAINED_DOCSTRING = r"""
Expand Down Expand Up @@ -1412,3 +1421,101 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
", ".join(c.__name__ for c in TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING.keys()),
)
)


class TFAutoModelForNextSentencePrediction:
r"""
This is a generic model class that will be instantiated as one of the model classes of the library---with a
multiple choice classification head---when created with the when created with the
:meth:`~transformers.TFAutoModelForNextSentencePrediction.from_pretrained` class method or the
:meth:`~transformers.TFAutoModelForNextSentencePrediction.from_config` class method.
This class cannot be instantiated directly using ``__init__()`` (throws an error).
"""

def __init__(self):
raise EnvironmentError(
"TFAutoModelForNextSentencePrediction is designed to be instantiated "
"using the `TFAutoModelForNextSentencePrediction.from_pretrained(pretrained_model_name_or_path)` or "
"`TFAutoModelForNextSentencePrediction.from_config(config)` methods."
)

@classmethod
@replace_list_option_in_docstrings(TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING, use_model_types=False)
def from_config(cls, config):
r"""
Instantiates one of the model classes of the library---with a next sentence prediction head---from a
configuration.
Note:
Loading a model from its configuration file does **not** load the model weights. It only affects the
model's configuration. Use :meth:`~transformers.TFAutoModelForNextSentencePrediction.from_pretrained` to
load the model weights.
Args:
config (:class:`~transformers.PretrainedConfig`):
The model class to instantiate is selected based on the configuration class:
List options
Examples::
>>> from transformers import AutoConfig, TFAutoModelForNextSentencePrediction
>>> # Download configuration from S3 and cache.
>>> config = AutoConfig.from_pretrained('bert-base-uncased')
>>> model = TFAutoModelForNextSentencePrediction.from_config(config)
"""
if type(config) in TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING.keys():
return TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING[type(config)](config)
raise ValueError(
"Unrecognized configuration class {} for this kind of TFAutoModel: {}.\n"
"Model type should be one of {}.".format(
config.__class__,
cls.__name__,
", ".join(c.__name__ for c in TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING.keys()),
)
)

@classmethod
@replace_list_option_in_docstrings(TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING)
@add_start_docstrings(
"Instantiate one of the model classes of the library---with a next sentence prediction head---from a "
"pretrained model.",
TF_AUTO_MODEL_PRETRAINED_DOCSTRING,
)
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
r"""
Examples::
>>> from transformers import AutoConfig, TFAutoModelForNextSentencePrediction
>>> # Download model and configuration from S3 and cache.
>>> model = TFAutoModelForNextSentencePrediction.from_pretrained('bert-base-uncased')
>>> # Update configuration during loading
>>> model = TFAutoModelForNextSentencePrediction.from_pretrained('bert-base-uncased', output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a PyTorch checkpoint file instead of a TensorFlow model (slower)
>>> config = AutoConfig.from_json_file('./pt_model/bert_pt_model_config.json')
>>> model = TFAutoModelForNextSentencePrediction.from_pretrained('./pt_model/bert_pytorch_model.bin', from_pt=True, config=config)
"""
config = kwargs.pop("config", None)
if not isinstance(config, PretrainedConfig):
config, kwargs = AutoConfig.from_pretrained(
pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs
)

if type(config) in TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING.keys():
return TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING[type(config)].from_pretrained(
pretrained_model_name_or_path, *model_args, config=config, **kwargs
)
raise ValueError(
"Unrecognized configuration class {} for this kind of TFAutoModel: {}.\n"
"Model type should be one of {}.".format(
config.__class__,
cls.__name__,
", ".join(c.__name__ for c in TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING.keys()),
)
)
2 changes: 0 additions & 2 deletions utils/check_repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,10 +87,8 @@
"RagSequenceForGeneration",
"RagTokenForGeneration",
"T5Stack",
"TFBertForNextSentencePrediction",
"TFFunnelBaseModel",
"TFGPT2DoubleHeadsModel",
"TFMobileBertForNextSentencePrediction",
"TFOpenAIGPTDoubleHeadsModel",
"XLMForQuestionAnswering",
"XLMProphetNetDecoder",
Expand Down

0 comments on commit b6604af

Please sign in to comment.