Skip to content

Commit

Permalink
Fix special tokens not correctly tokenized (#13489)
Browse files Browse the repository at this point in the history
* Fix special tokens not correctly tokenized

* Add testing

* Fix

* Fix

* Use user workflows instead of directly assigning variables

* Enable test of fast tokenizers

* Update test of canine tokenizer
  • Loading branch information
qqaatw authored Sep 17, 2021
1 parent 1f9dcfc commit da8beaa
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 1 deletion.
4 changes: 3 additions & 1 deletion src/transformers/tokenization_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,7 +453,9 @@ def tokenize(self, text: TextInput, **kwargs) -> List[str]:
# TODO: should this be in the base class?
if hasattr(self, "do_lower_case") and self.do_lower_case:
# convert non-special tokens to lowercase
escaped_special_toks = [re.escape(s_tok) for s_tok in self.all_special_tokens]
escaped_special_toks = [
re.escape(s_tok) for s_tok in (self.unique_no_split_tokens + self.all_special_tokens)
]
pattern = r"(" + r"|".join(escaped_special_toks) + r")|" + r"(.+?)"
text = re.sub(pattern, lambda m: m.groups()[0] or m.groups()[1].lower(), text)

Expand Down
21 changes: 21 additions & 0 deletions tests/test_tokenization_canine.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,27 @@ def test_add_special_tokens(self):
decoded = tokenizer.decode(encoded, skip_special_tokens=True)
self.assertTrue(special_token not in decoded)

def test_tokenize_special_tokens(self):
tokenizers = self.get_tokenizers(do_lower_case=True)
for tokenizer in tokenizers:
with self.subTest(f"{tokenizer.__class__.__name__}"):
SPECIAL_TOKEN_1 = chr(0xE005)
SPECIAL_TOKEN_2 = chr(0xE006)

# `add_tokens` method stores special tokens only in `tokenizer.unique_no_split_tokens`. (in tokenization_utils.py)
tokenizer.add_tokens([SPECIAL_TOKEN_1], special_tokens=True)
# `add_special_tokens` method stores special tokens in `tokenizer.additional_special_tokens`,
# which also occur in `tokenizer.all_special_tokens`. (in tokenization_utils_base.py)
tokenizer.add_special_tokens({"additional_special_tokens": [SPECIAL_TOKEN_2]})

token_1 = tokenizer.tokenize(SPECIAL_TOKEN_1)
token_2 = tokenizer.tokenize(SPECIAL_TOKEN_2)

self.assertEqual(len(token_1), 1)
self.assertEqual(len(token_2), 1)
self.assertEqual(token_1[0], SPECIAL_TOKEN_1)
self.assertEqual(token_2[0], SPECIAL_TOKEN_2)

@require_tokenizers
def test_added_token_serializable(self):
tokenizers = self.get_tokenizers(do_lower_case=False)
Expand Down
27 changes: 27 additions & 0 deletions tests/test_tokenization_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,33 @@ def convert_batch_encode_plus_format_to_encode_plus(batch_encode_plus_sequences)
for i in range(len(batch_encode_plus_sequences["input_ids"]))
]

# TODO: this test can be combined with `test_sentencepiece_tokenize_and_convert_tokens_to_string` after the latter is extended to all tokenizers.
def test_tokenize_special_tokens(self):
"""Test `tokenize` with special tokens."""
tokenizers = self.get_tokenizers(fast=True)
for tokenizer in tokenizers:
with self.subTest(f"{tokenizer.__class__.__name__}"):
SPECIAL_TOKEN_1 = "[SPECIAL_TOKEN_1]"
SPECIAL_TOKEN_2 = "[SPECIAL_TOKEN_2]"

# TODO:
# Can we combine `unique_no_split_tokens` and `all_special_tokens`(and properties related to it)
# with one variable(property) for a better maintainability?

# `add_tokens` method stores special tokens only in `tokenizer.unique_no_split_tokens`. (in tokenization_utils.py)
tokenizer.add_tokens([SPECIAL_TOKEN_1], special_tokens=True)
# `add_special_tokens` method stores special tokens in `tokenizer.additional_special_tokens`,
# which also occur in `tokenizer.all_special_tokens`. (in tokenization_utils_base.py)
tokenizer.add_special_tokens({"additional_special_tokens": [SPECIAL_TOKEN_2]})

token_1 = tokenizer.tokenize(SPECIAL_TOKEN_1)
token_2 = tokenizer.tokenize(SPECIAL_TOKEN_2)

self.assertEqual(len(token_1), 1)
self.assertEqual(len(token_2), 1)
self.assertEqual(token_1[0], SPECIAL_TOKEN_1)
self.assertEqual(token_2[0], SPECIAL_TOKEN_2)

# TODO: this test could be extended to all tokenizers - not just the sentencepiece
def test_sentencepiece_tokenize_and_convert_tokens_to_string(self):
"""Test ``_tokenize`` and ``convert_tokens_to_string``."""
Expand Down

0 comments on commit da8beaa

Please sign in to comment.