Skip to content

Commit

Permalink
[AutoTokenizer] Allow creation of tokenizers by tokenizer type (huggi…
Browse files Browse the repository at this point in the history
…ngface#13668)

* up

* up
  • Loading branch information
patrickvonplaten authored and stas00 committed Oct 12, 2021
1 parent b0bbebf commit 32ce915
Show file tree
Hide file tree
Showing 5 changed files with 81 additions and 1 deletion.
29 changes: 28 additions & 1 deletion src/transformers/models/auto/tokenization_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,6 +402,8 @@ def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
facebook/rag-token-base), specify it here.
use_fast (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether or not to try to load the fast version of the tokenizer.
tokenizer_type (:obj:`str`, `optional`):
Tokenizer type to be loaded.
kwargs (additional keyword arguments, `optional`):
Will be passed to the Tokenizer ``__init__()`` method. Can be used to set special tokens like
``bos_token``, ``eos_token``, ``unk_token``, ``sep_token``, ``pad_token``, ``cls_token``,
Expand All @@ -425,8 +427,33 @@ def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
kwargs["_from_auto"] = True

use_fast = kwargs.pop("use_fast", True)
tokenizer_type = kwargs.pop("tokenizer_type", None)

# First, let's try to use the tokenizer_config file to get the tokenizer class.
# First, let's see whether the tokenizer_type is passed so that we can leverage it
if tokenizer_type is not None:
tokenizer_class = None
tokenizer_class_tuple = TOKENIZER_MAPPING_NAMES.get(tokenizer_type, None)

if tokenizer_class_tuple is None:
raise ValueError(
f"Passed `tokenizer_type` {tokenizer_type} does not exist. `tokenizer_type` should be one of "
f"{', '.join(c for c in TOKENIZER_MAPPING_NAMES.keys())}."
)

tokenizer_class_name, tokenizer_fast_class_name = tokenizer_class_tuple

if use_fast and tokenizer_fast_class_name is not None:
tokenizer_class = tokenizer_class_from_name(tokenizer_fast_class_name)

if tokenizer_class is None:
tokenizer_class = tokenizer_class_from_name(tokenizer_class_name)

if tokenizer_class is None:
raise ValueError(f"Tokenizer class {tokenizer_class_name} is not currently imported.")

return tokenizer_class.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)

# Next, let's try to use the tokenizer_config file to get the tokenizer class.
tokenizer_config = get_tokenizer_config(pretrained_model_name_or_path, **kwargs)
config_tokenizer_class = tokenizer_config.get("tokenizer_class")

Expand Down
5 changes: 5 additions & 0 deletions tests/fixtures/merges.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
#version: 0.2
Ġ l
Ġl o
Ġlo w
e r
1 change: 1 addition & 0 deletions tests/fixtures/vocab.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"l": 0, "o": 1, "w": 2, "e": 3, "r": 4, "s": 5, "t": 6, "i": 7, "d": 8, "n": 9, "Ġ": 10, "Ġl": 11, "Ġn": 12, "Ġlo": 13, "Ġlow": 14, "er": 15, "Ġlowest": 16, "Ġnewer": 17, "Ġwider": 18, "<unk>": 19, "<|endoftext|>": 20}
10 changes: 10 additions & 0 deletions tests/fixtures/vocab.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
[PAD]
[SEP]
[MASK]
[CLS]
[unused3]
[unused4]
[unused5]
[unused6]
[unused7]
[unused8]
37 changes: 37 additions & 0 deletions tests/test_tokenization_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import shutil
import tempfile
import unittest

import pytest

from transformers import (
BERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP,
Expand Down Expand Up @@ -78,6 +82,39 @@ def test_tokenizer_from_tokenizer_class(self):
self.assertIsInstance(tokenizer, (BertTokenizer, BertTokenizerFast))
self.assertEqual(tokenizer.vocab_size, 12)

def test_tokenizer_from_type(self):
with tempfile.TemporaryDirectory() as tmp_dir:
shutil.copy("./tests/fixtures/vocab.txt", os.path.join(tmp_dir, "vocab.txt"))

tokenizer = AutoTokenizer.from_pretrained(tmp_dir, tokenizer_type="bert", use_fast=False)
self.assertIsInstance(tokenizer, BertTokenizer)

with tempfile.TemporaryDirectory() as tmp_dir:
shutil.copy("./tests/fixtures/vocab.json", os.path.join(tmp_dir, "vocab.json"))
shutil.copy("./tests/fixtures/merges.txt", os.path.join(tmp_dir, "merges.txt"))

tokenizer = AutoTokenizer.from_pretrained(tmp_dir, tokenizer_type="gpt2", use_fast=False)
self.assertIsInstance(tokenizer, GPT2Tokenizer)

@require_tokenizers
def test_tokenizer_from_type_fast(self):
with tempfile.TemporaryDirectory() as tmp_dir:
shutil.copy("./tests/fixtures/vocab.txt", os.path.join(tmp_dir, "vocab.txt"))

tokenizer = AutoTokenizer.from_pretrained(tmp_dir, tokenizer_type="bert")
self.assertIsInstance(tokenizer, BertTokenizerFast)

with tempfile.TemporaryDirectory() as tmp_dir:
shutil.copy("./tests/fixtures/vocab.json", os.path.join(tmp_dir, "vocab.json"))
shutil.copy("./tests/fixtures/merges.txt", os.path.join(tmp_dir, "merges.txt"))

tokenizer = AutoTokenizer.from_pretrained(tmp_dir, tokenizer_type="gpt2")
self.assertIsInstance(tokenizer, GPT2TokenizerFast)

def test_tokenizer_from_type_incorrect_name(self):
with pytest.raises(ValueError):
AutoTokenizer.from_pretrained("./", tokenizer_type="xxx")

@require_tokenizers
def test_tokenizer_identifier_with_correct_config(self):
for tokenizer_class in [BertTokenizer, BertTokenizerFast, AutoTokenizer]:
Expand Down

0 comments on commit 32ce915

Please sign in to comment.