Skip to content

Commit f03ddb7

Browse files
authored
feat: fast random access for streamingDataset without chunk downloading (#631)
* fast random access for s3 works * no_chunk_download supports gcloud * update * update * update * nitpick * tests * update * update * update * update * update * rename `no_store` to `on_demand_bytes` * make sure client exists * fallback for `download_bytes` method
1 parent 82bf020 commit f03ddb7

File tree

8 files changed

+246
-33
lines changed

8 files changed

+246
-33
lines changed

src/litdata/streaming/cache.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ def __init__(
5050
session_options: Optional[Dict] = {},
5151
max_pre_download: int = 2,
5252
msg_queue: Optional[Queue] = None,
53+
on_demand_bytes: bool = False,
5354
):
5455
"""The Cache enables to optimise dataset format for cloud training. This is done by grouping several elements
5556
together in order to accelerate fetching.
@@ -70,6 +71,7 @@ def __init__(
7071
session_options: Additional options for the S3 session.
7172
max_pre_download: Maximum number of chunks that can be pre-downloaded while filling up the cache.
7273
msg_queue: Optional message queue to send messages to the main process.
74+
on_demand_bytes: If True, fetch only the requested sample's bytes instead of downloading the entire chunk.
7375
7476
"""
7577
super().__init__()
@@ -100,6 +102,7 @@ def __init__(
100102
storage_options=storage_options,
101103
session_options=session_options,
102104
max_pre_download=max_pre_download,
105+
on_demand_bytes=on_demand_bytes,
103106
)
104107
self._is_done = False
105108
self._distributed_env = _DistributedEnv.detect()

src/litdata/streaming/config.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,28 @@ def download_chunk_from_index(self, chunk_index: int, skip_lock: bool = False) -
151151

152152
self.try_decompress(local_chunkpath)
153153

154+
def download_chunk_bytes_from_index(self, chunk_index: int, offset: int, length: int) -> bytes:
155+
assert self._chunks is not None
156+
chunk_filename = self._chunks[chunk_index]["filename"]
157+
158+
local_chunkpath = os.path.join(self._cache_dir, chunk_filename)
159+
160+
if os.path.exists(local_chunkpath):
161+
with open(local_chunkpath, "rb") as f:
162+
f.seek(offset)
163+
return f.read(length)
164+
165+
if self._compressor is not None:
166+
raise ValueError(
167+
"The `download_chunk_bytes_from_index` method is not supported for compressed chunks. "
168+
"Please, use `download_chunk_from_index` instead."
169+
)
170+
171+
if self._downloader is None:
172+
raise RuntimeError("The downloader is not initialized. Please, initialize it before downloading chunks.")
173+
174+
return self._downloader.download_chunk_bytes_from_index(chunk_index, offset, length)
175+
154176
def try_decompress(self, local_chunkpath: str) -> None:
155177
if self._compressor is None:
156178
return

src/litdata/streaming/dataset.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,19 @@ def __init__(
201201
if not callable(transform):
202202
raise ValueError(f"Transform should be a callable. Found {transform}")
203203
self.transform = transform
204+
self._on_demand_bytes = True # true by default, when iterating, turn this off to store the chunks in the cache
205+
206+
@property
207+
def on_demand_bytes(self) -> bool:
208+
return self._on_demand_bytes
209+
210+
@on_demand_bytes.setter
211+
def on_demand_bytes(self, value: bool) -> None:
212+
if not isinstance(value, bool):
213+
raise ValueError(f"on_demand_bytes should be a boolean. Found {value}")
214+
self._on_demand_bytes = value
215+
assert self.cache is not None, "Cache must be initialized before setting on_demand_bytes."
216+
self.cache._reader.on_demand_bytes = value
204217

205218
def set_shuffle(self, shuffle: bool) -> None:
206219
self.shuffle = shuffle
@@ -240,6 +253,7 @@ def _create_cache(self, worker_env: _WorkerEnv) -> Cache:
240253
storage_options=self.storage_options,
241254
session_options=self.session_options,
242255
max_pre_download=self.max_pre_download,
256+
on_demand_bytes=self._on_demand_bytes,
243257
)
244258
cache._reader._try_load_config()
245259

@@ -287,6 +301,7 @@ def __iter__(self) -> "StreamingDataset":
287301
self.worker_env = _WorkerEnv.detect()
288302
self.cache = self._create_cache(worker_env=self.worker_env)
289303
self.shuffler = self._create_shuffler(self.cache)
304+
self.on_demand_bytes = False # reset on_demand_bytes to False, and store chunks in the cache
290305

291306
# Handle restart
292307
if self._state_dict:
@@ -402,14 +417,15 @@ def _resume(self, workers_chunks: List[List[int]], workers_intervals: List[Any])
402417
# bump the chunk_index
403418
self.worker_next_chunk_index += 1
404419

405-
def __getitem__(self, index: Union[ChunkedIndex, int]) -> Any:
420+
def __getitem__(self, index: Union[ChunkedIndex, int, slice]) -> Any:
406421
if self.cache is None:
407422
self.worker_env = _WorkerEnv.detect()
408423
self.cache = self._create_cache(worker_env=self.worker_env)
409424
self.shuffler = self._create_shuffler(self.cache)
410425
if isinstance(index, int):
411426
index = ChunkedIndex(*self.cache._get_chunk_index_from_index(index))
412427
elif isinstance(index, slice):
428+
self.on_demand_bytes = False # for slices, we always want to store the chunks
413429
start, stop, step = index.indices(len(self))
414430
_my_indices = list(range(start, stop, step))
415431
_my_cache_indices = [ChunkedIndex(*self.cache._get_chunk_index_from_index(idx)) for idx in _my_indices]
@@ -436,6 +452,7 @@ def __next__(self) -> Any:
436452
self.current_epoch += 1
437453
self.reset_state_dict()
438454
logger.debug(_get_log_msg({"name": "iterating_dataset", "ph": "E"}))
455+
self.on_demand_bytes = True # reset on_demand_bytes to True
439456
raise StopIteration
440457

441458
# Lazily re-populate the interval to reduce memory usage.
@@ -444,6 +461,7 @@ def __next__(self) -> Any:
444461
if self.num_chunks is not None and self.worker_next_chunk_index >= self.num_chunks:
445462
self.current_epoch += 1
446463
self.reset_state_dict()
464+
self.on_demand_bytes = True # reset on_demand_bytes to True
447465
raise StopIteration
448466

449467
# if upcoming_indexes is empty, means either:

src/litdata/streaming/downloader.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,9 +76,28 @@ def download_chunk_from_index(self, chunk_index: int) -> None:
7676

7777
logger.debug(_get_log_msg({"name": f"download_chunk_from_index_{chunk_index}", "ph": "E"}))
7878

79+
def download_chunk_bytes_from_index(self, chunk_index: int, offset: int, length: int) -> bytes:
80+
chunk_filename = self._chunks[chunk_index]["filename"]
81+
local_chunkpath = os.path.join(self._cache_dir, chunk_filename)
82+
remote_chunkpath = os.path.join(self._remote_dir, chunk_filename)
83+
84+
return self.download_bytes(remote_chunkpath, offset, length, local_chunkpath)
85+
7986
def download_file(self, remote_chunkpath: str, local_chunkpath: str) -> None:
8087
pass
8188

89+
def download_bytes(self, remote_chunkpath: str, offset: int, length: int, local_chunkpath: str) -> bytes:
90+
"""Download a specific range of bytes from the remote file.
91+
92+
If this method is not overridden in a subclass, it defaults to downloading the full file
93+
by calling `download_file` and then reading the desired byte range from the local copy.
94+
"""
95+
self.download_file(remote_chunkpath, local_chunkpath)
96+
# read the specified byte range from the local file
97+
with open(local_chunkpath, "rb") as f:
98+
f.seek(offset)
99+
return f.read(length)
100+
82101

83102
class S3Downloader(Downloader):
84103
def __init__(
@@ -165,6 +184,24 @@ def download_file(self, remote_filepath: str, local_filepath: str) -> None:
165184
Config=TransferConfig(use_threads=False),
166185
)
167186

187+
def download_bytes(self, remote_filepath: str, offset: int, length: int, local_chunkpath: str) -> bytes:
188+
obj = parse.urlparse(remote_filepath)
189+
190+
if obj.scheme != "s3":
191+
raise ValueError(f"Expected obj.scheme to be `s3`, instead, got {obj.scheme} for remote={remote_filepath}")
192+
193+
if not hasattr(self, "client"):
194+
self._client = S3Client(storage_options=self._storage_options, session_options=self.session_options)
195+
196+
bucket = obj.netloc
197+
key = obj.path.lstrip("/")
198+
199+
byte_range = f"bytes={offset}-{offset + length - 1}"
200+
201+
response = self._client.client.get_object(Bucket=bucket, Key=key, Range=byte_range)
202+
203+
return response["Body"].read()
204+
168205

169206
class GCPDownloader(Downloader):
170207
def __init__(
@@ -208,6 +245,26 @@ def download_file(self, remote_filepath: str, local_filepath: str) -> None:
208245
blob = bucket.blob(key)
209246
blob.download_to_filename(local_filepath)
210247

248+
def download_bytes(self, remote_filepath: str, offset: int, length: int, local_chunkpath: str) -> bytes:
249+
from google.cloud import storage
250+
251+
obj = parse.urlparse(remote_filepath)
252+
253+
if obj.scheme != "gs":
254+
raise ValueError(f"Expected scheme 'gs', got '{obj.scheme}' for remote={remote_filepath}")
255+
256+
bucket_name = obj.netloc
257+
key = obj.path.lstrip("/")
258+
259+
client = storage.Client(**self._storage_options)
260+
bucket = client.bucket(bucket_name)
261+
blob = bucket.blob(key)
262+
263+
# GCS uses end as *inclusive*, so end = offset + length - 1
264+
end = offset + length - 1
265+
266+
return blob.download_as_bytes(start=offset, end=end)
267+
211268

212269
class AzureDownloader(Downloader):
213270
def __init__(

src/litdata/streaming/item_loader.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,14 @@ def load_item_from_chunk(
107107
) -> Any:
108108
"""Returns an item loaded from a chunk."""
109109

110+
def load_item_from_bytes(
111+
self,
112+
raw_bytes: bytes,
113+
chunk_index: int,
114+
) -> Any:
115+
"""Returns an item loaded from bytes."""
116+
raise NotImplementedError("The `load_item_from_bytes` method is not implemented for this item loader.")
117+
110118
@abstractmethod
111119
def delete(self, chunk_index: int, chunk_filepath: str) -> None:
112120
"""Delete a chunk from the local filesystem."""
@@ -143,6 +151,22 @@ def generate_intervals(self) -> List[Interval]:
143151
def pre_load_chunk(self, chunk_index: int, chunk_filepath: str) -> None:
144152
pass
145153

154+
def load_item_from_bytes(
155+
self,
156+
raw_bytes: bytes,
157+
chunk_index: int,
158+
) -> bytes:
159+
if self._config.get("encryption"):
160+
raise ValueError("The `load_item_from_bytes` method does not support encrypted data loading currently.")
161+
162+
# check for mosaic mds format
163+
if "format" in self._config and self._config["format"] == "mds":
164+
item_data = self.mds_deserialize(raw_bytes, chunk_index)
165+
else:
166+
item_data = self.deserialize(raw_bytes)
167+
168+
return item_data
169+
146170
def load_item_from_chunk(
147171
self,
148172
index: int,

0 commit comments

Comments
 (0)