Skip to content

Tokenizers tokenizer #1261

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Nov 5, 2024
32 changes: 32 additions & 0 deletions tokenizer/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
"""
Abstract base class for all tokenizer classes in python matching c++ interface.
"""

# Standard
from abc import ABC, abstractmethod
from typing import List


class TokenizerBase(ABC):
__doc__ = __doc__

@abstractmethod
def encode(self, s: str, *, bos: bool = False, eos: bool = False) -> List[int]:
"""Encode the given string and optionally include bos/eos tokens"""

@abstractmethod
def decode(self, ids: List[int]) -> str:
"""Decode the given token ids into a string"""

@abstractmethod
def bos_id(self) -> int:
"""The id of the begin-of-string token"""

@abstractmethod
def eos_id(self) -> int:
"""The id of the end-of-string token"""
92 changes: 92 additions & 0 deletions tokenizer/hf_tokenizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# Standard
from typing import List, Optional
import json
import os

# Third Party
from tokenizers import Tokenizer

# Local
from .base import TokenizerBase


class HFTokenizer(TokenizerBase):
"""
Wrapper around the Huggingface `tokenizers` library for API compatibility
"""

def __init__(self, file_path: str):
# If the path is a directory, look for "tokenizer.json" which is
# standard for transformers checkpoints and also look for the
# "tokenizer_config.json" file to parse eos/bos tokens
if os.path.isdir(file_path):
tokenizer_path = os.path.join(file_path, "tokenizer.json")
tokenizer_config_path = os.path.join(file_path, "tokenizer_config.json")
else:
tokenizer_path = file_path
tokenizer_config_path = os.path.join(os.path.dirname(file_path), "tokenizer_config.json")
if not os.path.isfile(tokenizer_path):
tokenizer_config_path = None

# Load the tokenizer itself
self._tokenizer = Tokenizer.from_file(tokenizer_path)

# If available, parse bos/eos tokens from the tokenizer config
self._bos_id, self._eos_id = None, None
if tokenizer_config_path is not None:
with open(tokenizer_config_path, "r") as handle:
tok_config = json.load(handle)
bos_token = tok_config.get("bos_token")
eos_token = tok_config.get("eos_token")
if bos_token is not None:
self._bos_id = self._tokenizer.token_to_id(bos_token)
if eos_token is not None:
self._eos_id = self._tokenizer.token_to_id(eos_token)

# If no eos/bos tokens found, go looking for them!
if None in [self._bos_id, self._eos_id]:
tok_content = json.loads(self._tokenizer.to_str())
if self._bos_id is None:
self._bos_id = self._look_for_special_token(tok_content, ["begin", "text"])
if self._eos_id is None:
self._eos_id = self._look_for_special_token(tok_content, ["end", "text"])

assert None not in [self._bos_id, self._eos_id], "Unable to find an BOS/EOS tokens"

@staticmethod
def _look_for_special_token(added_tokens: dict, search_strs: List[str]) -> Optional[int]:
candidate_toks = added_tokens
for search_str in search_strs:
candidate_toks = [
tok for tok in candidate_toks
if tok["special"] and search_str in tok["content"]
]
if len(candidate_toks) == 1:
return candidate_toks[0]["id"]

def encode(
self,
s: str,
*,
bos: bool = False,
eos: bool = False,
) -> List[int]:
res = self._tokenizer.encode(s, add_special_tokens=bos).ids
if eos and (not res or res[-1] != self._eos_token):
res.append(self._eos_token)
return res

def decode(self, ids: List[int]) -> str:
return self._tokenizer.decode(ids)

def bos_id(self) -> int:
return self._bos_id

def eos_id(self) -> int:
return self._eos_id
4 changes: 3 additions & 1 deletion tokenizer/tiktoken.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
import tiktoken
from tiktoken.load import load_tiktoken_bpe

from .base import TokenizerBase
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any reason not to use the full path?

