Skip to content

Commit c5c1c5f

Browse files
Implement dry run mode in download CLI (#3407)
* Implement dry run mode * docs * more docs * quality * quality * fix cli test * Apply suggestions from code review Co-authored-by: célina <hanouticelina@gmail.com> * fix test on widnwso --------- Co-authored-by: célina <hanouticelina@gmail.com>
1 parent 379c06a commit c5c1c5f

File tree

15 files changed

+652
-64
lines changed

15 files changed

+652
-64
lines changed

docs/source/en/guides/cli.md

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -254,6 +254,46 @@ A `.cache/huggingface/` folder is created at the root of your local directory co
254254
fuyu/model-00001-of-00002.safetensors
255255
```
256256

257+
### Dry-run mode
258+
259+
In some cases, you would like to check which files would be downloaded before actually downloading them. You can check this using the `--dry-run` parameter. It lists all files to download on the repo and checks whether they are already downloaded or not. This gives an idea of how many files have to be downloaded and their sizes.
260+
261+
```sh
262+
>>> hf download openai-community/gpt2 --dry-run
263+
[dry-run] Fetching 26 files: 100%|█████████████| 26/26 [00:04<00:00, 6.26it/s]
264+
[dry-run] Will download 11 files (out of 26) totalling 5.6G.
265+
File Bytes to download
266+
--------------------------------- -----------------
267+
.gitattributes -
268+
64-8bits.tflite 125.2M
269+
64-fp16.tflite 248.3M
270+
64.tflite 495.8M
271+
README.md -
272+
config.json -
273+
flax_model.msgpack 497.8M
274+
generation_config.json -
275+
merges.txt -
276+
model.safetensors 548.1M
277+
onnx/config.json -
278+
onnx/decoder_model.onnx 653.7M
279+
onnx/decoder_model_merged.onnx 655.2M
280+
onnx/decoder_with_past_model.onnx 653.7M
281+
onnx/generation_config.json -
282+
onnx/merges.txt -
283+
onnx/special_tokens_map.json -
284+
onnx/tokenizer.json -
285+
onnx/tokenizer_config.json -
286+
onnx/vocab.json -
287+
pytorch_model.bin 548.1M
288+
rust_model.ot 702.5M
289+
tf_model.h5 497.9M
290+
tokenizer.json -
291+
tokenizer_config.json -
292+
vocab.json -
293+
```
294+
295+
For more details, check out the [download guide](./download.md#dry-run-mode).
296+
257297
### Specify cache directory
258298

259299
If not using `--local-dir`, all files will be downloaded by default to the cache directory defined by the `HF_HOME` [environment variable](../package_reference/environment_variables#hfhome). You can specify a custom cache using `--cache-dir`:

docs/source/en/guides/download.md

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,89 @@ Fetching 2 files: 100%|███████████████████
158158

159159
For more details about the CLI download command, please refer to the [CLI guide](./cli#hf-download).
160160

161+
## Dry-run mode
162+
163+
In some cases, you would like to check which files would be downloaded before actually downloading them. You can check this using the `--dry-run` parameter. It lists all files to download on the repo and checks whether they are already downloaded or not. This gives an idea of how many files have to be downloaded and their sizes.
164+
165+
Here is an example, checking on a single file:
166+
167+
```sh
168+
>>> hf download openai-community/gpt2 onnx/decoder_model_merged.onnx --dry-run
169+
[dry-run] Will download 1 files (out of 1) totalling 655.2M
170+
File Bytes to download
171+
------------------------------ -----------------
172+
onnx/decoder_model_merged.onnx 655.2M
173+
```
174+
175+
And if the file is already cached:
176+
177+
```sh
178+
>>> hf download openai-community/gpt2 onnx/decoder_model_merged.onnx --dry-run
179+
[dry-run] Will download 0 files (out of 1) totalling 0.0.
180+
File Bytes to download
181+
------------------------------ -----------------
182+
onnx/decoder_model_merged.onnx -
183+
```
184+
185+
You can also execute a dry-run on an entire repository:
186+
187+
```sh
188+
>>> hf download openai-community/gpt2 --dry-run
189+
[dry-run] Fetching 26 files: 100%|█████████████| 26/26 [00:04<00:00, 6.26it/s]
190+
[dry-run] Will download 11 files (out of 26) totalling 5.6G.
191+
File Bytes to download
192+
--------------------------------- -----------------
193+
.gitattributes -
194+
64-8bits.tflite 125.2M
195+
64-fp16.tflite 248.3M
196+
64.tflite 495.8M
197+
README.md -
198+
config.json -
199+
flax_model.msgpack 497.8M
200+
generation_config.json -
201+
merges.txt -
202+
model.safetensors 548.1M
203+
onnx/config.json -
204+
onnx/decoder_model.onnx 653.7M
205+
onnx/decoder_model_merged.onnx 655.2M
206+
onnx/decoder_with_past_model.onnx 653.7M
207+
onnx/generation_config.json -
208+
onnx/merges.txt -
209+
onnx/special_tokens_map.json -
210+
onnx/tokenizer.json -
211+
onnx/tokenizer_config.json -
212+
onnx/vocab.json -
213+
pytorch_model.bin 548.1M
214+
rust_model.ot 702.5M
215+
tf_model.h5 497.9M
216+
tokenizer.json -
217+
tokenizer_config.json -
218+
vocab.json -
219+
```
220+
221+
And with files filtering:
222+
223+
```sh
224+
>>> hf download openai-community/gpt2 --include "*.json" --dry-run
225+
[dry-run] Fetching 11 files: 100%|█████████████| 11/11 [00:00<00:00, 80518.92it/s]
226+
[dry-run] Will download 0 files (out of 11) totalling 0.0.
227+
File Bytes to download
228+
---------------------------- -----------------
229+
config.json -
230+
generation_config.json -
231+
onnx/config.json -
232+
onnx/generation_config.json -
233+
onnx/special_tokens_map.json -
234+
onnx/tokenizer.json -
235+
onnx/tokenizer_config.json -
236+
onnx/vocab.json -
237+
tokenizer.json -
238+
tokenizer_config.json -
239+
vocab.json -
240+
```
241+
242+
Finally, you can also make a dry-run programmatically by passing `dry_run=True` to [`hf_hub_download`] and [`snapshot_download`]. It will return a [`DryRunFileInfo`] (respectively a list of [`DryRunFileInfo`]) with for each file, their commit hash, file name and file size, whether the file is cached and whether the file would be downloaded. In practice, the file will be downloaded if not cached or if `force_download=True` is passed.
243+
161244
## Faster downloads
162245

163246
There are two options to speed up downloads. Both involve installing a Python package written in Rust.

docs/source/en/package_reference/hf_api.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,10 @@ models = hf_api.list_models()
4545

4646
[[autodoc]] huggingface_hub.hf_api.DatasetInfo
4747

48+
### DryRunFileInfo
49+
50+
[[autodoc]] huggingface_hub.hf_api.DryRunFileInfo
51+
4852
### GitRefInfo
4953

5054
[[autodoc]] huggingface_hub.hf_api.GitRefInfo

src/huggingface_hub/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,7 @@
138138
"push_to_hub_fastai",
139139
],
140140
"file_download": [
141+
"DryRunFileInfo",
141142
"HfFileMetadata",
142143
"_CACHED_NO_EXIST",
143144
"get_hf_file_metadata",
@@ -625,6 +626,7 @@
625626
"DocumentQuestionAnsweringInputData",
626627
"DocumentQuestionAnsweringOutputElement",
627628
"DocumentQuestionAnsweringParameters",
629+
"DryRunFileInfo",
628630
"EvalResult",
629631
"FLAX_WEIGHTS_NAME",
630632
"FeatureExtractionInput",
@@ -1147,6 +1149,7 @@ def __dir__():
11471149
)
11481150
from .file_download import (
11491151
_CACHED_NO_EXIST, # noqa: F401
1152+
DryRunFileInfo, # noqa: F401
11501153
HfFileMetadata, # noqa: F401
11511154
get_hf_file_metadata, # noqa: F401
11521155
hf_hub_download, # noqa: F401

src/huggingface_hub/_snapshot_download.py

Lines changed: 119 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,21 @@
11
import os
22
from pathlib import Path
3-
from typing import Iterable, Optional, Union
3+
from typing import Iterable, List, Literal, Optional, Union, overload
44

55
import httpx
66
from tqdm.auto import tqdm as base_tqdm
77
from tqdm.contrib.concurrent import thread_map
88

99
from . import constants
1010
from .errors import (
11+
DryRunError,
1112
GatedRepoError,
1213
HfHubHTTPError,
1314
LocalEntryNotFoundError,
1415
RepositoryNotFoundError,
1516
RevisionNotFoundError,
1617
)
17-
from .file_download import REGEX_COMMIT_HASH, hf_hub_download, repo_folder_name
18+
from .file_download import REGEX_COMMIT_HASH, DryRunFileInfo, hf_hub_download, repo_folder_name
1819
from .hf_api import DatasetInfo, HfApi, ModelInfo, RepoFile, SpaceInfo
1920
from .utils import OfflineModeIsEnabled, filter_repo_objects, logging, validate_hf_hub_args
2021
from .utils import tqdm as hf_tqdm
@@ -25,6 +26,81 @@
2526
VERY_LARGE_REPO_THRESHOLD = 50000 # After this limit, we don't consider `repo_info.siblings` to be reliable enough
2627

2728

29+
@overload
30+
def snapshot_download(
31+
repo_id: str,
32+
*,
33+
repo_type: Optional[str] = None,
34+
revision: Optional[str] = None,
35+
cache_dir: Union[str, Path, None] = None,
36+
local_dir: Union[str, Path, None] = None,
37+
library_name: Optional[str] = None,
38+
library_version: Optional[str] = None,
39+
user_agent: Optional[Union[dict, str]] = None,
40+
etag_timeout: float = constants.DEFAULT_ETAG_TIMEOUT,
41+
force_download: bool = False,
42+
token: Optional[Union[bool, str]] = None,
43+
local_files_only: bool = False,
44+
allow_patterns: Optional[Union[list[str], str]] = None,
45+
ignore_patterns: Optional[Union[list[str], str]] = None,
46+
max_workers: int = 8,
47+
tqdm_class: Optional[type[base_tqdm]] = None,
48+
headers: Optional[dict[str, str]] = None,
49+
endpoint: Optional[str] = None,
50+
dry_run: Literal[False] = False,
51+
) -> str: ...
52+
53+
54+
@overload
55+
def snapshot_download(
56+
repo_id: str,
57+
*,
58+
repo_type: Optional[str] = None,
59+
revision: Optional[str] = None,
60+
cache_dir: Union[str, Path, None] = None,
61+
local_dir: Union[str, Path, None] = None,
62+
library_name: Optional[str] = None,
63+
library_version: Optional[str] = None,
64+
user_agent: Optional[Union[dict, str]] = None,
65+
etag_timeout: float = constants.DEFAULT_ETAG_TIMEOUT,
66+
force_download: bool = False,
67+
token: Optional[Union[bool, str]] = None,
68+
local_files_only: bool = False,
69+
allow_patterns: Optional[Union[list[str], str]] = None,
70+
ignore_patterns: Optional[Union[list[str], str]] = None,
71+
max_workers: int = 8,
72+
tqdm_class: Optional[type[base_tqdm]] = None,
73+
headers: Optional[dict[str, str]] = None,
74+
endpoint: Optional[str] = None,
75+
dry_run: Literal[True] = True,
76+
) -> list[DryRunFileInfo]: ...
77+
78+
79+
@overload
80+
def snapshot_download(
81+
repo_id: str,
82+
*,
83+
repo_type: Optional[str] = None,
84+
revision: Optional[str] = None,
85+
cache_dir: Union[str, Path, None] = None,
86+
local_dir: Union[str, Path, None] = None,
87+
library_name: Optional[str] = None,
88+
library_version: Optional[str] = None,
89+
user_agent: Optional[Union[dict, str]] = None,
90+
etag_timeout: float = constants.DEFAULT_ETAG_TIMEOUT,
91+
force_download: bool = False,
92+
token: Optional[Union[bool, str]] = None,
93+
local_files_only: bool = False,
94+
allow_patterns: Optional[Union[list[str], str]] = None,
95+
ignore_patterns: Optional[Union[list[str], str]] = None,
96+
max_workers: int = 8,
97+
tqdm_class: Optional[type[base_tqdm]] = None,
98+
headers: Optional[dict[str, str]] = None,
99+
endpoint: Optional[str] = None,
100+
dry_run: bool = False,
101+
) -> Union[str, list[DryRunFileInfo]]: ...
102+
103+
28104
@validate_hf_hub_args
29105
def snapshot_download(
30106
repo_id: str,
@@ -46,7 +122,8 @@ def snapshot_download(
46122
tqdm_class: Optional[type[base_tqdm]] = None,
47123
headers: Optional[dict[str, str]] = None,
48124
endpoint: Optional[str] = None,
49-
) -> str:
125+
dry_run: bool = False,
126+
) -> Union[str, list[DryRunFileInfo]]:
50127
"""Download repo files.
51128
52129
Download a whole snapshot of a repo's files at the specified revision. This is useful when you want all files from
@@ -109,9 +186,14 @@ def snapshot_download(
109186
Note that the `tqdm_class` is not passed to each individual download.
110187
Defaults to the custom HF progress bar that can be disabled by setting
111188
`HF_HUB_DISABLE_PROGRESS_BARS` environment variable.
189+
dry_run (`bool`, *optional*, defaults to `False`):
190+
If `True`, perform a dry run without actually downloading the files. Returns a list of
191+
[`DryRunFileInfo`] objects containing information about what would be downloaded.
112192
113193
Returns:
114-
`str`: folder path of the repo snapshot.
194+
`str` or list of [`DryRunFileInfo`]:
195+
- If `dry_run=False`: Local snapshot path.
196+
- If `dry_run=True`: A list of [`DryRunFileInfo`] objects containing download information.
115197
116198
Raises:
117199
[`~utils.RepositoryNotFoundError`]
@@ -187,6 +269,11 @@ def snapshot_download(
187269
# - f the specified revision is a branch or tag, look inside "refs".
188270
# => if local_dir is not None, we will return the path to the local folder if it exists.
189271
if repo_info is None:
272+
if dry_run:
273+
raise DryRunError(
274+
"Dry run cannot be performed as the repository cannot be accessed. Please check your internet connection or authentication token."
275+
) from api_call_error
276+
190277
# Try to get which commit hash corresponds to the specified revision
191278
commit_hash = None
192279
if REGEX_COMMIT_HASH.match(revision):
@@ -273,6 +360,8 @@ def snapshot_download(
273360
tqdm_desc = f"Fetching {len(filtered_repo_files)} files"
274361
else:
275362
tqdm_desc = "Fetching ... files"
363+
if dry_run:
364+
tqdm_desc = "[dry-run] " + tqdm_desc
276365

277366
commit_hash = repo_info.sha
278367
snapshot_folder = os.path.join(storage_folder, "snapshots", commit_hash)
@@ -288,28 +377,33 @@ def snapshot_download(
288377
except OSError as e:
289378
logger.warning(f"Ignored error while writing commit hash to {ref_path}: {e}.")
290379

380+
results: List[Union[str, DryRunFileInfo]] = []
381+
291382
# we pass the commit_hash to hf_hub_download
292383
# so no network call happens if we already
293384
# have the file locally.
294-
def _inner_hf_hub_download(repo_file: str):
295-
return hf_hub_download(
296-
repo_id,
297-
filename=repo_file,
298-
repo_type=repo_type,
299-
revision=commit_hash,
300-
endpoint=endpoint,
301-
cache_dir=cache_dir,
302-
local_dir=local_dir,
303-
library_name=library_name,
304-
library_version=library_version,
305-
user_agent=user_agent,
306-
etag_timeout=etag_timeout,
307-
force_download=force_download,
308-
token=token,
309-
headers=headers,
385+
def _inner_hf_hub_download(repo_file: str) -> None:
386+
results.append(
387+
hf_hub_download( # type: ignore[no-matching-overload] # ty not happy, don't know why :/
388+
repo_id,
389+
filename=repo_file,
390+
repo_type=repo_type,
391+
revision=commit_hash,
392+
endpoint=endpoint,
393+
cache_dir=cache_dir,
394+
local_dir=local_dir,
395+
library_name=library_name,
396+
library_version=library_version,
397+
user_agent=user_agent,
398+
etag_timeout=etag_timeout,
399+
force_download=force_download,
400+
token=token,
401+
headers=headers,
402+
dry_run=dry_run,
403+
)
310404
)
311405

312-
if constants.HF_HUB_ENABLE_HF_TRANSFER:
406+
if constants.HF_HUB_ENABLE_HF_TRANSFER and not dry_run:
313407
# when using hf_transfer we don't want extra parallelism
314408
# from the one hf_transfer provides
315409
for file in filtered_repo_files:
@@ -324,6 +418,10 @@ def _inner_hf_hub_download(repo_file: str):
324418
tqdm_class=tqdm_class or hf_tqdm,
325419
)
326420

421+
if dry_run:
422+
assert all(isinstance(r, DryRunFileInfo) for r in results)
423+
return results # type: ignore
424+
327425
if local_dir is not None:
328426
return str(os.path.realpath(local_dir))
329427
return snapshot_folder

src/huggingface_hub/_upload_large_folder.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,7 @@
3131
from ._commit_api import CommitOperationAdd, UploadInfo, _fetch_upload_modes
3232
from ._local_folder import LocalUploadFileMetadata, LocalUploadFilePaths, get_local_upload_paths, read_upload_metadata
3333
from .constants import DEFAULT_REVISION, REPO_TYPES
34-
from .utils import DEFAULT_IGNORE_PATTERNS, filter_repo_objects, tqdm
35-
from .utils._cache_manager import _format_size
34+
from .utils import DEFAULT_IGNORE_PATTERNS, _format_size, filter_repo_objects, tqdm
3635
from .utils._runtime import is_xet_available
3736
from .utils.sha import sha_fileobj
3837

0 commit comments

Comments
 (0)