Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions docs/source/en/guides/cli.md
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,46 @@ A `.cache/huggingface/` folder is created at the root of your local directory co
fuyu/model-00001-of-00002.safetensors
```

### Dry-run mode

In some cases, you would like to check which files would be downloaded before actually downloading them. You can check this using the `--dry-run` parameter. It lists all files to download on the repo and checks whether they are already downloaded or not. This gives an idea of how many files have to be downloaded and their sizes.

```sh
>>> hf download openai-community/gpt2 --dry-run
[dry-run] Fetching 26 files: 100%|█████████████| 26/26 [00:04<00:00, 6.26it/s]
[dry-run] Will download 11 files (out of 26) totalling 5.6G.
File Bytes to download
--------------------------------- -----------------
.gitattributes -
64-8bits.tflite 125.2M
64-fp16.tflite 248.3M
64.tflite 495.8M
README.md -
config.json -
flax_model.msgpack 497.8M
generation_config.json -
merges.txt -
model.safetensors 548.1M
onnx/config.json -
onnx/decoder_model.onnx 653.7M
onnx/decoder_model_merged.onnx 655.2M
onnx/decoder_with_past_model.onnx 653.7M
onnx/generation_config.json -
onnx/merges.txt -
onnx/special_tokens_map.json -
onnx/tokenizer.json -
onnx/tokenizer_config.json -
onnx/vocab.json -
pytorch_model.bin 548.1M
rust_model.ot 702.5M
tf_model.h5 497.9M
tokenizer.json -
tokenizer_config.json -
vocab.json -
```

For more details, check out the [download guide](./download.md#dry-run-mode).

### Specify cache directory

If not using `--local-dir`, all files will be downloaded by default to the cache directory defined by the `HF_HOME` [environment variable](../package_reference/environment_variables#hfhome). You can specify a custom cache using `--cache-dir`:
Expand Down
83 changes: 83 additions & 0 deletions docs/source/en/guides/download.md
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,89 @@ Fetching 2 files: 100%|███████████████████

For more details about the CLI download command, please refer to the [CLI guide](./cli#hf-download).

## Dry-run mode

In some cases, you would like to check which files would be downloaded before actually downloading them. You can check this using the `--dry-run` parameter. It lists all files to download on the repo and checks whether they are already downloaded or not. This gives an idea of how many files have to be downloaded and their sizes.

Here is an example, checking on a single file:

```sh
>>> hf download openai-community/gpt2 onnx/decoder_model_merged.onnx --dry-run
[dry-run] Will download 1 files (out of 1) totalling 655.2M
File Bytes to download
------------------------------ -----------------
onnx/decoder_model_merged.onnx 655.2M
```

And if the file is already cached:

```sh
>>> hf download openai-community/gpt2 onnx/decoder_model_merged.onnx --dry-run
[dry-run] Will download 0 files (out of 1) totalling 0.0.
File Bytes to download
------------------------------ -----------------
onnx/decoder_model_merged.onnx -
```

You can also execute a dry-run on an entire repository:

```sh
>>> hf download openai-community/gpt2 --dry-run
[dry-run] Fetching 26 files: 100%|█████████████| 26/26 [00:04<00:00, 6.26it/s]
[dry-run] Will download 11 files (out of 26) totalling 5.6G.
File Bytes to download
--------------------------------- -----------------
.gitattributes -
64-8bits.tflite 125.2M
64-fp16.tflite 248.3M
64.tflite 495.8M
README.md -
config.json -
flax_model.msgpack 497.8M
generation_config.json -
merges.txt -
model.safetensors 548.1M
onnx/config.json -
onnx/decoder_model.onnx 653.7M
onnx/decoder_model_merged.onnx 655.2M
onnx/decoder_with_past_model.onnx 653.7M
onnx/generation_config.json -
onnx/merges.txt -
onnx/special_tokens_map.json -
onnx/tokenizer.json -
onnx/tokenizer_config.json -
onnx/vocab.json -
pytorch_model.bin 548.1M
rust_model.ot 702.5M
tf_model.h5 497.9M
tokenizer.json -
tokenizer_config.json -
vocab.json -
```

And with files filtering:

```sh
>>> hf download openai-community/gpt2 --include "*.json" --dry-run
[dry-run] Fetching 11 files: 100%|█████████████| 11/11 [00:00<00:00, 80518.92it/s]
[dry-run] Will download 0 files (out of 11) totalling 0.0.
File Bytes to download
---------------------------- -----------------
config.json -
generation_config.json -
onnx/config.json -
onnx/generation_config.json -
onnx/special_tokens_map.json -
onnx/tokenizer.json -
onnx/tokenizer_config.json -
onnx/vocab.json -
tokenizer.json -
tokenizer_config.json -
vocab.json -
```

Finally, you can also make a dry-run programmatically by passing `dry_run=True` to [`hf_hub_download`] and [`snapshot_download`]. It will return a [`DryRunFileInfo`] (respectively a list of [`DryRunFileInfo`]) with for each file, their commit hash, file name and file size, whether the file is cached and whether the file would be downloaded. In practice, the file will be downloaded if not cached or if `force_download=True` is passed.

## Faster downloads

There are two options to speed up downloads. Both involve installing a Python package written in Rust.
Expand Down
4 changes: 4 additions & 0 deletions docs/source/en/package_reference/hf_api.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,10 @@ models = hf_api.list_models()

[[autodoc]] huggingface_hub.hf_api.DatasetInfo

### DryRunFileInfo

[[autodoc]] huggingface_hub.hf_api.DryRunFileInfo

### GitRefInfo

[[autodoc]] huggingface_hub.hf_api.GitRefInfo
Expand Down
3 changes: 3 additions & 0 deletions src/huggingface_hub/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@
"push_to_hub_fastai",
],
"file_download": [
"DryRunFileInfo",
"HfFileMetadata",
"_CACHED_NO_EXIST",
"get_hf_file_metadata",
Expand Down Expand Up @@ -625,6 +626,7 @@
"DocumentQuestionAnsweringInputData",
"DocumentQuestionAnsweringOutputElement",
"DocumentQuestionAnsweringParameters",
"DryRunFileInfo",
"EvalResult",
"FLAX_WEIGHTS_NAME",
"FeatureExtractionInput",
Expand Down Expand Up @@ -1147,6 +1149,7 @@ def __dir__():
)
from .file_download import (
_CACHED_NO_EXIST, # noqa: F401
DryRunFileInfo, # noqa: F401
HfFileMetadata, # noqa: F401
get_hf_file_metadata, # noqa: F401
hf_hub_download, # noqa: F401
Expand Down
140 changes: 119 additions & 21 deletions src/huggingface_hub/_snapshot_download.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,21 @@
import os
from pathlib import Path
from typing import Iterable, Optional, Union
from typing import Iterable, List, Literal, Optional, Union, overload

import httpx
from tqdm.auto import tqdm as base_tqdm
from tqdm.contrib.concurrent import thread_map

from . import constants
from .errors import (
DryRunError,
GatedRepoError,
HfHubHTTPError,
LocalEntryNotFoundError,
RepositoryNotFoundError,
RevisionNotFoundError,
)
from .file_download import REGEX_COMMIT_HASH, hf_hub_download, repo_folder_name
from .file_download import REGEX_COMMIT_HASH, DryRunFileInfo, hf_hub_download, repo_folder_name
from .hf_api import DatasetInfo, HfApi, ModelInfo, RepoFile, SpaceInfo
from .utils import OfflineModeIsEnabled, filter_repo_objects, logging, validate_hf_hub_args
from .utils import tqdm as hf_tqdm
Expand All @@ -25,6 +26,81 @@
VERY_LARGE_REPO_THRESHOLD = 50000 # After this limit, we don't consider `repo_info.siblings` to be reliable enough


@overload
def snapshot_download(
repo_id: str,
*,
repo_type: Optional[str] = None,
revision: Optional[str] = None,
cache_dir: Union[str, Path, None] = None,
local_dir: Union[str, Path, None] = None,
library_name: Optional[str] = None,
library_version: Optional[str] = None,
user_agent: Optional[Union[dict, str]] = None,
etag_timeout: float = constants.DEFAULT_ETAG_TIMEOUT,
force_download: bool = False,
token: Optional[Union[bool, str]] = None,
local_files_only: bool = False,
allow_patterns: Optional[Union[list[str], str]] = None,
ignore_patterns: Optional[Union[list[str], str]] = None,
max_workers: int = 8,
tqdm_class: Optional[type[base_tqdm]] = None,
headers: Optional[dict[str, str]] = None,
endpoint: Optional[str] = None,
dry_run: Literal[False] = False,
) -> str: ...


@overload
def snapshot_download(
repo_id: str,
*,
repo_type: Optional[str] = None,
revision: Optional[str] = None,
cache_dir: Union[str, Path, None] = None,
local_dir: Union[str, Path, None] = None,
library_name: Optional[str] = None,
library_version: Optional[str] = None,
user_agent: Optional[Union[dict, str]] = None,
etag_timeout: float = constants.DEFAULT_ETAG_TIMEOUT,
force_download: bool = False,
token: Optional[Union[bool, str]] = None,
local_files_only: bool = False,
allow_patterns: Optional[Union[list[str], str]] = None,
ignore_patterns: Optional[Union[list[str], str]] = None,
max_workers: int = 8,
tqdm_class: Optional[type[base_tqdm]] = None,
headers: Optional[dict[str, str]] = None,
endpoint: Optional[str] = None,
dry_run: Literal[True] = True,
) -> list[DryRunFileInfo]: ...


@overload
def snapshot_download(
repo_id: str,
*,
repo_type: Optional[str] = None,
revision: Optional[str] = None,
cache_dir: Union[str, Path, None] = None,
local_dir: Union[str, Path, None] = None,
library_name: Optional[str] = None,
library_version: Optional[str] = None,
user_agent: Optional[Union[dict, str]] = None,
etag_timeout: float = constants.DEFAULT_ETAG_TIMEOUT,
force_download: bool = False,
token: Optional[Union[bool, str]] = None,
local_files_only: bool = False,
allow_patterns: Optional[Union[list[str], str]] = None,
ignore_patterns: Optional[Union[list[str], str]] = None,
max_workers: int = 8,
tqdm_class: Optional[type[base_tqdm]] = None,
headers: Optional[dict[str, str]] = None,
endpoint: Optional[str] = None,
dry_run: bool = False,
) -> Union[str, list[DryRunFileInfo]]: ...


@validate_hf_hub_args
def snapshot_download(
repo_id: str,
Expand All @@ -46,7 +122,8 @@ def snapshot_download(
tqdm_class: Optional[type[base_tqdm]] = None,
headers: Optional[dict[str, str]] = None,
endpoint: Optional[str] = None,
) -> str:
dry_run: bool = False,
) -> Union[str, list[DryRunFileInfo]]:
"""Download repo files.

Download a whole snapshot of a repo's files at the specified revision. This is useful when you want all files from
Expand Down Expand Up @@ -109,9 +186,14 @@ def snapshot_download(
Note that the `tqdm_class` is not passed to each individual download.
Defaults to the custom HF progress bar that can be disabled by setting
`HF_HUB_DISABLE_PROGRESS_BARS` environment variable.
dry_run (`bool`, *optional*, defaults to `False`):
If `True`, perform a dry run without actually downloading the files. Returns a list of
[`DryRunFileInfo`] objects containing information about what would be downloaded.

Returns:
`str`: folder path of the repo snapshot.
`str` or list of [`DryRunFileInfo`]:
- If `dry_run=False`: Local snapshot path.
- If `dry_run=True`: A list of [`DryRunFileInfo`] objects containing download information.

Raises:
[`~utils.RepositoryNotFoundError`]
Expand Down Expand Up @@ -187,6 +269,11 @@ def snapshot_download(
# - f the specified revision is a branch or tag, look inside "refs".
# => if local_dir is not None, we will return the path to the local folder if it exists.
if repo_info is None:
if dry_run:
raise DryRunError(
"Dry run cannot be performed as the repository cannot be accessed. Please check your internet connection or authentication token."
) from api_call_error

# Try to get which commit hash corresponds to the specified revision
commit_hash = None
if REGEX_COMMIT_HASH.match(revision):
Expand Down Expand Up @@ -273,6 +360,8 @@ def snapshot_download(
tqdm_desc = f"Fetching {len(filtered_repo_files)} files"
else:
tqdm_desc = "Fetching ... files"
if dry_run:
tqdm_desc = "[dry-run] " + tqdm_desc

commit_hash = repo_info.sha
snapshot_folder = os.path.join(storage_folder, "snapshots", commit_hash)
Expand All @@ -288,28 +377,33 @@ def snapshot_download(
except OSError as e:
logger.warning(f"Ignored error while writing commit hash to {ref_path}: {e}.")

results: List[Union[str, DryRunFileInfo]] = []

# we pass the commit_hash to hf_hub_download
# so no network call happens if we already
# have the file locally.
def _inner_hf_hub_download(repo_file: str):
return hf_hub_download(
repo_id,
filename=repo_file,
repo_type=repo_type,
revision=commit_hash,
endpoint=endpoint,
cache_dir=cache_dir,
local_dir=local_dir,
library_name=library_name,
library_version=library_version,
user_agent=user_agent,
etag_timeout=etag_timeout,
force_download=force_download,
token=token,
headers=headers,
def _inner_hf_hub_download(repo_file: str) -> None:
results.append(
hf_hub_download( # type: ignore[no-matching-overload] # ty not happy, don't know why :/
repo_id,
filename=repo_file,
repo_type=repo_type,
revision=commit_hash,
endpoint=endpoint,
cache_dir=cache_dir,
local_dir=local_dir,
library_name=library_name,
library_version=library_version,
user_agent=user_agent,
etag_timeout=etag_timeout,
force_download=force_download,
token=token,
headers=headers,
dry_run=dry_run,
)
)

if constants.HF_HUB_ENABLE_HF_TRANSFER:
if constants.HF_HUB_ENABLE_HF_TRANSFER and not dry_run:
# when using hf_transfer we don't want extra parallelism
# from the one hf_transfer provides
for file in filtered_repo_files:
Expand All @@ -324,6 +418,10 @@ def _inner_hf_hub_download(repo_file: str):
tqdm_class=tqdm_class or hf_tqdm,
)

if dry_run:
assert all(isinstance(r, DryRunFileInfo) for r in results)
return results # type: ignore

if local_dir is not None:
return str(os.path.realpath(local_dir))
return snapshot_folder
3 changes: 1 addition & 2 deletions src/huggingface_hub/_upload_large_folder.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,7 @@
from ._commit_api import CommitOperationAdd, UploadInfo, _fetch_upload_modes
from ._local_folder import LocalUploadFileMetadata, LocalUploadFilePaths, get_local_upload_paths, read_upload_metadata
from .constants import DEFAULT_REVISION, REPO_TYPES
from .utils import DEFAULT_IGNORE_PATTERNS, filter_repo_objects, tqdm
from .utils._cache_manager import _format_size
from .utils import DEFAULT_IGNORE_PATTERNS, _format_size, filter_repo_objects, tqdm
from .utils._runtime import is_xet_available
from .utils.sha import sha_fileobj

Expand Down
Loading