Skip to content

Commit

Permalink
Improve bert-japanese tokenizer handling (huggingface#8659)
Browse files Browse the repository at this point in the history
* Make ci fail

* Try to make tests actually run?

* CI finally failing?

* Fix CI

* Revert "Fix CI"

This reverts commit ca7923b.

* Ooops wrong one

* one more try

* Ok ok let's move this elsewhere

* Alternative to globals() (huggingface#8667)

* Alternative to globals()

* Error is raised later so return None

* Sentencepiece not installed make some tokenizers None

* Apply Lysandre wisdom

* Slightly clearer comment?

cc @sgugger

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
  • Loading branch information
julien-c and sgugger authored Nov 23, 2020
1 parent eec7661 commit 0cc5ab1
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 9 deletions.
2 changes: 1 addition & 1 deletion .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ jobs:
run_tests_custom_tokenizers:
working_directory: ~/transformers
docker:
- image: circleci/python:3.6
- image: circleci/python:3.7
environment:
RUN_CUSTOM_TOKENIZERS: yes
steps:
Expand Down
34 changes: 26 additions & 8 deletions src/transformers/models/auto/tokenization_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,8 +185,6 @@
(LongformerConfig, (LongformerTokenizer, LongformerTokenizerFast)),
(BartConfig, (BartTokenizer, BartTokenizerFast)),
(LongformerConfig, (LongformerTokenizer, LongformerTokenizerFast)),
(RobertaConfig, (BertweetTokenizer, None)),
(RobertaConfig, (PhobertTokenizer, None)),
(RobertaConfig, (RobertaTokenizer, RobertaTokenizerFast)),
(ReformerConfig, (ReformerTokenizer, ReformerTokenizerFast)),
(ElectraConfig, (ElectraTokenizer, ElectraTokenizerFast)),
Expand All @@ -195,7 +193,6 @@
(LayoutLMConfig, (LayoutLMTokenizer, LayoutLMTokenizerFast)),
(DPRConfig, (DPRQuestionEncoderTokenizer, DPRQuestionEncoderTokenizerFast)),
(SqueezeBertConfig, (SqueezeBertTokenizer, SqueezeBertTokenizerFast)),
(BertConfig, (HerbertTokenizer, HerbertTokenizerFast)),
(BertConfig, (BertTokenizer, BertTokenizerFast)),
(OpenAIGPTConfig, (OpenAIGPTTokenizer, OpenAIGPTTokenizerFast)),
(GPT2Config, (GPT2Tokenizer, GPT2TokenizerFast)),
Expand All @@ -213,13 +210,34 @@
]
)

# For tokenizers which are not directly mapped from a config
NO_CONFIG_TOKENIZER = [
BertJapaneseTokenizer,
BertweetTokenizer,
HerbertTokenizer,
HerbertTokenizerFast,
PhobertTokenizer,
]


SLOW_TOKENIZER_MAPPING = {
k: (v[0] if v[0] is not None else v[1])
for k, v in TOKENIZER_MAPPING.items()
if (v[0] is not None or v[1] is not None)
}


def tokenizer_class_from_name(class_name: str):
all_tokenizer_classes = (
[v[0] for v in TOKENIZER_MAPPING.values() if v[0] is not None]
+ [v[1] for v in TOKENIZER_MAPPING.values() if v[1] is not None]
+ NO_CONFIG_TOKENIZER
)
for c in all_tokenizer_classes:
if c.__name__ == class_name:
return c


class AutoTokenizer:
r"""
This is a generic tokenizer class that will be instantiated as one of the tokenizer classes of the library when
Expand Down Expand Up @@ -307,17 +325,17 @@ def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
if not isinstance(config, PretrainedConfig):
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)

if "bert-base-japanese" in str(pretrained_model_name_or_path):
return BertJapaneseTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)

use_fast = kwargs.pop("use_fast", True)

if config.tokenizer_class is not None:
tokenizer_class = None
if use_fast and not config.tokenizer_class.endswith("Fast"):
tokenizer_class_candidate = f"{config.tokenizer_class}Fast"
else:
tokenizer_class = tokenizer_class_from_name(tokenizer_class_candidate)
if tokenizer_class is None:
tokenizer_class_candidate = config.tokenizer_class
tokenizer_class = globals().get(tokenizer_class_candidate)
tokenizer_class = tokenizer_class_from_name(tokenizer_class_candidate)

if tokenizer_class is None:
raise ValueError(
"Tokenizer class {} does not exist or is not currently imported.".format(tokenizer_class_candidate)
Expand Down
9 changes: 9 additions & 0 deletions tests/test_tokenization_bert_japanese.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import pickle
import unittest

from transformers import AutoTokenizer
from transformers.models.bert_japanese.tokenization_bert_japanese import (
VOCAB_FILES_NAMES,
BertJapaneseTokenizer,
Expand Down Expand Up @@ -267,3 +268,11 @@ def test_sequence_builders(self):
# 2 is for "[CLS]", 3 is for "[SEP]"
assert encoded_sentence == [2] + text + [3]
assert encoded_pair == [2] + text + [3] + text_2 + [3]


@custom_tokenizers
class AutoTokenizerCustomTest(unittest.TestCase):
def test_tokenizer_bert_japanese(self):
EXAMPLE_BERT_JAPANESE_ID = "cl-tohoku/bert-base-japanese"
tokenizer = AutoTokenizer.from_pretrained(EXAMPLE_BERT_JAPANESE_ID)
self.assertIsInstance(tokenizer, BertJapaneseTokenizer)

0 comments on commit 0cc5ab1

Please sign in to comment.