Skip to content

Commit

Permalink
[Sentencepiece] make sure legacy do not require protobuf (#25684)
Browse files Browse the repository at this point in the history
make sure legacy does not require `protobuf`
  • Loading branch information
ArthurZucker authored Aug 25, 2023
1 parent 0770ce6 commit dd8b7d2
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 12 deletions.
7 changes: 5 additions & 2 deletions src/transformers/convert_slow_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,17 +27,20 @@
from tokenizers.models import BPE, Unigram, WordPiece

from .utils import is_protobuf_available, requires_backends
from .utils.import_utils import PROTOBUF_IMPORT_ERROR


def import_protobuf():
def import_protobuf(error_message=""):
if is_protobuf_available():
import google.protobuf

if version.parse(google.protobuf.__version__) < version.parse("4.0.0"):
from transformers.utils import sentencepiece_model_pb2
else:
from transformers.utils import sentencepiece_model_pb2_new as sentencepiece_model_pb2
return sentencepiece_model_pb2
return sentencepiece_model_pb2
else:
raise ImportError(PROTOBUF_IMPORT_ERROR.format(error_message))


class SentencePieceExtractor:
Expand Down
13 changes: 8 additions & 5 deletions src/transformers/models/llama/tokenization_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,14 +162,17 @@ def unk_token_length(self):
# Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.get_spm_processor
def get_spm_processor(self):
tokenizer = spm.SentencePieceProcessor(**self.sp_model_kwargs)
if self.legacy: # no dependency on protobuf
tokenizer.Load(self.vocab_file)
return tokenizer

with open(self.vocab_file, "rb") as f:
sp_model = f.read()
model_pb2 = import_protobuf()
model_pb2 = import_protobuf(f"The new behaviour of {self.__class__.__name__} (with `self.legacy = False`)")
model = model_pb2.ModelProto.FromString(sp_model)
if not self.legacy:
normalizer_spec = model_pb2.NormalizerSpec()
normalizer_spec.add_dummy_prefix = False
model.normalizer_spec.MergeFrom(normalizer_spec)
normalizer_spec = model_pb2.NormalizerSpec()
normalizer_spec.add_dummy_prefix = False
model.normalizer_spec.MergeFrom(normalizer_spec)
sp_model = model.SerializeToString()
tokenizer.LoadFromSerializedProto(sp_model)
return tokenizer
Expand Down
13 changes: 8 additions & 5 deletions src/transformers/models/t5/tokenization_t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,14 +195,17 @@ def __init__(

def get_spm_processor(self):
tokenizer = spm.SentencePieceProcessor(**self.sp_model_kwargs)
if self.legacy: # no dependency on protobuf
tokenizer.Load(self.vocab_file)
return tokenizer

with open(self.vocab_file, "rb") as f:
sp_model = f.read()
model_pb2 = import_protobuf()
model_pb2 = import_protobuf(f"The new behaviour of {self.__class__.__name__} (with `self.legacy = False`)")
model = model_pb2.ModelProto.FromString(sp_model)
if not self.legacy:
normalizer_spec = model_pb2.NormalizerSpec()
normalizer_spec.add_dummy_prefix = False
model.normalizer_spec.MergeFrom(normalizer_spec)
normalizer_spec = model_pb2.NormalizerSpec()
normalizer_spec.add_dummy_prefix = False
model.normalizer_spec.MergeFrom(normalizer_spec)
sp_model = model.SerializeToString()
tokenizer.LoadFromSerializedProto(sp_model)
return tokenizer
Expand Down

0 comments on commit dd8b7d2

Please sign in to comment.