Skip to content

Commit

Permalink
[Fix Download] update converted logic & fix hf hub download subfolder…
Browse files Browse the repository at this point in the history
… bug (#7911)

* update converted logic & fix hf hub download subfolder bug
  • Loading branch information
JunnYu authored Jan 29, 2024
1 parent fff730e commit 797efa6
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 29 deletions.
54 changes: 26 additions & 28 deletions paddlenlp/transformers/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
from paddle.utils.download import is_url as is_remote_url
from tqdm.auto import tqdm

from paddlenlp.utils.downloader import get_path_from_url_with_filelock, hf_file_exists
from paddlenlp.utils.downloader import get_path_from_url_with_filelock
from paddlenlp.utils.env import (
CONFIG_NAME,
LEGACY_CONFIG_NAME,
Expand Down Expand Up @@ -367,28 +367,7 @@ def resolve_weight_file_from_hf_hub(repo_id: str, cache_dir: str, support_conver
support_conversion (bool): whether support converting pytorch weight file to paddle weight file
subfolder (str, optional) An optional value corresponding to a folder inside the repo.
"""
is_local = os.path.isdir(repo_id)
if not is_local:
if hf_file_exists(repo_id, PADDLE_WEIGHTS_NAME, subfolder=subfolder):
file_name = PADDLE_WEIGHTS_NAME
assert (
support_conversion is False
), "Please call set convert_from_torch for paddle weights on huggingface hub, eg. Model.from_pretrained(model_name, from_hf_hub=True, convert_from_torch=False)"
elif hf_file_exists(repo_id, PYTORCH_WEIGHTS_NAME, subfolder=subfolder):
if not support_conversion:
raise EntryNotFoundError(
f"can not download `{PADDLE_WEIGHTS_NAME} from https://huggingface.co/{repo_id}` "
"and current model doesn't support conversion from pytorch weight file to paddle weight file"
)
file_name = PYTORCH_WEIGHTS_NAME
else:
raise EntryNotFoundError(
message=f"can not find the paddle/pytorch weight file from: https://huggingface.co/{repo_id}",
response=None,
)
else:
# for local file, we use support_conversion to select paddle or torch weight.
file_name = PYTORCH_WEIGHTS_NAME if support_conversion else PADDLE_WEIGHTS_NAME
file_name = PYTORCH_WEIGHTS_NAME if support_conversion else PADDLE_WEIGHTS_NAME

file_name_list = [SAFE_WEIGHTS_NAME] + [file_name] + [PYTORCH_WEIGHTS_INDEX_NAME] + [SAFE_WEIGHTS_INDEX_NAME]
resolved_file = None
Expand Down Expand Up @@ -2156,12 +2135,31 @@ def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
or resolved_archive_file.endswith(SAFE_WEIGHTS_NAME)
or resolved_archive_file.endswith(SAFE_WEIGHTS_INDEX_NAME)
):
# try to get the name-mapping info
logger.info(
f"Starting to convert pytorch weight file<{resolved_archive_file}> to "
f"paddle weight file<{os.path.join(cache_dir, PADDLE_WEIGHTS_NAME)}> ..."
converted_paddle_weights = os.path.join(
os.path.dirname(resolved_archive_file), PADDLE_WEIGHTS_NAME
)
state_dict = cls.convert(resolved_archive_file, config, cache_dir)
if not os.path.exists(converted_paddle_weights):
# try to get the name-mapping info
logger.info(
f"Starting to convert pytorch weight file <{resolved_archive_file}> to "
f"paddle weight file <{converted_paddle_weights}> ..."
)
state_dict = cls.convert(resolved_archive_file, config, os.path.dirname(resolved_archive_file))
else:
# try to load the converted paddle weight file
resolved_archive_file = converted_paddle_weights
sharded_metadata = None
is_sharded = False
logger.info(
f"Detect the converted Paddle weight file <{converted_paddle_weights}>. We intend to reuse this file."
)
if config.tensor_parallel_degree > 1 and resolved_archive_file.endswith(
"model_state.pdparams"
):
state_dict = cls.convert_tensor_parallel(resolved_archive_file, config)
else:
state_dict = load_state_dict(resolved_archive_file)
logger.info("Loaded weights file from disk, setting weights to model.")
else:
raise ValueError(f"Unexpected file: {resolved_archive_file} for weight conversion.")
else:
Expand Down
2 changes: 1 addition & 1 deletion paddlenlp/transformers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -587,7 +587,7 @@ def cached_file_for_hf_hub(
download_check(path_or_repo_id, full_filename, addition="from_hf_hub")
resolved_file = hf_hub_download(
repo_id=path_or_repo_id,
filename=full_filename,
filename=filename,
cache_dir=cache_dir,
subfolder=subfolder,
library_name="PaddleNLP",
Expand Down

0 comments on commit 797efa6

Please sign in to comment.