diff --git a/src/transformers/tokenization_utils.py b/src/transformers/tokenization_utils.py index 0146c0ef0d4f61..4cc93120ebd966 100644 --- a/src/transformers/tokenization_utils.py +++ b/src/transformers/tokenization_utils.py @@ -453,7 +453,9 @@ def tokenize(self, text: TextInput, **kwargs) -> List[str]: # TODO: should this be in the base class? if hasattr(self, "do_lower_case") and self.do_lower_case: # convert non-special tokens to lowercase - escaped_special_toks = [re.escape(s_tok) for s_tok in self.all_special_tokens] + escaped_special_toks = [ + re.escape(s_tok) for s_tok in (self.unique_no_split_tokens + self.all_special_tokens) + ] pattern = r"(" + r"|".join(escaped_special_toks) + r")|" + r"(.+?)" text = re.sub(pattern, lambda m: m.groups()[0] or m.groups()[1].lower(), text) diff --git a/tests/test_tokenization_canine.py b/tests/test_tokenization_canine.py index a0c5a09727762f..e6b786ccc902ae 100644 --- a/tests/test_tokenization_canine.py +++ b/tests/test_tokenization_canine.py @@ -160,6 +160,27 @@ def test_add_special_tokens(self): decoded = tokenizer.decode(encoded, skip_special_tokens=True) self.assertTrue(special_token not in decoded) + def test_tokenize_special_tokens(self): + tokenizers = self.get_tokenizers(do_lower_case=True) + for tokenizer in tokenizers: + with self.subTest(f"{tokenizer.__class__.__name__}"): + SPECIAL_TOKEN_1 = chr(0xE005) + SPECIAL_TOKEN_2 = chr(0xE006) + + # `add_tokens` method stores special tokens only in `tokenizer.unique_no_split_tokens`. (in tokenization_utils.py) + tokenizer.add_tokens([SPECIAL_TOKEN_1], special_tokens=True) + # `add_special_tokens` method stores special tokens in `tokenizer.additional_special_tokens`, + # which also occur in `tokenizer.all_special_tokens`. (in tokenization_utils_base.py) + tokenizer.add_special_tokens({"additional_special_tokens": [SPECIAL_TOKEN_2]}) + + token_1 = tokenizer.tokenize(SPECIAL_TOKEN_1) + token_2 = tokenizer.tokenize(SPECIAL_TOKEN_2) + + self.assertEqual(len(token_1), 1) + self.assertEqual(len(token_2), 1) + self.assertEqual(token_1[0], SPECIAL_TOKEN_1) + self.assertEqual(token_2[0], SPECIAL_TOKEN_2) + @require_tokenizers def test_added_token_serializable(self): tokenizers = self.get_tokenizers(do_lower_case=False) diff --git a/tests/test_tokenization_common.py b/tests/test_tokenization_common.py index 5cea9a8c4a51a2..50aca9c4c8af8c 100644 --- a/tests/test_tokenization_common.py +++ b/tests/test_tokenization_common.py @@ -310,6 +310,33 @@ def convert_batch_encode_plus_format_to_encode_plus(batch_encode_plus_sequences) for i in range(len(batch_encode_plus_sequences["input_ids"])) ] + # TODO: this test can be combined with `test_sentencepiece_tokenize_and_convert_tokens_to_string` after the latter is extended to all tokenizers. + def test_tokenize_special_tokens(self): + """Test `tokenize` with special tokens.""" + tokenizers = self.get_tokenizers(fast=True) + for tokenizer in tokenizers: + with self.subTest(f"{tokenizer.__class__.__name__}"): + SPECIAL_TOKEN_1 = "[SPECIAL_TOKEN_1]" + SPECIAL_TOKEN_2 = "[SPECIAL_TOKEN_2]" + + # TODO: + # Can we combine `unique_no_split_tokens` and `all_special_tokens`(and properties related to it) + # with one variable(property) for a better maintainability? + + # `add_tokens` method stores special tokens only in `tokenizer.unique_no_split_tokens`. (in tokenization_utils.py) + tokenizer.add_tokens([SPECIAL_TOKEN_1], special_tokens=True) + # `add_special_tokens` method stores special tokens in `tokenizer.additional_special_tokens`, + # which also occur in `tokenizer.all_special_tokens`. (in tokenization_utils_base.py) + tokenizer.add_special_tokens({"additional_special_tokens": [SPECIAL_TOKEN_2]}) + + token_1 = tokenizer.tokenize(SPECIAL_TOKEN_1) + token_2 = tokenizer.tokenize(SPECIAL_TOKEN_2) + + self.assertEqual(len(token_1), 1) + self.assertEqual(len(token_2), 1) + self.assertEqual(token_1[0], SPECIAL_TOKEN_1) + self.assertEqual(token_2[0], SPECIAL_TOKEN_2) + # TODO: this test could be extended to all tokenizers - not just the sentencepiece def test_sentencepiece_tokenize_and_convert_tokens_to_string(self): """Test ``_tokenize`` and ``convert_tokens_to_string``."""