Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 83 additions & 32 deletions src/transformers/tokenization_utils_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2099,12 +2099,13 @@ def from_pretrained(
template = template.removesuffix(".jinja")
vocab_files[f"chat_template_{template}"] = f"{CHAT_TEMPLATE_DIR}/{template}.jinja"

remote_files = []
if not is_local and not local_files_only:
try:
remote_files = list_repo_files(pretrained_model_name_or_path)
except Exception:
remote_files = []
else:
elif pretrained_model_name_or_path and os.path.isdir(pretrained_model_name_or_path):
remote_files = os.listdir(pretrained_model_name_or_path)

if "tokenizer_file" in vocab_files and not re.search(vocab_files["tokenizer_file"], "".join(remote_files)):
Expand Down Expand Up @@ -2437,57 +2438,108 @@ def _from_pretrained(
except NotImplementedError:
vocab_size = 0

# Optionally patches mistral tokenizers with wrong regex
if (
vocab_size > 100000
and hasattr(tokenizer, "_tokenizer")
and getattr(tokenizer._tokenizer, "pre_tokenizer", None) is not None
):
from huggingface_hub import model_info
tokenizer = cls._patch_mistral_regex(
tokenizer,
pretrained_model_name_or_path,
token=token,
cache_dir=cache_dir,
local_files_only=local_files_only,
_commit_hash=_commit_hash,
_is_local=_is_local,
init_kwargs=init_kwargs,
fix_mistral_regex=kwargs.get("fix_mistral_regex"),
)

def is_base_mistral(model_id: str) -> bool:
model = model_info(model_id)
if model.tags is not None:
if re.search("base_model:.*mistralai", "".join(model.tags)):
return True
return False
return tokenizer

if _is_local or is_base_mistral(pretrained_model_name_or_path):
_config_file = cached_file(
pretrained_model_name_or_path,
"config.json",
cache_dir=cache_dir,
token=token,
local_files_only=local_files_only,
_raise_exceptions_for_missing_entries=False,
_raise_exceptions_for_connection_errors=False,
_commit_hash=_commit_hash,
)
if _config_file is not None:
with open(_config_file, encoding="utf-8") as f:
_config = json.load(f)
transformers_version = _config.get("transformers_version")
@classmethod
def _patch_mistral_regex(
cls,
tokenizer,
pretrained_model_name_or_path,
token=None,
cache_dir=None,
local_files_only=False,
_commit_hash=None,
_is_local=False,
init_kwargs=None,
fix_mistral_regex=None,
):
"""
Patches mistral related tokenizers with incorrect regex if detected
1) Local file with an associated config saved next to it
>> Model type one of the mistral models (on older versions)
2) Remote models on the hub from official mistral models
>> Tags including `base_model:.*mistralai`
"""
from huggingface_hub import model_info

if transformers_version and version.parse(transformers_version) <= version.parse("4.57.2"):
if _is_local and _config.model_type not in [
def is_base_mistral(model_id: str) -> bool:
model = model_info(model_id)
if model.tags is not None:
if re.search("base_model:.*mistralai", "".join(model.tags)):
return True
return False

if _is_local or is_base_mistral(pretrained_model_name_or_path):
_config_file = cached_file(
pretrained_model_name_or_path,
"config.json",
cache_dir=cache_dir,
token=token,
local_files_only=local_files_only,
_raise_exceptions_for_missing_entries=False,
_raise_exceptions_for_connection_errors=False,
_commit_hash=_commit_hash,
)

# Detected using a (local) mistral tokenizer
mistral_config_detected = False
if _config_file is not None:
with open(_config_file, encoding="utf-8") as f:
_config = json.load(f)
transformers_version = _config.get("transformers_version")
transformers_model_type = _config.get("model_type")

# Detect if we can skip the mistral fix by
# a) having a non-mistral tokenizer
# b) fixed version of transformers
if transformers_version and version.parse(transformers_version) <= version.parse("4.57.2"):
if (
_is_local
and transformers_model_type is not None
and transformers_model_type
not in [
"mistral",
"mistral3",
"voxstral",
"voxtral",
"ministral",
"pixtral",
]:
return tokenizer
]
):
return tokenizer
elif transformers_version and version.parse(transformers_version) >= version.parse("5.0.0"):
return tokenizer

mistral_config_detected = True

if mistral_config_detected or (not _is_local and is_base_mistral(pretrained_model_name_or_path)):
# Expose the `fix_mistral_regex` flag on the tokenizer when provided, even if no correction is applied.
if "fix_mistral_regex" in init_kwargs:
if init_kwargs and "fix_mistral_regex" in init_kwargs:
setattr(tokenizer, "fix_mistral_regex", init_kwargs["fix_mistral_regex"])

fix_mistral_regex = kwargs.get("fix_mistral_regex") # not init kwargs
# only warn if its not explicitly passed
if fix_mistral_regex is None and not getattr(tokenizer, "fix_mistral_regex", False):
setattr(tokenizer, "fix_mistral_regex", False)
logger.warning(
f"The tokenizer you are loading from '{pretrained_model_name_or_path}'"
f" with an incorrect regex pattern: https://huggingface.co/mistralai/Mistral-Small-3.1-24B-Instruct-2503/discussions/84#69121093e8b480e709447d5e. "
f" with an incorrect regex pattern: https://huggingface.co/mistralai/Mistral-Small-3.1-24B-Instruct-2503/discussions/84#69121093e8b480e709447d5e."
" This will lead to incorrect tokenization. You should set the `fix_mistral_regex=True` flag when loading this tokenizer to fix this issue."
)
elif fix_mistral_regex is True or getattr(tokenizer, "fix_mistral_regex", False):
Expand All @@ -2500,7 +2552,6 @@ def is_base_mistral(model_id: str) -> bool:
),
behavior="isolated",
)

