Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 85 additions & 1 deletion src/transformers/convert_slow_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,13 @@
"""

import warnings
from functools import lru_cache
from typing import Optional

from packaging import version
from tokenizers import AddedToken, Regex, Tokenizer, decoders, normalizers, pre_tokenizers, processors
from tokenizers.models import BPE, Unigram, WordPiece
from tqdm import tqdm

from .utils import is_protobuf_available, is_sentencepiece_available, logging, requires_backends
from .utils.import_utils import PROTOBUF_IMPORT_ERROR
Expand Down Expand Up @@ -1692,6 +1694,85 @@ def converted(self) -> Tokenizer:
return tokenizer


class MistralConverter:
def __init__(
self,
vocab_file=None,
pattern=r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+""",
add_prefix_space=False,
additional_special_tokens=None,
**kwargs,
):
self.vocab_file = vocab_file
self.pattern = pattern
self.add_prefix_space = add_prefix_space
self.additional_special_tokens = (
additional_special_tokens.keys()
if isinstance(additional_special_tokens, dict)
else additional_special_tokens
)

def extract_vocab_merges_from_model(self, tiktoken_url: str):
import base64
import json

with open(self.vocab_file, "r", encoding="utf-8") as f:
untyped = json.load(f)
self.pattern = untyped["config"]["pattern"]
self.additional_special_tokens = [
AddedToken(k["token_str"], special=k["is_control"]) for k in untyped["special_tokens"]
]
bpe_ranks = untyped["vocab"]
byte_encoder = bytes_to_unicode()

@lru_cache
def token_bytes_to_string(b):
return "".join([byte_encoder[ord(char)] for char in b.decode("latin-1")])

merges = []
vocab = {}
for idx, token in enumerate(self.additional_special_tokens):
vocab[token.content] = idx
bpe_ranks = [base64.b64decode(k["token_bytes"]) for k in bpe_ranks]
rank_set = set(bpe_ranks)
for rank, token in enumerate(tqdm(bpe_ranks, desc="Converting tekken.json to tokenizer.json")):
vocab[token_bytes_to_string(token)] = rank
if len(token) == 1:
continue
local = []
for index in range(1, len(token)):
piece_l, piece_r = token[:index], token[index:]
if piece_l in rank_set and piece_r in rank_set and (piece_l + piece_r) in rank_set:
local.append((piece_l, piece_r, rank))
local = sorted(local, key=lambda x: (bpe_ranks.index(x[0]), bpe_ranks.index(x[1])), reverse=False)
merges.extend(local)
merges = sorted(merges, key=lambda val: val[2], reverse=False)
merges = [(token_bytes_to_string(val[0]), token_bytes_to_string(val[1])) for val in merges]
return vocab, merges

def tokenizer(self):
vocab_scores, merges = self.extract_vocab_merges_from_model(self.vocab_file)
tokenizer = Tokenizer(BPE(vocab_scores, merges, fuse_unk=False))
if hasattr(tokenizer.model, "ignore_merges"):
tokenizer.model.ignore_merges = True
return tokenizer

def converted(self) -> Tokenizer:
tokenizer = self.tokenizer()
tokenizer.pre_tokenizer = pre_tokenizers.Sequence(
[
pre_tokenizers.Split(Regex(self.pattern), behavior="isolated", invert=False),
pre_tokenizers.ByteLevel(add_prefix_space=self.add_prefix_space, use_regex=False),
]
)
tokenizer.decoder = decoders.ByteLevel()

tokenizer.add_tokens(self.additional_special_tokens)
tokenizer.post_processor = processors.ByteLevel(trim_offsets=False)

return tokenizer


SLOW_TO_FAST_CONVERTERS = {
"AlbertTokenizer": AlbertConverter,
"BartTokenizer": RobertaConverter,
Expand Down Expand Up @@ -1771,7 +1852,10 @@ def convert_slow_tokenizer(transformer_tokenizer, from_tiktoken=False) -> Tokeni
if tokenizer_class_name in SLOW_TO_FAST_CONVERTERS and not from_tiktoken:
converter_class = SLOW_TO_FAST_CONVERTERS[tokenizer_class_name]
return converter_class(transformer_tokenizer).converted()

elif transformer_tokenizer.vocab_file.endswith("tekken.json"):
transformer_tokenizer.original_tokenizer = transformer_tokenizer
logger.info("Converting from Mistral tekken.json")
return MistralConverter(transformer_tokenizer.vocab_file).converted()
else:
try:
logger.info("Converting from Tiktoken")
Expand Down
6 changes: 2 additions & 4 deletions src/transformers/models/auto/tokenization_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -748,10 +748,8 @@
(
"voxtral",
(
"MistralCommonTokenizer"
if is_mistral_common_available()
else ("LlamaTokenizer" if is_sentencepiece_available() else None),
"LlamaTokenizerFast" if is_tokenizers_available() and not is_mistral_common_available() else None,
"MistralCommonTokenizer" if is_mistral_common_available() else None,
"PreTrainedTokenizerFast" if is_tokenizers_available() and not is_mistral_common_available() else None,
),
),
("wav2vec2", ("Wav2Vec2CTCTokenizer", None)),
Expand Down
86 changes: 85 additions & 1 deletion src/transformers/tokenization_utils_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from typing import TYPE_CHECKING, Any, Literal, NamedTuple, Optional, Union, overload

import numpy as np
from huggingface_hub import list_repo_files
from packaging import version

from . import __version__
Expand Down Expand Up @@ -2098,7 +2099,21 @@ def from_pretrained(
template = template.removesuffix(".jinja")
vocab_files[f"chat_template_{template}"] = f"{CHAT_TEMPLATE_DIR}/{template}.jinja"

# Get files from url, cache, or disk depending on the case
if not is_local and not local_files_only:
try:
remote_files = list_repo_files(pretrained_model_name_or_path)
except Exception:
remote_files = []
else:
remote_files = os.listdir(pretrained_model_name_or_path)

if "tokenizer_file" in vocab_files and not re.search(vocab_files["tokenizer_file"], "".join(remote_files)):
# mistral tokenizer names are different, but we can still convert them if
# mistral common is not there
other_pattern = re.escape("tekken.json|tokenizer.model.*")
if match := re.search(other_pattern, "\n".join(remote_files)):
vocab_files["vocab_file"] = match.group()

resolved_vocab_files = {}
for file_id, file_path in vocab_files.items():
if file_path is None:
Expand Down Expand Up @@ -2417,6 +2432,75 @@ def _from_pretrained(
"Special tokens have been added in the vocabulary, make sure the associated word embeddings are"
" fine-tuned or trained."
)
try:
vocab_size = tokenizer.vocab_size
except NotImplementedError:
vocab_size = 0

if (
vocab_size > 100000
and hasattr(tokenizer, "_tokenizer")
and getattr(tokenizer._tokenizer, "pre_tokenizer", None) is not None
):
from huggingface_hub import model_info

def is_base_mistral(model_id: str) -> bool:
model = model_info(model_id)
if model.tags is not None:
if re.search("base_model:.*mistralai", "".join(model.tags)):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm so that's only for mistral org no? Should we directly check of `model_type in ["mistral" ....] so that it also works for other orgs?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can't do that until we download the config / config is there

return True
return False

if _is_local or is_base_mistral(pretrained_model_name_or_path):
_config_file = cached_file(
pretrained_model_name_or_path,
"config.json",
cache_dir=cache_dir,
token=token,
local_files_only=local_files_only,
_raise_exceptions_for_missing_entries=False,
_raise_exceptions_for_connection_errors=False,
_commit_hash=_commit_hash,
)
if _config_file is not None:
with open(_config_file, encoding="utf-8") as f:
_config = json.load(f)
transformers_version = _config.get("transformers_version")

if transformers_version and version.parse(transformers_version) <= version.parse("4.57.2"):
if _is_local and _config.model_type not in [
"mistral",
"mistral3",
"voxstral",
"ministral",
"pixtral",
]:
return tokenizer
Comment on lines +2470 to +2478
Copy link
Contributor

@CISC CISC Nov 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The non-existent attribute use of _config.model_type is causing massive loading failures everywhere (including CIs), please consider making a hotfix ASAP. :)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Change it to _config.get("model_type")

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah sorry

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have no idea why the CI was full green


# Expose the `fix_mistral_regex` flag on the tokenizer when provided, even if no correction is applied.
if "fix_mistral_regex" in init_kwargs:
setattr(tokenizer, "fix_mistral_regex", init_kwargs["fix_mistral_regex"])

fix_mistral_regex = kwargs.get("fix_mistral_regex") # not init kwargs
# only warn if its not explicitly passed
if fix_mistral_regex is None and not getattr(tokenizer, "fix_mistral_regex", False):
setattr(tokenizer, "fix_mistral_regex", False)
logger.warning(
f"The tokenizer you are loading from '{pretrained_model_name_or_path}'"
f" with an incorrect regex pattern: https://huggingface.co/mistralai/Mistral-Small-3.1-24B-Instruct-2503/discussions/84#69121093e8b480e709447d5e. "
" This will lead to incorrect tokenization. You should set the `fix_mistral_regex=True` flag when loading this tokenizer to fix this issue."
)
elif fix_mistral_regex is True or getattr(tokenizer, "fix_mistral_regex", False):
setattr(tokenizer, "fix_mistral_regex", True)
import tokenizers

tokenizer.backend_tokenizer.pre_tokenizer[0] = tokenizers.pre_tokenizers.Split(
pattern=tokenizers.Regex(
r"[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]*[\p{Ll}\p{Lm}\p{Lo}\p{M}]+|[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]+[\p{Ll}\p{Lm}\p{Lo}\p{M}]*|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n/]*|\s*[\r\n]+|\s+(?!\S)|\s+"
),
behavior="isolated",
)

return tokenizer

@staticmethod
Expand Down