Skip to content

Commit 7efd761

Browse files
mohanreddypmrMohanReddy
andauthored
azure storage options (#365)
Co-authored-by: MohanReddy <mohanreddy@Mohans-Macbook-Pro.local>
1 parent 92df8af commit 7efd761

File tree

1 file changed

+7
-5
lines changed

1 file changed

+7
-5
lines changed

src/litdata/utilities/dataset_utilities.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,9 @@ def subsample_streaming_dataset(
3838

3939
# Make sure input_dir contains cache path and remote url
4040
if _should_replace_path(input_dir.path):
41-
cache_path = _try_create_cache_dir(input_dir=input_dir.path if input_dir.path else input_dir.url)
41+
cache_path = _try_create_cache_dir(
42+
input_dir=input_dir.path if input_dir.path else input_dir.url, storage_options=storage_options
43+
)
4244
if cache_path is not None:
4345
input_dir.path = cache_path
4446

@@ -96,7 +98,7 @@ def _should_replace_path(path: Optional[str]) -> bool:
9698
return path.startswith("/teamspace/datasets/") or path.startswith("/teamspace/s3_connections/")
9799

98100

99-
def _read_updated_at(input_dir: Optional[Dir]) -> str:
101+
def _read_updated_at(input_dir: Optional[Dir], storage_options: Optional[Dict] = {}) -> str:
100102
"""Read last updated timestamp from index.json file."""
101103
last_updation_timestamp = "0"
102104
index_json_content = None
@@ -110,7 +112,7 @@ def _read_updated_at(input_dir: Optional[Dir]) -> str:
110112
# download index.json file and read last_updation_timestamp
111113
with tempfile.TemporaryDirectory() as tmp_directory:
112114
temp_index_filepath = os.path.join(tmp_directory, _INDEX_FILENAME)
113-
downloader = get_downloader_cls(input_dir.url, tmp_directory, [])
115+
downloader = get_downloader_cls(input_dir.url, tmp_directory, [], storage_options)
114116
downloader.download_file(os.path.join(input_dir.url, _INDEX_FILENAME), temp_index_filepath)
115117

116118
index_json_content = load_index_file(tmp_directory)
@@ -135,9 +137,9 @@ def _clear_cache_dir_if_updated(input_dir_hash_filepath: str, updated_at_hash: s
135137
shutil.rmtree(input_dir_hash_filepath)
136138

137139

138-
def _try_create_cache_dir(input_dir: Optional[str]) -> Optional[str]:
140+
def _try_create_cache_dir(input_dir: Optional[str], storage_options: Optional[Dict] = {}) -> Optional[str]:
139141
resolved_input_dir = _resolve_dir(input_dir)
140-
updated_at = _read_updated_at(resolved_input_dir)
142+
updated_at = _read_updated_at(resolved_input_dir, storage_options)
141143

142144
if updated_at == "0" and input_dir is not None:
143145
updated_at = hashlib.md5(input_dir.encode()).hexdigest() # noqa: S324

0 commit comments

Comments
 (0)