diff --git a/src/transformers/models/auto/tokenization_auto.py b/src/transformers/models/auto/tokenization_auto.py index a706532e9ed4f9..a30d9818742483 100644 --- a/src/transformers/models/auto/tokenization_auto.py +++ b/src/transformers/models/auto/tokenization_auto.py @@ -402,6 +402,8 @@ def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs): facebook/rag-token-base), specify it here. use_fast (:obj:`bool`, `optional`, defaults to :obj:`True`): Whether or not to try to load the fast version of the tokenizer. + tokenizer_type (:obj:`str`, `optional`): + Tokenizer type to be loaded. kwargs (additional keyword arguments, `optional`): Will be passed to the Tokenizer ``__init__()`` method. Can be used to set special tokens like ``bos_token``, ``eos_token``, ``unk_token``, ``sep_token``, ``pad_token``, ``cls_token``, @@ -425,8 +427,33 @@ def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs): kwargs["_from_auto"] = True use_fast = kwargs.pop("use_fast", True) + tokenizer_type = kwargs.pop("tokenizer_type", None) - # First, let's try to use the tokenizer_config file to get the tokenizer class. + # First, let's see whether the tokenizer_type is passed so that we can leverage it + if tokenizer_type is not None: + tokenizer_class = None + tokenizer_class_tuple = TOKENIZER_MAPPING_NAMES.get(tokenizer_type, None) + + if tokenizer_class_tuple is None: + raise ValueError( + f"Passed `tokenizer_type` {tokenizer_type} does not exist. `tokenizer_type` should be one of " + f"{', '.join(c for c in TOKENIZER_MAPPING_NAMES.keys())}." + ) + + tokenizer_class_name, tokenizer_fast_class_name = tokenizer_class_tuple + + if use_fast and tokenizer_fast_class_name is not None: + tokenizer_class = tokenizer_class_from_name(tokenizer_fast_class_name) + + if tokenizer_class is None: + tokenizer_class = tokenizer_class_from_name(tokenizer_class_name) + + if tokenizer_class is None: + raise ValueError(f"Tokenizer class {tokenizer_class_name} is not currently imported.") + + return tokenizer_class.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) + + # Next, let's try to use the tokenizer_config file to get the tokenizer class. tokenizer_config = get_tokenizer_config(pretrained_model_name_or_path, **kwargs) config_tokenizer_class = tokenizer_config.get("tokenizer_class") diff --git a/tests/fixtures/merges.txt b/tests/fixtures/merges.txt new file mode 100644 index 00000000000000..d7c5738baaf430 --- /dev/null +++ b/tests/fixtures/merges.txt @@ -0,0 +1,5 @@ +#version: 0.2 +Ġ l +Ġl o +Ġlo w +e r diff --git a/tests/fixtures/vocab.json b/tests/fixtures/vocab.json new file mode 100644 index 00000000000000..c5d99b8ae9d4d6 --- /dev/null +++ b/tests/fixtures/vocab.json @@ -0,0 +1 @@ +{"l": 0, "o": 1, "w": 2, "e": 3, "r": 4, "s": 5, "t": 6, "i": 7, "d": 8, "n": 9, "Ġ": 10, "Ġl": 11, "Ġn": 12, "Ġlo": 13, "Ġlow": 14, "er": 15, "Ġlowest": 16, "Ġnewer": 17, "Ġwider": 18, "": 19, "<|endoftext|>": 20} diff --git a/tests/fixtures/vocab.txt b/tests/fixtures/vocab.txt new file mode 100644 index 00000000000000..ad9f94bc6876d3 --- /dev/null +++ b/tests/fixtures/vocab.txt @@ -0,0 +1,10 @@ +[PAD] +[SEP] +[MASK] +[CLS] +[unused3] +[unused4] +[unused5] +[unused6] +[unused7] +[unused8] diff --git a/tests/test_tokenization_auto.py b/tests/test_tokenization_auto.py index cd6b12335ec625..b00f68f30032d6 100644 --- a/tests/test_tokenization_auto.py +++ b/tests/test_tokenization_auto.py @@ -13,9 +13,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os +import shutil import tempfile import unittest +import pytest + from transformers import ( BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP, @@ -78,6 +82,39 @@ def test_tokenizer_from_tokenizer_class(self): self.assertIsInstance(tokenizer, (BertTokenizer, BertTokenizerFast)) self.assertEqual(tokenizer.vocab_size, 12) + def test_tokenizer_from_type(self): + with tempfile.TemporaryDirectory() as tmp_dir: + shutil.copy("./tests/fixtures/vocab.txt", os.path.join(tmp_dir, "vocab.txt")) + + tokenizer = AutoTokenizer.from_pretrained(tmp_dir, tokenizer_type="bert", use_fast=False) + self.assertIsInstance(tokenizer, BertTokenizer) + + with tempfile.TemporaryDirectory() as tmp_dir: + shutil.copy("./tests/fixtures/vocab.json", os.path.join(tmp_dir, "vocab.json")) + shutil.copy("./tests/fixtures/merges.txt", os.path.join(tmp_dir, "merges.txt")) + + tokenizer = AutoTokenizer.from_pretrained(tmp_dir, tokenizer_type="gpt2", use_fast=False) + self.assertIsInstance(tokenizer, GPT2Tokenizer) + + @require_tokenizers + def test_tokenizer_from_type_fast(self): + with tempfile.TemporaryDirectory() as tmp_dir: + shutil.copy("./tests/fixtures/vocab.txt", os.path.join(tmp_dir, "vocab.txt")) + + tokenizer = AutoTokenizer.from_pretrained(tmp_dir, tokenizer_type="bert") + self.assertIsInstance(tokenizer, BertTokenizerFast) + + with tempfile.TemporaryDirectory() as tmp_dir: + shutil.copy("./tests/fixtures/vocab.json", os.path.join(tmp_dir, "vocab.json")) + shutil.copy("./tests/fixtures/merges.txt", os.path.join(tmp_dir, "merges.txt")) + + tokenizer = AutoTokenizer.from_pretrained(tmp_dir, tokenizer_type="gpt2") + self.assertIsInstance(tokenizer, GPT2TokenizerFast) + + def test_tokenizer_from_type_incorrect_name(self): + with pytest.raises(ValueError): + AutoTokenizer.from_pretrained("./", tokenizer_type="xxx") + @require_tokenizers def test_tokenizer_identifier_with_correct_config(self): for tokenizer_class in [BertTokenizer, BertTokenizerFast, AutoTokenizer]: