Skip to content

Commit

Permalink
Support reading tiktoken tokenizer.model file (#31656)
Browse files Browse the repository at this point in the history
* use existing TikTokenConverter to read tiktoken tokenizer.model file

* del test file

* create titktoken integration file

* adding tiktoken llama test

* ALTNATIVE IMPLEMENTATION: supports llama 405B

* fix one char

* remove redundant line

* small fix

* rm unused import

* flag for converting from tiktokeng

* remove unneeded file

* ruff

* remove llamatiktokenconverter, stick to general converter

* tiktoken support v2

* update test

* remove stale changes

* udpate doc

* protect import

* use is_protobuf_available

* add templateprocessor in tiktokenconverter

* reverting templateprocessor from tiktoken support

* update test

* add require_tiktoken

* dev-ci

* trigger build

* trigger build again

* dev-ci

* [build-ci-image] tiktoken

* dev-ci

* dev-ci

* dev-ci

* dev-ci

* change tiktoken file name

* feedback review

* feedback rev

* applying feedback, removing tiktoken converters

* conform test

* adding docs for review

* add doc file for review

* add doc file for review

* add doc file for review

* support loading model without config.json file

* Revert "support loading model without config.json file"

This reverts commit 2753602.

* remove dev var

* updating docs

* safely import protobuf

* fix protobuf import error

* fix protobuf import error

* trying isort to fix ruff error

* fix ruff error

* try to fix ruff again

* try to fix ruff again

* try to fix ruff again

* doc table of contents

* add fix for consistency.dockerfile torchaudio

* ruff

* applying feedback

* minor typo

* merging with push-ci-image

* clean up imports

* revert dockerfile consistency
  • Loading branch information
itazap authored Sep 6, 2024
1 parent 342e800 commit e48e5f1
Show file tree
Hide file tree
Showing 13 changed files with 195 additions and 21 deletions.
2 changes: 1 addition & 1 deletion docker/consistency.dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,4 @@ RUN uv pip install --no-cache-dir "git+https://github.com/huggingface/transforme
RUN git lfs install

RUN pip uninstall -y transformers
RUN apt-get clean && rm -rf /var/lib/apt/lists/* && apt-get autoremove && apt-get autoclean
RUN apt-get clean && rm -rf /var/lib/apt/lists/* && apt-get autoremove && apt-get autoclean
4 changes: 2 additions & 2 deletions docker/torch-light.dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,6 @@ RUN apt-get update && apt-get install -y --no-install-recommends libsndfile1-de
ENV UV_PYTHON=/usr/local/bin/python
RUN pip --no-cache-dir install uv && uv venv && uv pip install --no-cache-dir -U pip setuptools
RUN pip install --no-cache-dir 'torch' 'torchvision' 'torchaudio' --index-url https://download.pytorch.org/whl/cpu
RUN uv pip install --no-deps timm accelerate --extra-index-url https://download.pytorch.org/whl/cpu
RUN uv pip install --no-cache-dir librosa "git+https://github.com/huggingface/transformers.git@${REF}#egg=transformers[sklearn,sentencepiece,vision,testing]"
RUN uv pip install --no-deps timm accelerate --extra-index-url https://download.pytorch.org/whl/cpu
RUN uv pip install --no-cache-dir librosa "git+https://github.com/huggingface/transformers.git@${REF}#egg=transformers[sklearn,sentencepiece,vision,testing,tiktoken]"
RUN pip uninstall -y transformers
2 changes: 2 additions & 0 deletions docs/source/en/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,8 @@
title: Troubleshoot
- local: gguf
title: Interoperability with GGUF files
- local: tiktoken
title: Interoperability with TikToken files
title: Developer guides
- sections:
- local: quantization/overview
Expand Down
38 changes: 38 additions & 0 deletions docs/source/en/tiktoken.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
``
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.
-->

# Tiktoken and interaction with Transformers

Support for tiktoken model files is seamlessly integrated in 🤗 transformers when loading models
`from_pretrained` with a `tokenizer.model` tiktoken file on the Hub, which is automatically converted into our
[fast tokenizer](https://huggingface.co/docs/transformers/main/en/main_classes/tokenizer#transformers.PreTrainedTokenizerFast).

### Known models that were released with a `tiktoken.model`:
- gpt2
- llama3

## Example usage

In order to load `tiktoken` files in `transformers`, ensure that the `tokenizer.model` file is a tiktoken file and it
will automatically be loaded when loading `from_pretrained`. Here is how one would load a tokenizer and a model, which
can be loaded from the exact same file:

```py
from transformers import AutoTokenizer

model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_id, subfolder="original")
```
3 changes: 3 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@
"accelerate>=0.26.0",
"av==9.2.0", # Latest version of PyAV (10.0.0) has issues with audio stream.
"beautifulsoup4",
"blobfile",
"codecarbon==1.2.0",
"cookiecutter==1.7.3",
"dataclasses",
Expand Down Expand Up @@ -177,6 +178,7 @@
"tensorflow-probability<0.24",
"tf2onnx",
"timeout-decorator",
"tiktoken",
"timm<=0.9.16",
"tokenizers>=0.19,<0.20",
"torch",
Expand Down Expand Up @@ -311,6 +313,7 @@ def run(self):
extras["video"] = deps_list("decord", "av")

extras["sentencepiece"] = deps_list("sentencepiece", "protobuf")
extras["tiktoken"] = deps_list("tiktoken", "blobfile")
extras["testing"] = (
deps_list(
"pytest",
Expand Down
41 changes: 29 additions & 12 deletions src/transformers/convert_slow_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,13 @@
from tokenizers import AddedToken, Regex, Tokenizer, decoders, normalizers, pre_tokenizers, processors
from tokenizers.models import BPE, Unigram, WordPiece

from .utils import is_protobuf_available, requires_backends
from .utils import is_protobuf_available, logging, requires_backends
from .utils.import_utils import PROTOBUF_IMPORT_ERROR


logger = logging.get_logger(__name__)


def import_protobuf(error_message=""):
if is_protobuf_available():
import google.protobuf
Expand Down Expand Up @@ -1451,12 +1454,15 @@ def __init__(
vocab_file=None,
pattern=r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+""",
add_prefix_space=False,
additional_special_tokens=None,
*args,
**kwargs,
):
super().__init__(*args)
self.vocab_file = vocab_file
self.pattern = pattern
self.add_prefix_space = add_prefix_space
self.additional_special_tokens = additional_special_tokens

def extract_vocab_merges_from_model(self, tiktoken_url: str):
try:
Expand Down Expand Up @@ -1505,7 +1511,10 @@ def converted(self) -> Tokenizer:
]
)
tokenizer.decoder = decoders.ByteLevel()
tokenizer.add_special_tokens(self.additional_special_tokens)

tokenizer.post_processor = processors.ByteLevel(trim_offsets=False)

return tokenizer


Expand Down Expand Up @@ -1569,29 +1578,37 @@ def converted(self) -> Tokenizer:
}


def convert_slow_tokenizer(transformer_tokenizer) -> Tokenizer:
def convert_slow_tokenizer(transformer_tokenizer, from_tiktoken=False) -> Tokenizer:
"""
Utilities to convert a slow tokenizer instance in a fast tokenizer instance.
Args:
transformer_tokenizer ([`~tokenization_utils_base.PreTrainedTokenizer`]):
Instance of a slow tokenizer to convert in the backend tokenizer for
[`~tokenization_utils_base.PreTrainedTokenizerFast`].
from_tiktoken (bool, optional): Whether to use the `tiktoken` library to convert the tokenizer instead of sentencepiece.
Defaults to False.
Return:
A instance of [`~tokenizers.Tokenizer`] to be used as the backend tokenizer of a
[`~tokenization_utils_base.PreTrainedTokenizerFast`]
"""

tokenizer_class_name = transformer_tokenizer.__class__.__name__
if tokenizer_class_name in SLOW_TO_FAST_CONVERTERS and not from_tiktoken:
converter_class = SLOW_TO_FAST_CONVERTERS[tokenizer_class_name]
return converter_class(transformer_tokenizer).converted()

if tokenizer_class_name not in SLOW_TO_FAST_CONVERTERS:
raise ValueError(
f"An instance of tokenizer class {tokenizer_class_name} cannot be converted in a Fast tokenizer instance."
" No converter was found. Currently available slow->fast convertors:"
f" {list(SLOW_TO_FAST_CONVERTERS.keys())}"
)

converter_class = SLOW_TO_FAST_CONVERTERS[tokenizer_class_name]

return converter_class(transformer_tokenizer).converted()
else:
try:
logger.info("Converting from Tiktoken")
return TikTokenConverter(
vocab_file=transformer_tokenizer.vocab_file,
additional_special_tokens=transformer_tokenizer.additional_special_tokens,
).converted()
except Exception:
raise ValueError(
f"Converting from Tiktoken failed, if a converter for SentencePiece is available, provide a model path "
f"with a SentencePiece tokenizer.model file."
f"Currently available slow->fast convertors: {list(SLOW_TO_FAST_CONVERTERS.keys())}"
)
2 changes: 2 additions & 0 deletions src/transformers/dependency_versions_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
"accelerate": "accelerate>=0.26.0",
"av": "av==9.2.0",
"beautifulsoup4": "beautifulsoup4",
"blobfile": "blobfile",
"codecarbon": "codecarbon==1.2.0",
"cookiecutter": "cookiecutter==1.7.3",
"dataclasses": "dataclasses",
Expand Down Expand Up @@ -82,6 +83,7 @@
"tensorflow-probability": "tensorflow-probability<0.24",
"tf2onnx": "tf2onnx",
"timeout-decorator": "timeout-decorator",
"tiktoken": "tiktoken",
"timm": "timm<=0.9.16",
"tokenizers": "tokenizers>=0.19,<0.20",
"torch": "torch",
Expand Down
8 changes: 8 additions & 0 deletions src/transformers/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@
is_tensorflow_text_available,
is_tf2onnx_available,
is_tf_available,
is_tiktoken_available,
is_timm_available,
is_tokenizers_available,
is_torch_available,
Expand Down Expand Up @@ -1228,6 +1229,13 @@ def require_cython(test_case):
return unittest.skipUnless(is_cython_available(), "test requires cython")(test_case)


def require_tiktoken(test_case):
"""
Decorator marking a test that requires TikToken. These tests are skipped when TikToken isn't installed.
"""
return unittest.skipUnless(is_tiktoken_available(), "test requires TikToken")(test_case)


def get_gpu_count():
"""
Return the number of available gpus (regardless of whether torch, tf or jax is used)
Expand Down
25 changes: 25 additions & 0 deletions src/transformers/tokenization_utils_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
is_mlx_available,
is_numpy_array,
is_offline_mode,
is_protobuf_available,
is_remote_url,
is_tf_available,
is_tf_tensor,
Expand All @@ -65,6 +66,7 @@
to_py_obj,
)
from .utils.chat_template_utils import _compile_jinja_template, _render_with_assistant_indices
from .utils.import_utils import PROTOBUF_IMPORT_ERROR


if TYPE_CHECKING:
Expand All @@ -75,6 +77,16 @@
if is_flax_available():
import jax.numpy as jnp # noqa: F401


def import_protobuf_decode_error(error_message=""):
if is_protobuf_available():
from google.protobuf.message import DecodeError

return DecodeError
else:
raise ImportError(PROTOBUF_IMPORT_ERROR.format(error_message))


if is_tokenizers_available():
from tokenizers import AddedToken
from tokenizers import Encoding as EncodingFast
Expand Down Expand Up @@ -2434,6 +2446,19 @@ def _from_pretrained(
# Instantiate the tokenizer.
try:
tokenizer = cls(*init_inputs, **init_kwargs)
except import_protobuf_decode_error():
logger.info(
"Unable to load tokenizer model from SPM, loading from TikToken will be attempted instead."
"(Google protobuf error: Tried to load SPM model with non-SPM vocab file).",
)
return False
except RuntimeError as e:
if "sentencepiece_processor.cc" in str(e):
logger.info(
"Unable to load tokenizer model from SPM, loading from TikToken will be attempted instead."
"(SentencePiece RuntimeError: Tried to load SPM model with non-SPM vocab file).",
)
return False
except OSError:
raise OSError(
"Unable to load vocabulary from file. "
Expand Down
17 changes: 11 additions & 6 deletions src/transformers/tokenization_utils_fast.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
TOKENIZER_FILE = "tokenizer.json"
SPECIAL_TOKENS_MAP_FILE = "special_tokens_map.json"
TOKENIZER_CONFIG_FILE = "tokenizer_config.json"
TIKTOKEN_VOCAB_FILE = "tokenizer.model"

# Slow tokenizers have an additional added tokens files
ADDED_TOKENS_FILE = "added_tokens.json"
Expand All @@ -74,7 +75,7 @@
"WordPiece": WordPieceTrainer,
}

VOCAB_FILES_NAMES = {"tokenizer_file": TOKENIZER_FILE}
VOCAB_FILES_NAMES = {"tokenizer_file": TOKENIZER_FILE, "vocab_file": TIKTOKEN_VOCAB_FILE}


@add_end_docstrings(INIT_TOKENIZER_DOCSTRING)
Expand Down Expand Up @@ -113,7 +114,7 @@ def __init__(self, *args, **kwargs):
elif fast_tokenizer_file is not None and not from_slow:
# We have a serialization from tokenizers which let us directly build the backend
fast_tokenizer = TokenizerFast.from_file(fast_tokenizer_file)
elif slow_tokenizer is not None:
elif slow_tokenizer:
# We need to convert a slow tokenizer to build the backend
fast_tokenizer = convert_slow_tokenizer(slow_tokenizer)
elif gguf_file is not None:
Expand All @@ -123,22 +124,26 @@ def __init__(self, *args, **kwargs):
tokenizer_dict = gguf_param["tokenizer"]
tokenizer_config = gguf_param["tokenizer_config"]
fast_tokenizer, additional_kwargs = convert_gguf_tokenizer(architecture, tokenizer_dict)

kwargs.update(tokenizer_config)
if len(additional_kwargs) > 0:
kwargs.update(additional_kwargs)

elif self.slow_tokenizer_class is not None:
elif self.slow_tokenizer_class is not None and slow_tokenizer is not False:
# We need to create and convert a slow tokenizer to build the backend
slow_tokenizer = self.slow_tokenizer_class(*args, **kwargs)
fast_tokenizer = convert_slow_tokenizer(slow_tokenizer)
elif not slow_tokenizer:
# We tried loading a slow_tokenizer with spm and failed, try to load with tiktoken
self.vocab_file = kwargs.get("vocab_file", None)
self.additional_special_tokens = kwargs.get("additional_special_tokens", [])
fast_tokenizer = convert_slow_tokenizer(self, from_tiktoken=True)
slow_tokenizer = None
else:
raise ValueError(
"Couldn't instantiate the backend tokenizer from one of: \n"
"(1) a `tokenizers` library serialization file, \n"
"(2) a slow tokenizer instance to convert or \n"
"(3) an equivalent slow tokenizer class to instantiate and convert. \n"
"You need to have sentencepiece installed to convert a slow tokenizer to a fast one."
"You need to have sentencepiece or tiktoken installed to convert a slow tokenizer to a fast one."
)

self._tokenizer = fast_tokenizer
Expand Down
1 change: 1 addition & 0 deletions src/transformers/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,7 @@
is_tensorflow_text_available,
is_tf2onnx_available,
is_tf_available,
is_tiktoken_available,
is_timm_available,
is_tokenizers_available,
is_torch_available,
Expand Down
6 changes: 6 additions & 0 deletions src/transformers/utils/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,8 @@ def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[
_torchvision_available = _is_package_available("torchvision")
_mlx_available = _is_package_available("mlx")
_hqq_available = _is_package_available("hqq")
_tiktoken_available = _is_package_available("tiktoken")
_blobfile_available = _is_package_available("blobfile")
_liger_kernel_available = _is_package_available("liger_kernel")


Expand Down Expand Up @@ -1171,6 +1173,10 @@ def is_mlx_available():
return _mlx_available


def is_tiktoken_available():
return _tiktoken_available and _blobfile_available


def is_liger_kernel_available():
if not _liger_kernel_available:
return False
Expand Down
Loading

0 comments on commit e48e5f1

Please sign in to comment.