Skip to content

Fix model templates #8595

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Nov 17, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 12 additions & 5 deletions src/transformers/commands/add_new_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def run(self):
path_to_transformer_root = (
Path(__file__).parent.parent.parent.parent if self._path is None else Path(self._path).parent.parent
)
path_to_cookiecutter = path_to_transformer_root / "templates" / "cookiecutter"
path_to_cookiecutter = path_to_transformer_root / "templates" / "adding_a_new_model"

# Execute cookiecutter
if not self._testing:
Expand Down Expand Up @@ -75,9 +75,16 @@ def run(self):
output_pytorch = "PyTorch" in pytorch_or_tensorflow
output_tensorflow = "TensorFlow" in pytorch_or_tensorflow

model_dir = f"{path_to_transformer_root}/src/transformers/models/{lowercase_model_name}"
os.makedirs(model_dir, exist_ok=True)

shutil.move(
f"{directory}/__init__.py",
f"{model_dir}/__init__.py",
)
shutil.move(
f"{directory}/configuration_{lowercase_model_name}.py",
f"{path_to_transformer_root}/src/transformers/configuration_{lowercase_model_name}.py",
f"{model_dir}/configuration_{lowercase_model_name}.py",
)

def remove_copy_lines(path):
Expand All @@ -94,7 +101,7 @@ def remove_copy_lines(path):

shutil.move(
f"{directory}/modeling_{lowercase_model_name}.py",
f"{path_to_transformer_root}/src/transformers/modeling_{lowercase_model_name}.py",
f"{model_dir}/modeling_{lowercase_model_name}.py",
)