return tokenizer

@staticmethod
Expand Down
39 changes: 39 additions & 0 deletions tests/models/auto/test_tokenization_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,13 @@
GPT2Tokenizer,
GPT2TokenizerFast,
PreTrainedTokenizerFast,
Qwen2Tokenizer,
Qwen2TokenizerFast,
Qwen3MoeConfig,
RobertaTokenizer,
RobertaTokenizerFast,
is_tokenizers_available,
logging,
)
from transformers.models.auto.configuration_auto import CONFIG_MAPPING, AutoConfig
from transformers.models.auto.tokenization_auto import (
Expand All @@ -49,6 +53,7 @@
DUMMY_DIFF_TOKENIZER_IDENTIFIER,
DUMMY_UNKNOWN_IDENTIFIER,
SMALL_MODEL_IDENTIFIER,
CaptureLogger,
RequestCounter,
is_flaky,
require_tokenizers,
Expand Down Expand Up @@ -229,6 +234,40 @@ def test_auto_tokenizer_from_local_folder(self):
self.assertIsInstance(tokenizer2, tokenizer.__class__)
self.assertEqual(tokenizer2.vocab_size, 12)

def test_auto_tokenizer_from_local_folder_mistral_detection(self):
"""See #42374 for reference, ensuring proper mistral detection on local tokenizers"""
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-235B-A22B-Thinking-2507")
config = Qwen3MoeConfig.from_pretrained("Qwen/Qwen3-235B-A22B-Thinking-2507")
self.assertIsInstance(tokenizer, (Qwen2Tokenizer, Qwen2TokenizerFast))

with tempfile.TemporaryDirectory() as tmp_dir:
tokenizer.save_pretrained(tmp_dir)

# Case 1: Tokenizer with no config associated
logger = logging.get_logger("transformers.tokenization_utils_base")
with CaptureLogger(logger) as cl:
AutoTokenizer.from_pretrained(tmp_dir)
self.assertNotIn(
"with an incorrect regex pattern: https://huggingface.co/mistralai/Mistral-Small-3.1-24B-Instruct-2503/discussions/84#69121093e8b480e709447d5e",
cl.out,
)

# Case 2: Tokenizer with config associated
# Needed to be saved along the tokenizer to detect (non)mistral
# for a version where the regex bug occurs
config_dict = config.to_diff_dict()
config_dict["transformers_version"] = "4.57.2"

# Manually saving to avoid versioning clashes
config_path = os.path.join(tmp_dir, "config.json")
with open(config_path, "w", encoding="utf-8") as f:
json.dump(config_dict, f, indent=2, sort_keys=True)

tokenizer2 = AutoTokenizer.from_pretrained(tmp_dir)

self.assertIsInstance(tokenizer2, tokenizer.__class__)
self.assertTrue(tokenizer2.vocab_size > 100_000)

def test_auto_tokenizer_fast_no_slow(self):
tokenizer = AutoTokenizer.from_pretrained("Salesforce/ctrl")
# There is no fast CTRL so this always gives us a slow tokenizer.
Expand Down
6 changes: 5 additions & 1 deletion tests/models/llama/test_tokenization_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,11 @@
@require_sentencepiece
@require_tokenizers
class LlamaTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
from_pretrained_id = ["hf-internal-testing/llama-tokenizer", "meta-llama/Llama-2-7b-hf"]
from_pretrained_id = [
"hf-internal-testing/llama-tokenizer",
"meta-llama/Llama-2-7b-hf",
"meta-llama/Meta-Llama-3-8B",
]
tokenizer_class = LlamaTokenizer
rust_tokenizer_class = LlamaTokenizerFast

Expand Down
22 changes: 22 additions & 0 deletions tests/test_tokenization_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -4670,3 +4670,25 @@ def test_empty_input_string(self):
for return_type, target_type in zip(tokenizer_return_type, output_tensor_type):
output = tokenizer(empty_input_string, return_tensors=return_type)
self.assertEqual(output.input_ids.dtype, target_type)

def test_local_files_only(self):
from transformers import AutoTokenizer

pretrained_list = getattr(self, "from_pretrained_id", []) or []
for pretrained_name in pretrained_list:
with self.subTest(f"AutoTokenizer ({pretrained_name})"):
# First cache the tokenizer files
try:
tokenizer_cached = AutoTokenizer.from_pretrained(pretrained_name)

# Now load with local_files_only=True
tokenizer_local = AutoTokenizer.from_pretrained(pretrained_name, local_files_only=True)

# Check that the two tokenizers are identical
self.assertEqual(tokenizer_cached.get_vocab(), tokenizer_local.get_vocab())
self.assertEqual(
tokenizer_cached.all_special_tokens_extended,
tokenizer_local.all_special_tokens_extended,
)
except Exception as _:
pass # if the pretrained model is not loadable how could it pass locally :)