|
61 | 61 | from .modeling_tf_bert import (
|
62 | 62 | TFBertForMaskedLM,
|
63 | 63 | TFBertForMultipleChoice,
|
| 64 | + TFBertForNextSentencePrediction, |
64 | 65 | TFBertForPreTraining,
|
65 | 66 | TFBertForQuestionAnswering,
|
66 | 67 | TFBertForSequenceClassification,
|
|
120 | 121 | from .modeling_tf_mobilebert import (
|
121 | 122 | TFMobileBertForMaskedLM,
|
122 | 123 | TFMobileBertForMultipleChoice,
|
| 124 | + TFMobileBertForNextSentencePrediction, |
123 | 125 | TFMobileBertForPreTraining,
|
124 | 126 | TFMobileBertForQuestionAnswering,
|
125 | 127 | TFMobileBertForSequenceClassification,
|
|
355 | 357 | ]
|
356 | 358 | )
|
357 | 359 |
|
| 360 | +TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING = OrderedDict( |
| 361 | + [ |
| 362 | + (BertConfig, TFBertForNextSentencePrediction), |
| 363 | + (MobileBertConfig, TFMobileBertForNextSentencePrediction), |
| 364 | + ] |
| 365 | +) |
| 366 | + |
358 | 367 |
|
359 | 368 | TF_AUTO_MODEL_PRETRAINED_DOCSTRING = r"""
|
360 | 369 |
|
@@ -1412,3 +1421,101 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
1412 | 1421 | ", ".join(c.__name__ for c in TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING.keys()),
|
1413 | 1422 | )
|
1414 | 1423 | )
|
| 1424 | + |
| 1425 | + |
| 1426 | +class TFAutoModelForNextSentencePrediction: |
| 1427 | + r""" |
| 1428 | + This is a generic model class that will be instantiated as one of the model classes of the library---with a |
| 1429 | + multiple choice classification head---when created with the when created with the |
| 1430 | + :meth:`~transformers.TFAutoModelForNextSentencePrediction.from_pretrained` class method or the |
| 1431 | + :meth:`~transformers.TFAutoModelForNextSentencePrediction.from_config` class method. |
| 1432 | +
|
| 1433 | + This class cannot be instantiated directly using ``__init__()`` (throws an error). |
| 1434 | + """ |
| 1435 | + |
| 1436 | + def __init__(self): |
| 1437 | + raise EnvironmentError( |
| 1438 | + "TFAutoModelForNextSentencePrediction is designed to be instantiated " |
| 1439 | + "using the `TFAutoModelForNextSentencePrediction.from_pretrained(pretrained_model_name_or_path)` or " |
| 1440 | + "`TFAutoModelForNextSentencePrediction.from_config(config)` methods." |
| 1441 | + ) |
| 1442 | + |
| 1443 | + @classmethod |
| 1444 | + @replace_list_option_in_docstrings(TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING, use_model_types=False) |
| 1445 | + def from_config(cls, config): |
| 1446 | + r""" |
| 1447 | + Instantiates one of the model classes of the library---with a next sentence prediction head---from a |
| 1448 | + configuration. |
| 1449 | +
|
| 1450 | + Note: |
| 1451 | + Loading a model from its configuration file does **not** load the model weights. It only affects the |
| 1452 | + model's configuration. Use :meth:`~transformers.TFAutoModelForNextSentencePrediction.from_pretrained` to |
| 1453 | + load the model weights. |
| 1454 | +
|
| 1455 | + Args: |
| 1456 | + config (:class:`~transformers.PretrainedConfig`): |
| 1457 | + The model class to instantiate is selected based on the configuration class: |
| 1458 | +
|
| 1459 | + List options |
| 1460 | +
|
| 1461 | + Examples:: |
| 1462 | +
|
| 1463 | + >>> from transformers import AutoConfig, TFAutoModelForNextSentencePrediction |
| 1464 | + >>> # Download configuration from S3 and cache. |
| 1465 | + >>> config = AutoConfig.from_pretrained('bert-base-uncased') |
| 1466 | + >>> model = TFAutoModelForNextSentencePrediction.from_config(config) |
| 1467 | + """ |
| 1468 | + if type(config) in TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING.keys(): |
| 1469 | + return TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING[type(config)](config) |
| 1470 | + raise ValueError( |
| 1471 | + "Unrecognized configuration class {} for this kind of TFAutoModel: {}.\n" |
| 1472 | + "Model type should be one of {}.".format( |
| 1473 | + config.__class__, |
| 1474 | + cls.__name__, |
| 1475 | + ", ".join(c.__name__ for c in TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING.keys()), |
| 1476 | + ) |
| 1477 | + ) |
| 1478 | + |
| 1479 | + @classmethod |
| 1480 | + @replace_list_option_in_docstrings(TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING) |
| 1481 | + @add_start_docstrings( |
| 1482 | + "Instantiate one of the model classes of the library---with a next sentence prediction head---from a " |
| 1483 | + "pretrained model.", |
| 1484 | + TF_AUTO_MODEL_PRETRAINED_DOCSTRING, |
| 1485 | + ) |
| 1486 | + def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): |
| 1487 | + r""" |
| 1488 | + Examples:: |
| 1489 | +
|
| 1490 | + >>> from transformers import AutoConfig, TFAutoModelForNextSentencePrediction |
| 1491 | +
|
| 1492 | + >>> # Download model and configuration from S3 and cache. |
| 1493 | + >>> model = TFAutoModelForNextSentencePrediction.from_pretrained('bert-base-uncased') |
| 1494 | +
|
| 1495 | + >>> # Update configuration during loading |
| 1496 | + >>> model = TFAutoModelForNextSentencePrediction.from_pretrained('bert-base-uncased', output_attentions=True) |
| 1497 | + >>> model.config.output_attentions |
| 1498 | + True |
| 1499 | +
|
| 1500 | + >>> # Loading from a PyTorch checkpoint file instead of a TensorFlow model (slower) |
| 1501 | + >>> config = AutoConfig.from_json_file('./pt_model/bert_pt_model_config.json') |
| 1502 | + >>> model = TFAutoModelForNextSentencePrediction.from_pretrained('./pt_model/bert_pytorch_model.bin', from_pt=True, config=config) |
| 1503 | + """ |
| 1504 | + config = kwargs.pop("config", None) |
| 1505 | + if not isinstance(config, PretrainedConfig): |
| 1506 | + config, kwargs = AutoConfig.from_pretrained( |
| 1507 | + pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs |
| 1508 | + ) |
| 1509 | + |
| 1510 | + if type(config) in TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING.keys(): |
| 1511 | + return TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING[type(config)].from_pretrained( |
| 1512 | + pretrained_model_name_or_path, *model_args, config=config, **kwargs |
| 1513 | + ) |
| 1514 | + raise ValueError( |
| 1515 | + "Unrecognized configuration class {} for this kind of TFAutoModel: {}.\n" |
| 1516 | + "Model type should be one of {}.".format( |
| 1517 | + config.__class__, |
| 1518 | + cls.__name__, |
| 1519 | + ", ".join(c.__name__ for c in TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING.keys()), |
| 1520 | + ) |
| 1521 | + ) |
0 commit comments