Skip to content

Commit

Permalink
[Tokenizer] Add replace_additional_special_tokens parameter to add_sp…
Browse files Browse the repository at this point in the history
…ecial_tokens (PaddlePaddle#9144)
  • Loading branch information
lvdongyi authored Sep 19, 2024
1 parent 7faad55 commit 90cef20
Show file tree
Hide file tree
Showing 3 changed files with 86 additions and 14 deletions.
31 changes: 28 additions & 3 deletions paddlenlp/transformers/luke/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,6 +482,12 @@ def __call__(

return encode_output

def __len__(self):
"""
Size of the full vocabulary with the added tokens.
"""
return len(self.encoder) + len(self.added_tokens_encoder)

def tokenize(self, text, add_prefix_space=False):
"""
Tokenize a string.
Expand Down Expand Up @@ -608,22 +614,41 @@ def _convert_token_to_id_with_added_voc(self, token):

return self._convert_token_to_id(token)

def add_special_tokens(self, token_list: Union[List[int], Dict]):
def add_special_tokens(self, token_list: Union[List[int], Dict], replace_additional_special_tokens: bool = True):
"""
Adding special tokens if you need.
Args:
token_list (List[int], Dict[List[int]]):
The special token list you provided. If you provide a Dict, the key of the Dict must
be "additional_special_tokens" and the value must be token list.
replace_additional_special_tokens (bool, optional, defaults to True):
If True, the existing list of additional special tokens will be replaced by the list provided in
`token_list`. Otherwise, `self._additional_special_tokens` is just extended. In the former
case, the tokens will NOT be removed from the tokenizer's full vocabulary - they are only being flagged
as non-special tokens. Remember, this only affects which tokens are skipped during decoding, not the
`added_tokens_encoder` and `added_tokens_decoder`. This means that the previous
`additional_special_tokens` are still added tokens, and will not be split by the model.
"""
if isinstance(token_list, dict):
token_list = token_list["additional_special_tokens"]

if replace_additional_special_tokens:
self._additional_special_tokens = list(token_list)
else:
self._additional_special_tokens.extend(
[token for token in token_list if token not in self._additional_special_tokens]
)
encoder_dict = dict()
decoder_dict = dict()

token_id_counter = len(self)
for token in token_list:
encoder_dict[token] = len(self.encoder.keys())
decoder_dict[len(self.decoder.keys())] = token
if token not in self.added_tokens_encoder:
encoder_dict[token] = token_id_counter
decoder_dict[token_id_counter] = token
token_id_counter += 1

self.added_tokens_encoder.update(encoder_dict)
self.added_tokens_decoder.update(decoder_dict)

Expand Down
44 changes: 33 additions & 11 deletions paddlenlp/transformers/tokenizer_utils_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -801,14 +801,16 @@ def sanitize_special_tokens(self) -> int:
"""
return self.add_tokens(self.all_special_tokens_extended, special_tokens=True)

def add_special_tokens(self, special_tokens_dict: Dict[str, Union[str, AddedToken]]) -> int:
def add_special_tokens(
self, special_tokens_dict: Dict[str, Union[str, AddedToken]], replace_additional_special_tokens=True
) -> int:
"""
Add a dictionary of special tokens (eos, pad, cls, etc.) to the encoder and link them to class attributes. If
special tokens are NOT in the vocabulary, they are added to it (indexed starting from the last index of the
current vocabulary).
Note,None When adding new tokens to the vocabulary, you should make sure to also resize the token embedding
matrix of the model so that its embedding matrix matches the tokenizer.
When adding new tokens to the vocabulary, you should make sure to also resize the token embedding matrix of the
model so that its embedding matrix matches the tokenizer.
In order to do that, please use the [`~PreTrainedModel.resize_token_embeddings`] method.
Expand All @@ -829,6 +831,13 @@ def add_special_tokens(self, special_tokens_dict: Dict[str, Union[str, AddedToke
Tokens are only added if they are not already in the vocabulary (tested by checking if the tokenizer
assign the index of the `unk_token` to them).
replace_additional_special_tokens (`bool`, *optional*,, defaults to `True`):
If `True`, the existing list of additional special tokens will be replaced by the list provided in
`special_tokens_dict`. Otherwise, `self._additional_special_tokens` is just extended. In the former
case, the tokens will NOT be removed from the tokenizer's full vocabulary - they are only being flagged
as non-special tokens. Remember, this only affects which tokens are skipped during decoding, not the
`added_tokens_encoder` and `added_tokens_decoder`. This means that the previous
`additional_special_tokens` are still added tokens, and will not be split by the model.
Returns:
`int`: Number of tokens added to the vocabulary.
Expand All @@ -852,25 +861,38 @@ def add_special_tokens(self, special_tokens_dict: Dict[str, Union[str, AddedToke
if not special_tokens_dict:
return 0

added_tokens = 0
added_tokens = []
for key, value in special_tokens_dict.items():
assert key in self.SPECIAL_TOKENS_ATTRIBUTES, f"Key {key} is not a special token"

if self.verbose:
logger.info(f"Assigning {value} to the {key} key of the tokenizer")
setattr(self, key, value)

if key == "additional_special_tokens":
assert isinstance(value, (list, tuple)) and all(
isinstance(t, (str, AddedToken)) for t in value
), f"Tokens {value} for key {key} should all be str or AddedToken instances"
added_tokens += self.add_tokens(value, special_tokens=True)
else:
assert isinstance(
value, (str, AddedToken)
), f"Token {value} for key {key} should be a str or an AddedToken instance"
added_tokens += self.add_tokens([value], special_tokens=True)

to_add = []
for token in value:
if not replace_additional_special_tokens and str(token) in self.additional_special_tokens:
continue
to_add.append(token)
if replace_additional_special_tokens and len(to_add) > 0:
setattr(self, key, list(to_add))
else:
self._additional_special_tokens.extend(to_add)
added_tokens += to_add

else:
if not isinstance(value, (str, AddedToken)):
raise ValueError(f"Token {value} for key {key} should be a str or an AddedToken instance")
setattr(self, key, value)
if value not in added_tokens:
added_tokens.append(value)

# if we are adding tokens that were not part of the vocab, we ought to add them
added_tokens = self.add_tokens(added_tokens, special_tokens=True)
return added_tokens

def add_tokens(
Expand Down
25 changes: 25 additions & 0 deletions tests/transformers/test_tokenizer_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1156,6 +1156,31 @@ def test_maximum_encoding_length_pair_input(self):

# self.assertEqual(encoded_masked, encoded_1)

def test_special_token_addition(self):
for tokenizer, pretrained_name, kwargs in self.tokenizers_list:
with self.subTest(f"{tokenizer.__class__.__name__} ({pretrained_name})"):
# Create tokenizer and add an additional special token
tokenizer_1 = tokenizer.from_pretrained(pretrained_name)
tokenizer_1.add_special_tokens({"additional_special_tokens": ["<tok>"]})
self.assertEqual(tokenizer_1.additional_special_tokens, ["<tok>"])
with tempfile.TemporaryDirectory() as tmp_dir:
tokenizer_1.save_pretrained(tmp_dir)
# Load the above tokenizer and add the same special token a second time
tokenizer_2 = tokenizer.from_pretrained(pretrained_name)
tokenizer_2.add_special_tokens({"additional_special_tokens": ["<tok>"]})
self.assertEqual(tokenizer_2.additional_special_tokens, ["<tok>"])

tokenizer_2.add_special_tokens({"additional_special_tokens": ["<tok>", "<other>"]})
self.assertEqual(tokenizer_2.additional_special_tokens, ["<tok>", "<other>"])
tokenizer_2.add_special_tokens({"additional_special_tokens": ["<other>", "<another>"]})
self.assertEqual(tokenizer_2.additional_special_tokens, ["<other>", "<another>"])

tokenizer_2.add_special_tokens(
{"additional_special_tokens": ["<tok>"]},
replace_additional_special_tokens=False,
)
self.assertEqual(tokenizer_2.additional_special_tokens, ["<other>", "<another>", "<tok>"])

def test_special_tokens_mask(self):
tokenizers = self.get_tokenizers(do_lower_case=False)
for tokenizer in tokenizers:
Expand Down

0 comments on commit 90cef20

Please sign in to comment.