Suggested change
from .base import TokenizerBase
from tokenizer.base import TokenizerBase

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Heh, no, I have tended towards relative imports for local (the mental equivalent of #inlclude "foo.h" vs #include <string> for local files vs standard/third party). Definitely no strong preference though! I'd much rather stay consistent with the rest of the project.



logger = getLogger(__name__)

Expand All @@ -38,7 +40,7 @@ class Message(TypedDict):
Dialog = Sequence[Message]


class Tokenizer:
class Tokenizer(TokenizerBase):
"""
tokenizing and encoding/decoding text using the Tiktoken tokenizer.
"""
Expand Down
40 changes: 35 additions & 5 deletions torchchat/cli/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,7 @@ class TokenizerArgs:
tokenizer_path: Optional[Union[Path, str]] = None
is_sentencepiece: bool = False
is_tiktoken: bool = False
is_hf_tokenizer: bool = False
t: Optional[Any] = None

def __post_init__(self):
Expand All @@ -213,6 +214,7 @@ def __post_init__(self):
self.t = TiktokenTokenizer(model_path=str(self.tokenizer_path))
self.is_tiktoken = True
self.is_sentencepiece = False
self.is_hf_tokenizer = False
return
except:
pass
Expand All @@ -223,12 +225,25 @@ def __post_init__(self):
self.t = SentencePieceProcessor(model_file=str(self.tokenizer_path))
self.is_tiktoken = False
self.is_sentencepiece = True
self.is_hf_tokenizer = False
return
except:
pass

try:
from tokenizer.hf_tokenizer import HFTokenizer

self.t = HFTokenizer(str(self.tokenizer_path))
self.is_tiktoken = False
self.is_sentencepiece = False
self.is_hf_tokenizer = True
return
except:
pass

self.is_tiktoken = False
self.is_sentencepiece = False
self.is_hf_tokenizer = False
self.t = None
return

Expand All @@ -240,16 +255,27 @@ def validate_model(
if model is None:
return

if self.is_tiktoken == self.is_sentencepiece:
if sum([self.is_tiktoken, self.is_hf_tokenizer, self.is_sentencepiece]) != 1:
raise RuntimeError(f"no tokenizer was found at {self.tokenizer_path}")

is_tiktoken = self.is_tiktoken
is_sentencepiece = self.is_sentencepiece
is_hf_tokenizer = self.is_hf_tokenizer
use_tiktoken = model.config.use_tiktoken
use_hf_tokenizer = model.config.use_hf_tokenizer
use_sentencepiece = not (use_tiktoken or use_hf_tokenizer)

if not (is_tiktoken == use_tiktoken) or not (is_sentencepiece != use_tiktoken):
if (
(is_tiktoken and not use_tiktoken) or
(is_hf_tokenizer and not use_hf_tokenizer) or
(is_sentencepiece and not use_sentencepiece)
):
raise RuntimeError(
f"model-specified tokenizer ({tokenizer_setting_to_name(use_tiktoken)}) does not match provided tokenizer ({tokenizer_setting_to_name(is_tiktoken)}) for {model_description}"
"model-specified tokenizer ({}) does not match provided tokenizer ({}) for {}".format(
tokenizer_setting_to_name(use_tiktoken, use_hf_tokenizer),
tokenizer_setting_to_name(is_tiktoken, is_hf_tokenizer),
model_description,
)
)

return
Expand Down Expand Up @@ -605,5 +631,9 @@ def _initialize_model(
return model


def tokenizer_setting_to_name(tiktoken: bool = False) -> str:
return "TikToken" if tiktoken else "SentencePiece"
def tokenizer_setting_to_name(tiktoken: bool, tokenizers: bool) -> str:
if tiktoken:
return "TikToken"
if tokenizers:
return "Tokenizers"
return "SentencePiece"
8 changes: 7 additions & 1 deletion torchchat/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,9 @@ class TransformerArgs:
norm_eps: float = 1e-5
multiple_of: int = 256
ffn_dim_multiplier: Optional[int] = None
# Select the desired tokenizer. Defaults to sentencepiece
use_tiktoken: bool = False
use_hf_tokenizer: bool = False
max_seq_length: int = 8192
rope_scaling: Optional[Dict[str, Any]] = None
# For pipeline parallel
Expand Down Expand Up @@ -327,12 +329,14 @@ class ModelArgs:
model_type: ModelType
transformer_args: Dict[str, Dict[str, Any]]
use_tiktoken: bool
use_hf_tokenizer: bool

def __init__(
self,
transformer_args: Dict[str, Dict[str, Any]],
model_type: ModelType = ModelType.TextOnly,
use_tiktoken: bool = False,
use_hf_tokenizer: bool = False,
) -> None:
self._sanity_check(transformer_args, model_type)

Expand All @@ -341,6 +345,7 @@ def __init__(

# Model-level attributes
self.use_tiktoken = use_tiktoken
self.use_hf_tokenizer = use_hf_tokenizer

def _sanity_check(
self,
Expand All @@ -367,7 +372,8 @@ def from_params(cls, params_path):
}

use_tiktoken = loaded_params.get("use_tiktoken", False)
return cls(transformer_args, model_type, use_tiktoken)
use_hf_tokenizer = loaded_params.get("use_hf_tokenizer", False)
return cls(transformer_args, model_type, use_tiktoken, use_hf_tokenizer)

@classmethod
def from_table(cls, name: str):
Expand Down
Loading