2020
2121from filelock import FileLock , Timeout
2222
23- from litdata .constants import _AZURE_STORAGE_AVAILABLE , _GOOGLE_STORAGE_AVAILABLE , _INDEX_FILENAME
23+ from litdata .constants import (
24+ _AZURE_STORAGE_AVAILABLE ,
25+ _GOOGLE_STORAGE_AVAILABLE ,
26+ _HUGGINGFACE_HUB_AVAILABLE ,
27+ _INDEX_FILENAME ,
28+ )
2429from litdata .streaming .client import S3Client
2530
2631
@@ -164,6 +169,56 @@ def download_file(self, remote_filepath: str, local_filepath: str) -> None:
164169 pass
165170
166171
172+ class HFDownloader (Downloader ):
173+ def __init__ (
174+ self , remote_dir : str , cache_dir : str , chunks : List [Dict [str , Any ]], storage_options : Optional [Dict ] = {}
175+ ):
176+ if not _HUGGINGFACE_HUB_AVAILABLE :
177+ raise ModuleNotFoundError (str (_HUGGINGFACE_HUB_AVAILABLE ))
178+
179+ super ().__init__ (remote_dir , cache_dir , chunks , storage_options )
180+
181+ def download_file (self , remote_filepath : str , local_filepath : str ) -> None :
182+ """Download a file from the Hugging Face Hub.
183+
184+ The remote_filepath should be in the format `hf://<repo_type>/<repo_org>/<repo_name>/path`. For more
185+ information, see
186+ https://huggingface.co/docs/huggingface_hub/en/guides/hf_file_system#integrations.
187+
188+ """
189+ from huggingface_hub import hf_hub_download
190+
191+ obj = parse .urlparse (remote_filepath )
192+
193+ if obj .scheme != "hf" :
194+ raise ValueError (f"Expected obj.scheme to be `hf`, instead, got { obj .scheme } for remote={ remote_filepath } " )
195+
196+ if os .path .exists (local_filepath ):
197+ return
198+
199+ try :
200+ with FileLock (local_filepath + ".lock" , timeout = 3 if obj .path .endswith (_INDEX_FILENAME ) else 0 ):
201+ # Adapted from https://github.com/mosaicml/streaming/blob/main/streaming/base/storage/download.py#L292
202+ # expected URL format: hf://datasets/<repo_org>/<repo_name>/path
203+ _ , _ , _ , repo_org , repo_name , path = remote_filepath .split ("/" , 5 )
204+ downloaded_path = hf_hub_download (
205+ repo_id = f"{ repo_org } /{ repo_name } " ,
206+ filename = path ,
207+ local_dir = self ._cache_dir ,
208+ repo_type = "dataset" ,
209+ ** self ._storage_options ,
210+ )
211+
212+ # Move the downloaded file to the expected location if it's not already there.
213+ if downloaded_path != local_filepath and os .path .exists (downloaded_path ):
214+ os .rename (downloaded_path , local_filepath )
215+ os .rmdir (os .path .dirname (downloaded_path ))
216+
217+ except Timeout :
218+ # another process is responsible to download that file, continue
219+ pass
220+
221+
167222class LocalDownloader (Downloader ):
168223 def download_file (self , remote_filepath : str , local_filepath : str ) -> None :
169224 if not os .path .exists (remote_filepath ):
@@ -183,6 +238,7 @@ def download_file(self, remote_filepath: str, local_filepath: str) -> None:
183238 "s3://" : S3Downloader ,
184239 "gs://" : GCPDownloader ,
185240 "azure://" : AzureDownloader ,
241+ "hf://" : HFDownloader ,
186242 "local:" : LocalDownloaderWithCache ,
187243 "" : LocalDownloader ,
188244}
0 commit comments