From 172ed00a5477459b6491225e9dc0f1a1a5f23887 Mon Sep 17 00:00:00 2001 From: Arthur <48595927+ArthurZucker@users.noreply.github.com> Date: Tue, 30 May 2023 16:50:41 +0200 Subject: [PATCH] [LlamaTokenizerFast] nit update `post_processor` on the fly (#23855) * Update the processor when changing add_eos and add_bos * fixup * update * add a test * fix failing tests * fixup --- .../models/llama/tokenization_llama_fast.py | 44 +++++++++++++++++++ tests/models/llama/test_tokenization_llama.py | 33 ++++++++++++++ 2 files changed, 77 insertions(+) diff --git a/src/transformers/models/llama/tokenization_llama_fast.py b/src/transformers/models/llama/tokenization_llama_fast.py index bb2737075ea2ad..c3946d83b0e0b8 100644 --- a/src/transformers/models/llama/tokenization_llama_fast.py +++ b/src/transformers/models/llama/tokenization_llama_fast.py @@ -16,6 +16,8 @@ from shutil import copyfile from typing import Optional, Tuple +from tokenizers import processors + from ...tokenization_utils_fast import PreTrainedTokenizerFast from ...utils import is_sentencepiece_available, logging from ...utils.versions import require_version @@ -84,6 +86,8 @@ def __init__( unk_token="", bos_token="", eos_token="", + add_bos_token=True, + add_eos_token=False, **kwargs, ): super().__init__( @@ -95,10 +99,50 @@ def __init__( eos_token=eos_token, **kwargs, ) + self._add_bos_token = add_bos_token + self._add_eos_token = add_eos_token + self.update_post_processor() self.vocab_file = vocab_file self.can_save_slow_tokenizer = False if not self.vocab_file else True + def update_post_processor(self): + bos = self.bos_token + bos_token_id = self.bos_token_id + + eos = self.eos_token + eos_token_id = self.eos_token_id + + single = f"{(bos+':0 ') * self.add_bos_token}$A:0{(' '+eos+':0') * self.add_eos_token}" + pair = f"{single}{(' '+bos+':1') * self.add_bos_token} $B:1{(' '+eos+':1') * self.add_eos_token}" + + special_tokens = [] + if self.add_bos_token: + special_tokens.append((bos, bos_token_id)) + if self.add_eos_token: + special_tokens.append((eos, eos_token_id)) + self._tokenizer.post_processor = processors.TemplateProcessing( + single=single, pair=pair, special_tokens=special_tokens + ) + + @property + def add_eos_token(self): + return self._add_eos_token + + @property + def add_bos_token(self): + return self._add_bos_token + + @add_eos_token.setter + def add_eos_token(self, value): + self._add_eos_token = value + self.update_post_processor() + + @add_bos_token.setter + def add_bos_token(self, value): + self._add_bos_token = value + self.update_post_processor() + def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: if not self.can_save_slow_tokenizer: raise ValueError( diff --git a/tests/models/llama/test_tokenization_llama.py b/tests/models/llama/test_tokenization_llama.py index 6ce1bb44c03db2..3a1ec2be93bf4e 100644 --- a/tests/models/llama/test_tokenization_llama.py +++ b/tests/models/llama/test_tokenization_llama.py @@ -315,6 +315,39 @@ def integration_tests(self): }, ) + def test_fast_special_tokens(self): + slow_tokenizer = self.tokenizer + fast_tokenizer = self.rust_tokenizer + slow = slow_tokenizer.encode("A sample test", add_special_tokens=True) + assert slow == [1, 319, 4559, 1243] + + fast_tokenizer.add_eos_token = False + fast = fast_tokenizer.encode("A sample test", add_special_tokens=True) + assert fast == [1, 319, 4559, 1243] + + fast_tokenizer.add_eos_token = True + fast = fast_tokenizer.encode("A sample test", add_special_tokens=True) + assert fast == [1, 319, 4559, 1243, 2] + + slow_tokenizer.add_eos_token = True + slow = slow_tokenizer.encode("A sample test", add_special_tokens=True) + assert slow == [1, 319, 4559, 1243, 2] + + fast_tokenizer = LlamaTokenizerFast.from_pretrained( + "hf-internal-testing/llama-tokenizer", add_eos_token=True, add_bos_token=False + ) + fast = fast_tokenizer.encode("A sample test", add_special_tokens=True) + assert fast == [319, 4559, 1243, 2] + + slow_tokenzier = LlamaTokenizer.from_pretrained( + "hf-internal-testing/llama-tokenizer", add_eos_token=True, add_bos_token=False + ) + slow = slow_tokenzier.encode("A sample test", add_special_tokens=True) + assert slow == [319, 4559, 1243, 2] + + self.tokenizer.add_eos_token = False + self.rust_tokenizer.add_eos_token = False + @slow def test_conversion(self): # This is excruciatingly slow since it has to recreate the entire merge