Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 9 additions & 8 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,14 +410,15 @@ def adjust_generation_fn(
logger.info(
"Generation config file not found, using a generation config created from the model config."
)
self.generation_config = GenerationConfig.from_pretrained(
pretrained_model_name_or_path,
config_file_name="config.json",
_from_auto=from_auto_class,
_from_pipeline=from_pipeline,
_from_model_config=True,
**repo_loading_kwargs,
)
# print(pretrained_model_name_or_path)
# self.generation_config = GenerationConfig.from_pretrained(
# pretrained_model_name_or_path,
# config_file_name="config.json",
# _from_auto=from_auto_class,
# _from_pipeline=from_pipeline,
# _from_model_config=True,
# **repo_loading_kwargs,
# )
# Load custom generate function if `pretrained_model_name_or_path` defines it (and override `generate`)
if hasattr(self, "load_custom_generate") and trust_remote_code:
try:
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/auto/tokenization_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -680,7 +680,7 @@ def from_pretrained(
if not isinstance(config, PreTrainedConfig):
if gguf_file:
gguf_path = cached_file(pretrained_model_name_or_path, gguf_file, **kwargs)
config_dict = load_gguf_checkpoint(gguf_path, return_tensors=False)["config"]
config_dict = load_gguf_checkpoint(gguf_path)["config"]
config = AutoConfig.for_model(**config_dict)
else:
config = AutoConfig.from_pretrained(
Expand Down
65 changes: 32 additions & 33 deletions src/transformers/tokenization_utils_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1622,11 +1622,10 @@ def from_pretrained(
# For legacy support: allow single-file loading if:
# 1. Only one vocab file is required, OR
# 2. It's a fast tokenizer with tokenizer_file (which is optional), OR
# 3. It's a GGUF file
vocab_files_count = len(cls.vocab_files_names)
has_optional_tokenizer_file = vocab_files_count > 1 and "tokenizer_file" in cls.vocab_files_names

if vocab_files_count > 1 and not gguf_file and not has_optional_tokenizer_file:
if vocab_files_count > 1 and not has_optional_tokenizer_file:
raise ValueError(
f"Calling {cls.__name__}.from_pretrained() with the path to a single file or url is not "
"supported for this tokenizer. Use a model identifier or the path to a directory instead."
Expand All @@ -1650,35 +1649,35 @@ def from_pretrained(
"chat_template_file": CHAT_TEMPLATE_FILE,
}

vocab_files = {**cls.vocab_files_names, **additional_files_names}

# Check for versioned tokenizer files
if "tokenizer_file" in vocab_files:
fast_tokenizer_file = FULL_TOKENIZER_FILE
try:
resolved_config_file = cached_file(
pretrained_model_name_or_path,
TOKENIZER_CONFIG_FILE,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
token=token,
revision=revision,
local_files_only=local_files_only,
subfolder=subfolder,
user_agent=user_agent,
_raise_exceptions_for_missing_entries=False,
_commit_hash=commit_hash,
)
if resolved_config_file is not None:
with open(resolved_config_file, encoding="utf-8") as reader:
tokenizer_config = json.load(reader)
if "fast_tokenizer_files" in tokenizer_config:
fast_tokenizer_file = get_fast_tokenizer_file(tokenizer_config["fast_tokenizer_files"])
commit_hash = extract_commit_hash(resolved_config_file, commit_hash)
except Exception:
pass
vocab_files["tokenizer_file"] = fast_tokenizer_file
vocab_files = {**cls.vocab_files_names, **additional_files_names}

# Check for versioned tokenizer files
if "tokenizer_file" in vocab_files:
fast_tokenizer_file = FULL_TOKENIZER_FILE
try:
resolved_config_file = cached_file(
pretrained_model_name_or_path,
TOKENIZER_CONFIG_FILE,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
token=token,
revision=revision,
local_files_only=local_files_only,
subfolder=subfolder,
user_agent=user_agent,
_raise_exceptions_for_missing_entries=False,
_commit_hash=commit_hash,
)
if resolved_config_file is not None:
with open(resolved_config_file, encoding="utf-8") as reader:
tokenizer_config = json.load(reader)
if "fast_tokenizer_files" in tokenizer_config:
fast_tokenizer_file = get_fast_tokenizer_file(tokenizer_config["fast_tokenizer_files"])
commit_hash = extract_commit_hash(resolved_config_file, commit_hash)
except Exception:
pass
vocab_files["tokenizer_file"] = fast_tokenizer_file

# This block looks for any extra chat template files
if is_local:
Expand Down Expand Up @@ -1989,7 +1988,6 @@ def _from_pretrained(
added_tokens_decoder[idx] = AddedToken(**serialized_tokens)
added_tokens_map[str(added_tokens_decoder[idx])] = added_tokens_decoder[idx]
# end legacy

# Passing AddedTokens and not strings to the class to prevent it from casting the string to a different AddedToken
# convert {'__type': 'AddedToken', 'content': '<ent>', 'lstrip': False, 'normalized': True, ...} to AddedTokens
init_kwargs["added_tokens_decoder"] = added_tokens_decoder
Expand All @@ -2003,7 +2001,8 @@ def _from_pretrained(
# for `tokenizers` based tokenizer, we actually want to have vocab and merges pre-extracted from whatever inputs
# for `none` (PythonBackend) based tokenizer, we also want the vocab file / merge files not extracted.
# for `sentencepiece` based tokenizer, we pass the sentencepiece model file directly.
init_kwargs = cls.convert_to_native_format(**init_kwargs)
if not kwargs.get("gguf_file"):
init_kwargs = cls.convert_to_native_format(**init_kwargs)

try:
tokenizer = cls(*init_inputs, **init_kwargs)
Expand Down
15 changes: 9 additions & 6 deletions src/transformers/utils/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,16 +54,19 @@ def _is_package_available(pkg_name: str, return_version: bool = False) -> tuple[
try:
# importlib.metadata works with the distribution package, which may be different from the import
# name (e.g. `PIL` is the import name, but `pillow` is the distribution name)
distributions = PACKAGE_DISTRIBUTION_MAPPING[pkg_name]
distributions = PACKAGE_DISTRIBUTION_MAPPING.get(pkg_name, None)
# Per PEP 503, underscores and hyphens are equivalent in package names.
# Prefer the distribution that matches the (normalized) package name.
normalized_pkg_name = pkg_name.replace("_", "-")
if normalized_pkg_name in distributions:
distribution_name = normalized_pkg_name
elif pkg_name in distributions:
distribution_name = pkg_name
if distributions is not None:
if normalized_pkg_name in distributions:
distribution_name = normalized_pkg_name
elif pkg_name in distributions:
distribution_name = pkg_name
else:
distribution_name = distributions[0]
else:
distribution_name = distributions[0]
distribution_name = normalized_pkg_name
package_version = importlib.metadata.version(distribution_name)
except (importlib.metadata.PackageNotFoundError, KeyError):
# If we cannot find the metadata (because of editable install for example), try to import directly.
Expand Down