shutil.move(
Expand All @@ -111,7 +118,7 @@ def remove_copy_lines(path):

shutil.move(
f"{directory}/modeling_tf_{lowercase_model_name}.py",
f"{path_to_transformer_root}/src/transformers/modeling_tf_{lowercase_model_name}.py",
f"{model_dir}/modeling_tf_{lowercase_model_name}.py",
)

shutil.move(
Expand All @@ -129,7 +136,7 @@ def remove_copy_lines(path):

shutil.move(
f"{directory}/tokenization_{lowercase_model_name}.py",
f"{path_to_transformer_root}/src/transformers/tokenization_{lowercase_model_name}.py",
f"{model_dir}/tokenization_{lowercase_model_name}.py",
)

from os import fdopen, remove
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/auto/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
from ...configuration_utils import PretrainedConfig
from ...file_utils import add_start_docstrings
from ...utils import logging

# Add modeling imports here
from ..albert.modeling_albert import (
AlbertForMaskedLM,
AlbertForMultipleChoice,
Expand Down Expand Up @@ -228,8 +230,6 @@
)


# Add modeling imports here

logger = logging.get_logger(__name__)


Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/auto/modeling_tf_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
from ...configuration_utils import PretrainedConfig
from ...file_utils import add_start_docstrings
from ...utils import logging

# Add modeling imports here
from ..albert.modeling_tf_albert import (
TFAlbertForMaskedLM,
TFAlbertForMultipleChoice,
Expand Down Expand Up @@ -175,8 +177,6 @@
)


# Add modeling imports here

logger = logging.get_logger(__name__)


Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# flake8: noqa
# There's no way to ignore "F401 '...' imported but unused" warnings in this
# module, but to preserve other warnings. So, don't check this module at all.

{%- if cookiecutter.generate_tensorflow_and_pytorch == "PyTorch & TensorFlow" %}
from ...file_utils import is_tf_available, is_torch_available
{%- elif cookiecutter.generate_tensorflow_and_pytorch == "PyTorch" %}
from ...file_utils import is_torch_available
{%- elif cookiecutter.generate_tensorflow_and_pytorch == "TensorFlow" %}
from ...file_utils import is_tf_available
{% endif %}
from .configuration_{{cookiecutter.lowercase_modelname}} import {{cookiecutter.uppercase_modelname}}_PRETRAINED_CONFIG_ARCHIVE_MAP, {{cookiecutter.camelcase_modelname}}Config
from .tokenization_{{cookiecutter.lowercase_modelname}} import {{cookiecutter.camelcase_modelname}}Tokenizer

{%- if (cookiecutter.generate_tensorflow_and_pytorch == "PyTorch & TensorFlow" or cookiecutter.generate_tensorflow_and_pytorch == "PyTorch") %}
if is_torch_available():
from .modeling_{{cookiecutter.lowercase_modelname}} import (
{{cookiecutter.uppercase_modelname}}_PRETRAINED_MODEL_ARCHIVE_LIST,
{{cookiecutter.camelcase_modelname}}ForMaskedLM,
{{cookiecutter.camelcase_modelname}}ForMultipleChoice,
{{cookiecutter.camelcase_modelname}}ForQuestionAnswering,
{{cookiecutter.camelcase_modelname}}ForSequenceClassification,
{{cookiecutter.camelcase_modelname}}ForTokenClassification,
{{cookiecutter.camelcase_modelname}}Layer,
{{cookiecutter.camelcase_modelname}}Model,
{{cookiecutter.camelcase_modelname}}PreTrainedModel,
load_tf_weights_in_{{cookiecutter.lowercase_modelname}},
)
{% endif %}
{%- if (cookiecutter.generate_tensorflow_and_pytorch == "PyTorch & TensorFlow" or cookiecutter.generate_tensorflow_and_pytorch == "TensorFlow") %}
if is_tf_available():
from .modeling_tf_{{cookiecutter.lowercase_modelname}} import (
TF_{{cookiecutter.uppercase_modelname}}_PRETRAINED_MODEL_ARCHIVE_LIST,
TF{{cookiecutter.camelcase_modelname}}ForMaskedLM,
TF{{cookiecutter.camelcase_modelname}}ForMultipleChoice,
TF{{cookiecutter.camelcase_modelname}}ForQuestionAnswering,
TF{{cookiecutter.camelcase_modelname}}ForSequenceClassification,
TF{{cookiecutter.camelcase_modelname}}ForTokenClassification,
TF{{cookiecutter.camelcase_modelname}}Layer,
TF{{cookiecutter.camelcase_modelname}}Model,
TF{{cookiecutter.camelcase_modelname}}PreTrainedModel,
)
{% endif %}
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
# limitations under the License.
""" {{cookiecutter.modelname}} model configuration """

from .configuration_utils import PretrainedConfig
from .utils import logging
from ...configuration_utils import PretrainedConfig
from ...utils import logging


logger = logging.get_logger(__name__)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,14 @@

import tensorflow as tf

from .activations_tf import get_tf_activation
from .configuration_{{cookiecutter.lowercase_modelname}} import {{cookiecutter.camelcase_modelname}}Config
from .file_utils import (
from ...activations_tf import get_tf_activation
from ...file_utils import (
MULTIPLE_CHOICE_DUMMY_INPUTS,
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
)
from .modeling_tf_outputs import (
from ...modeling_tf_outputs import (
TFBaseModelOutput,
TFBaseModelOutputWithPooling,
TFMaskedLMOutput,
Expand All @@ -34,7 +33,7 @@
TFSequenceClassifierOutput,
TFTokenClassifierOutput,
)
from .modeling_tf_utils import (
from ...modeling_tf_utils import (
TFMaskedLanguageModelingLoss,
TFMultipleChoiceLoss,
TFPreTrainedModel,
Expand All @@ -46,8 +45,9 @@
keras_serializable,
shape_list,
)
from .tokenization_utils import BatchEncoding
from .utils import logging
from ...tokenization_utils import BatchEncoding
from ...utils import logging
from .configuration_{{cookiecutter.lowercase_modelname}} import {{cookiecutter.camelcase_modelname}}Config


logger = logging.get_logger(__name__)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,12 @@
from torch import nn
from torch.nn import CrossEntropyLoss, MSELoss

from .configuration_{{cookiecutter.lowercase_modelname}} import {{cookiecutter.camelcase_modelname}}Config
from .file_utils import (
from ...file_utils import (
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
)
from .modeling_outputs import (
from ...modeling_outputs import (
BaseModelOutput,
BaseModelOutputWithPooling,
MaskedLMOutput,
Expand All @@ -40,15 +39,16 @@
SequenceClassifierOutput,
TokenClassifierOutput,
)
from .modeling_utils import (
from ...modeling_utils import (
PreTrainedModel,
SequenceSummary,
apply_chunking_to_forward,
find_pruneable_heads_and_indices,
prune_linear_layer,
)
from .utils import logging
from .activations import ACT2FN
from ...utils import logging
from ...activations import ACT2FN
from .configuration_{{cookiecutter.lowercase_modelname}} import {{cookiecutter.camelcase_modelname}}Config


logger = logging.get_logger(__name__)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# To replace in: "src/transformers/__init__.py"
# Below: "if is_torch_available():" if generating PyTorch
# Replace with:
from .modeling_{{cookiecutter.lowercase_modelname}} import (
from .models.{{cookiecutter.lowercase_modelname}} import (
{{cookiecutter.uppercase_modelname}}_PRETRAINED_MODEL_ARCHIVE_LIST,
{{cookiecutter.camelcase_modelname}}ForMaskedLM,
{{cookiecutter.camelcase_modelname}}ForMultipleChoice,
Expand All @@ -30,7 +30,7 @@

# Below: "if is_tf_available():" if generating TensorFlow
# Replace with:
from .modeling_tf_{{cookiecutter.lowercase_modelname}} import (
from .models.{{cookiecutter.lowercase_modelname}} import (
TF_{{cookiecutter.uppercase_modelname}}_PRETRAINED_MODEL_ARCHIVE_LIST,
TF{{cookiecutter.camelcase_modelname}}ForMaskedLM,
TF{{cookiecutter.camelcase_modelname}}ForMultipleChoice,
Expand All @@ -44,14 +44,14 @@
# End.


# Below: "from .configuration_albert import ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, AlbertConfig"
# Below: "from .models.albert import ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, AlbertConfig"
# Replace with:
from .configuration_{{cookiecutter.lowercase_modelname}} import {{cookiecutter.uppercase_modelname}}_PRETRAINED_CONFIG_ARCHIVE_MAP, {{cookiecutter.camelcase_modelname}}Config
from .models.{{cookiecutter.lowercase_modelname}} import {{cookiecutter.uppercase_modelname}}_PRETRAINED_CONFIG_ARCHIVE_MAP, {{cookiecutter.camelcase_modelname}}Config
# End.



# To replace in: "src/transformers/configuration_auto.py"
# To replace in: "src/transformers/models/auto/configuration_auto.py"
# Below: "# Add configs here"
# Replace with:
("{{cookiecutter.lowercase_modelname}}", {{cookiecutter.camelcase_modelname}}Config),
Expand All @@ -62,9 +62,9 @@
{{cookiecutter.uppercase_modelname}}_PRETRAINED_CONFIG_ARCHIVE_MAP,
# End.

# Below: "from .configuration_albert import ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, AlbertConfig",
# Below: "from ..albert.configuration_albert import ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, AlbertConfig",
# Replace with:
from .configuration_{{cookiecutter.lowercase_modelname}} import {{cookiecutter.uppercase_modelname}}_PRETRAINED_CONFIG_ARCHIVE_MAP, {{cookiecutter.camelcase_modelname}}Config
from ..{{cookiecutter.lowercase_modelname}}.configuration_{{cookiecutter.lowercase_modelname}} import {{cookiecutter.uppercase_modelname}}_PRETRAINED_CONFIG_ARCHIVE_MAP, {{cookiecutter.camelcase_modelname}}Config
# End.

# Below: "# Add full (and cased) model names here"
Expand All @@ -83,7 +83,7 @@
# Below: "# Add modeling imports here"
# Replace with:

from .modeling_{{cookiecutter.lowercase_modelname}} import (
from ..{{cookiecutter.lowercase_modelname}}.modeling_{{cookiecutter.lowercase_modelname}} import (
{{cookiecutter.camelcase_modelname}}ForMaskedLM,
{{cookiecutter.camelcase_modelname}}ForMultipleChoice,
{{cookiecutter.camelcase_modelname}}ForQuestionAnswering,
Expand Down Expand Up @@ -138,7 +138,7 @@
# Below: "# Add modeling imports here"
# Replace with:

from .modeling_tf_{{cookiecutter.lowercase_modelname}} import (
from ..{{cookiecutter.lowercase_modelname}}.modeling_tf_{{cookiecutter.lowercase_modelname}} import (
TF{{cookiecutter.camelcase_modelname}}ForMaskedLM,
TF{{cookiecutter.camelcase_modelname}}ForMultipleChoice,
TF{{cookiecutter.camelcase_modelname}}ForQuestionAnswering,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,9 @@
"""Tokenization classes for {{cookiecutter.modelname}}."""

{%- if cookiecutter.tokenizer_type == "Based on BERT" %}
from .tokenization_bert import BertTokenizer, BertTokenizerFast
from .utils import logging
from ...utils import logging
from ..bert.tokenization_bert import BertTokenizer
from ..bert.tokenization_bert_fast import BertTokenizerFast


logger = logging.get_logger(__name__)
Expand Down Expand Up @@ -73,14 +74,14 @@ class {{cookiecutter.camelcase_modelname}}TokenizerFast(BertTokenizerFast):
pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION
{%- elif cookiecutter.tokenizer_type == "Standalone" %}
import warnings
from typing import List, Optional

from tokenizers import ByteLevelBPETokenizer

from .tokenization_utils import AddedToken, PreTrainedTokenizer
from .tokenization_utils_base import BatchEncoding
from .tokenization_utils_fast import PreTrainedTokenizerFast
from typing import List, Optional
from .utils import logging
from ...tokenization_utils import AddedToken, PreTrainedTokenizer
from ...tokenization_utils_base import BatchEncoding
from ...tokenization_utils_fast import PreTrainedTokenizerFast
from ...utils import logging


logger = logging.get_logger(__name__)
Expand Down