66import json
77import os
88import tempfile
9+ import time
910from collections import defaultdict
1011from typing import Any , Callable , Dict , Generator , List , Optional , Tuple , Union
1112
1415import huggingface_hub .constants
1516import numpy as np
1617import torch
17- from huggingface_hub import HfFileSystem , hf_hub_download , snapshot_download
18+ from huggingface_hub import (HfFileSystem , hf_hub_download , scan_cache_dir ,
19+ snapshot_download )
1820from safetensors .torch import load_file , safe_open , save_file
1921from tqdm .auto import tqdm
2022
@@ -253,6 +255,8 @@ def download_weights_from_hf(
253255 # Use file lock to prevent multiple processes from
254256 # downloading the same model weights at the same time.
255257 with get_lock (model_name_or_path , cache_dir ):
258+ start_size = scan_cache_dir ().size_on_disk
259+ start_time = time .perf_counter ()
256260 hf_folder = snapshot_download (
257261 model_name_or_path ,
258262 allow_patterns = allow_patterns ,
@@ -262,6 +266,11 @@ def download_weights_from_hf(
262266 revision = revision ,
263267 local_files_only = huggingface_hub .constants .HF_HUB_OFFLINE ,
264268 )
269+ end_time = time .perf_counter ()
270+ end_size = scan_cache_dir ().size_on_disk
271+ if end_size != start_size :
272+ logger .info ("Time took to download weights for %s: %.6f seconds" ,
273+ model_name_or_path , end_time - start_time )
265274 return hf_folder
266275
267276
0 commit comments