Skip to content

Commit 518a1c3

Browse files
tchatonawaelchli
andauthored
Enforce passing item_loader when customizing underlying storage format (#296)
Co-authored-by: awaelchli <aedu.waelchli@gmail.com>
1 parent 44ef1af commit 518a1c3

File tree

9 files changed

+160
-22
lines changed

9 files changed

+160
-22
lines changed

README.md

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -311,6 +311,69 @@ for batch_idx, batch in enumerate(dataloader):
311311
</details>
312312

313313

314+
<details>
315+
<summary> ✅ LLM Pre-training </summary>
316+
&nbsp;
317+
318+
LitData is highly optimized for LLM pre-training. First, we need to tokenize the entire dataset and then we can consume it.
319+
320+
```python
321+
import json
322+
from pathlib import Path
323+
import zstandard as zstd
324+
from litdata import optimize, TokensLoader
325+
from tokenizer import Tokenizer
326+
from functools import partial
327+
328+
# 1. Define a function to convert the text within the jsonl files into tokens
329+
def tokenize_fn(filepath, tokenizer=None):
330+
with zstd.open(open(filepath, "rb"), "rt", encoding="utf-8") as f:
331+
for row in f:
332+
text = json.loads(row)["text"]
333+
if json.loads(row)["meta"]["redpajama_set_name"] == "RedPajamaGithub":
334+
continue # exclude the GitHub data since it overlaps with starcoder
335+
text_ids = tokenizer.encode(text, bos=False, eos=True)
336+
yield text_ids
337+
338+
if __name__ == "__main__":
339+
# 2. Generate the inputs (we are going to optimize all the compressed json files from SlimPajama dataset )
340+
input_dir = "./slimpajama-raw"
341+
inputs = [str(file) for file in Path(f"{input_dir}/SlimPajama-627B/train").rglob("*.zst")]
342+
343+
# 3. Store the optimized data wherever you want under "/teamspace/datasets" or "/teamspace/s3_connections"
344+
outputs = optimize(
345+
fn=partial(tokenize_fn, tokenizer=Tokenizer(f"{input_dir}/checkpoints/Llama-2-7b-hf")), # Note: You can use HF tokenizer or any others
346+
inputs=inputs,
347+
output_dir="./slimpajama-optimized",
348+
chunk_size=(2049 * 8012),
349+
# This is important to inform LitData that we are encoding contiguous 1D array (tokens).
350+
# LitData skips storing metadata for each sample e.g all the tokens are concatenated to form one large tensor.
351+
item_loader=TokensLoader(),
352+
)
353+
```
354+
355+
```python
356+
import os
357+
from litdata import StreamingDataset, CombinedStreamingDataset, StreamingDataLoader, TokensLoader
358+
from tqdm import tqdm
359+
360+
# Increase by one because we need the next word as well
361+
dataset = StreamingDataset(
362+
input_dir=f"./slimpajama-optimized/train",
363+
item_loader=TokensLoader(block_size=2048 + 1),
364+
shuffle=True,
365+
drop_last=True,
366+
)
367+
368+
train_dataloader = StreamingDataLoader(dataset, batch_size=8, pin_memory=True, num_workers=os.cpu_count())
369+
370+
# Iterate over the SlimPajama dataset
371+
for batch in tqdm(train_dataloader):
372+
pass
373+
```
374+
375+
</details>
376+
314377
<details>
315378
<summary> ✅ Combine datasets</summary>
316379
&nbsp;

src/litdata/processing/data_processor.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
from litdata.streaming.cache import Dir
5151
from litdata.streaming.client import S3Client
5252
from litdata.streaming.dataloader import StreamingDataLoader
53+
from litdata.streaming.item_loader import BaseItemLoader
5354
from litdata.streaming.resolver import _resolve_dir
5455
from litdata.utilities._pytree import tree_flatten, tree_unflatten, treespec_loads
5556
from litdata.utilities.broadcast import broadcast_object
@@ -399,6 +400,7 @@ def __init__(
399400
use_checkpoint: bool = False,
400401
checkpoint_chunks_info: Optional[List[Dict[str, Any]]] = None,
401402
checkpoint_next_index: Optional[int] = None,
403+
item_loader: Optional[BaseItemLoader] = None,
402404
) -> None:
403405
"""The BaseWorker is responsible to process the user data."""
404406
self.worker_index = worker_index
@@ -424,6 +426,7 @@ def __init__(
424426
self.remove_queue: Queue = Queue()
425427
self.progress_queue: Queue = progress_queue
426428
self.error_queue: Queue = error_queue
429+
self.item_loader = item_loader
427430
self._counter = 0
428431
self._last_time = time()
429432
self._index_counter = 0
@@ -522,6 +525,7 @@ def _create_cache(self) -> None:
522525
compression=self.data_recipe.compression,
523526
encryption=self.data_recipe.encryption,
524527
writer_chunk_index=self.writer_starting_chunk_index,
528+
item_loader=self.item_loader,
525529
)
526530
self.cache._reader._rank = _get_node_rank() * self.num_workers + self.worker_index
527531

@@ -880,6 +884,7 @@ def __init__(
880884
reader: Optional[BaseReader] = None,
881885
state_dict: Optional[Dict[int, int]] = None,
882886
use_checkpoint: bool = False,
887+
item_loader: Optional[BaseItemLoader] = None,
883888
start_method: Optional[str] = None,
884889
):
885890
"""The `DatasetOptimiser` provides an efficient way to process data across multiple machine into chunks to make
@@ -902,6 +907,8 @@ def __init__(
902907
state_dict: The writer state dict. This is used to decide how to append data to an existing dataset.
903908
use_checkpoint: Whether to create checkpoints while processing the data, which can be used to resume the
904909
processing from the last checkpoint if the process is interrupted. (`Default: False`)
910+
item_loader: The item loader that will be used during loading in StreamingDataset. Determines
911+
the format in which the data is stored and optimized for loading.
905912
start_method: The start method used by python multiprocessing package. Default to spawn unless running
906913
inside an interactive shell like Ipython.
907914
@@ -937,6 +944,7 @@ def __init__(
937944
self.use_checkpoint = use_checkpoint
938945
self.checkpoint_chunks_info: Optional[List[List[Dict[str, Any]]]] = None
939946
self.checkpoint_next_index: Optional[List[int]] = None
947+
self.item_loader = item_loader
940948

941949
self.state_dict = state_dict or {rank: 0 for rank in range(self.num_workers)}
942950

@@ -1157,6 +1165,7 @@ def _create_process_workers(self, data_recipe: DataRecipe, workers_user_items: L
11571165
self.use_checkpoint,
11581166
self.checkpoint_chunks_info[worker_idx] if self.checkpoint_chunks_info else None,
11591167
self.checkpoint_next_index[worker_idx] if self.checkpoint_next_index else None,
1168+
self.item_loader,
11601169
)
11611170
worker.start()
11621171
workers.append(worker)

src/litdata/processing/functions.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
)
3939
from litdata.streaming.client import S3Client
4040
from litdata.streaming.dataloader import StreamingDataLoader
41+
from litdata.streaming.item_loader import BaseItemLoader
4142
from litdata.streaming.resolver import (
4243
Dir,
4344
_assert_dir_has_index_file,
@@ -311,6 +312,7 @@ def optimize(
311312
batch_size: Optional[int] = None,
312313
mode: Optional[Literal["append", "overwrite"]] = None,
313314
use_checkpoint: bool = False,
315+
item_loader: Optional[BaseItemLoader] = None,
314316
start_method: Optional[str] = None,
315317
) -> None:
316318
"""This function converts a dataset into chunks, possibly in a distributed way.
@@ -341,6 +343,8 @@ def optimize(
341343
Defaults to None.
342344
use_checkpoint: Whether to create checkpoints while processing the data, which can be used to resume the
343345
processing from the last checkpoint if the process is interrupted. (`Default: False`)
346+
item_loader: The item loader that will be used during loading in StreamingDataset. Determines
347+
the format in which the data is stored and optimized for loading.
344348
start_method: The start method used by python multiprocessing package. Default to spawn unless running
345349
inside an interactive shell like Ipython.
346350
@@ -433,6 +437,7 @@ def optimize(
433437
reader=reader,
434438
state_dict=state_dict,
435439
use_checkpoint=use_checkpoint,
440+
item_loader=item_loader,
436441
start_method=start_method,
437442
)
438443

src/litdata/streaming/cache.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ def __init__(
7676
encryption=encryption,
7777
serializers=serializers,
7878
chunk_index=writer_chunk_index or 0,
79+
item_loader=item_loader,
7980
)
8081
self._reader = BinaryReader(
8182
self._cache_dir,

src/litdata/streaming/config.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -257,12 +257,17 @@ def __len__(self) -> int:
257257

258258
def _validate_item_loader(self) -> None:
259259
assert self._config
260-
if (
261-
len(self._config["data_format"]) == 1
262-
and self._config["data_format"][0].startswith("no_header_tensor")
263-
and not isinstance(self._item_loader, TokensLoader)
264-
):
265-
raise ValueError("Please, use Cache(..., item_loader=TokensLoader(block_size=...))")
260+
if "item_loader" in self._config:
261+
if self._item_loader.__class__.__name__ != self._config["item_loader"]:
262+
item_loader = self._config["item_loader"]
263+
raise ValueError(f"Please, use Cache(..., item_loader={item_loader}(...))")
264+
else:
265+
if (
266+
len(self._config["data_format"]) == 1
267+
and self._config["data_format"][0].startswith("no_header_tensor")
268+
and not isinstance(self._item_loader, TokensLoader)
269+
):
270+
raise ValueError("Please, use Cache(..., item_loader=TokensLoader(block_size=...))")
266271

267272

268273
def load_subsampled_chunks(subsampled_files: List[str], original_chunks: List[Dict[str, Any]]) -> List[Dict[str, Any]]:

src/litdata/streaming/item_loader.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,10 @@ def delete(self, chunk_index: int, chunk_filepath: str) -> None:
100100
"""Delete a chunk from the local filesystem."""
101101
pass
102102

103+
@abstractmethod
104+
def encode_data(self, data: List[bytes], sizes: List[int], flattened: List[Any]) -> Any:
105+
pass
106+
103107

104108
class PyTreeLoader(BaseItemLoader):
105109
"""The Pytree Loader is the default loader of the Cache object."""
@@ -245,9 +249,16 @@ def _validate_encryption(self, encryption: Optional[Encryption]) -> None:
245249
if encryption.level != self._config["encryption"]["level"]:
246250
raise ValueError("Encryption level mismatch.")
247251

252+
@classmethod
253+
def encode_data(cls, data: List[bytes], sizes: List[int], flattened: List[Any]) -> Tuple[bytes, Optional[int]]:
254+
# Concatenante into a single byte array
255+
head = np.array(sizes, np.uint32).tobytes()
256+
body = b"".join(data)
257+
return head + body, None
258+
248259

249260
class TokensLoader(BaseItemLoader):
250-
def __init__(self, block_size: int):
261+
def __init__(self, block_size: Optional[int] = None):
251262
"""The Tokens Loader is an optimizer item loader for NLP.
252263
253264
Arguments:
@@ -263,6 +274,7 @@ def __init__(self, block_size: int):
263274
self._chunk_filepaths: Dict[str, bool] = {}
264275

265276
def state_dict(self) -> Dict:
277+
assert self._block_size
266278
return {
267279
"block_size": self._block_size,
268280
}
@@ -280,6 +292,7 @@ def setup(
280292
raise ValueError("The provided chunks isn't properly setup.")
281293

282294
def generate_intervals(self) -> List[Interval]:
295+
assert self._block_size
283296
intervals = []
284297
begin = 0
285298
end = 0
@@ -324,6 +337,8 @@ def load_item_from_chunk(
324337
begin: int,
325338
chunk_bytes: int,
326339
) -> torch.Tensor:
340+
assert self._block_size
341+
327342
if chunk_filepath in self._chunk_filepaths and not os.path.isfile(chunk_filepath):
328343
del self._chunk_filepaths[chunk_filepath]
329344

@@ -350,3 +365,7 @@ def delete(self, chunk_index: int, chunk_filepath: str) -> None:
350365
if chunk_index in self._mmaps:
351366
del self._mmaps[chunk_index]
352367
os.remove(chunk_filepath)
368+
369+
@classmethod
370+
def encode_data(cls, data: List[bytes], _: List[int], flattened: List[Any]) -> Tuple[bytes, Optional[int]]:
371+
return data[0], flattened[0].shape[0]

src/litdata/streaming/writer.py

Lines changed: 5 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,11 @@
2020
from typing import Any, Dict, List, Optional, Tuple, Union
2121

2222
import numpy as np
23-
import torch
2423

2524
from litdata.constants import _INDEX_FILENAME
2625
from litdata.processing.utilities import get_worker_rank
2726
from litdata.streaming.compression import _COMPRESSORS, Compressor
27+
from litdata.streaming.item_loader import BaseItemLoader, PyTreeLoader
2828
from litdata.streaming.serializers import Serializer, _get_serializers
2929
from litdata.utilities._pytree import PyTree, tree_flatten, treespec_dumps
3030
from litdata.utilities.encryption import Encryption, EncryptionLevel
@@ -54,6 +54,7 @@ def __init__(
5454
follow_tensor_dimension: bool = True,
5555
serializers: Optional[Dict[str, Serializer]] = None,
5656
chunk_index: Optional[int] = None,
57+
item_loader: Optional[BaseItemLoader] = None,
5758
):
5859
"""The BinaryWriter enables to chunk dataset into an efficient streaming format for cloud training.
5960
@@ -83,6 +84,7 @@ def __init__(
8384
self._chunk_bytes = _convert_bytes_to_int(chunk_bytes) if isinstance(chunk_bytes, str) else chunk_bytes
8485
self._compression = compression
8586
self._encryption = encryption
87+
self._item_loader = item_loader or PyTreeLoader()
8688

8789
self._data_format: Optional[List[str]] = None
8890
self._data_spec: Optional[PyTree] = None
@@ -148,6 +150,7 @@ def get_config(self) -> Dict[str, Any]:
148150
"data_format": self._data_format,
149151
"data_spec": treespec_dumps(self._data_spec) if self._data_spec else None,
150152
"encryption": self._encryption.state_dict() if self._encryption else None,
153+
"item_loader": self._item_loader.__class__.__name__,
151154
}
152155

153156
def serialize(self, items: Any) -> Tuple[bytes, Optional[int]]:
@@ -156,10 +159,6 @@ def serialize(self, items: Any) -> Tuple[bytes, Optional[int]]:
156159
# Flatten the items provided by the users
157160
flattened, data_spec = tree_flatten(items)
158161

159-
is_single_tensor = (
160-
len(flattened) == 1 and isinstance(flattened[0], torch.Tensor) and len(flattened[0].shape) == 1
161-
)
162-
163162
# Collect the sizes and associated bytes for each item
164163
sizes: List[int] = []
165164
data: List[bytes] = []
@@ -178,14 +177,7 @@ def serialize(self, items: Any) -> Tuple[bytes, Optional[int]]:
178177
# tiny optimization to avoid looping over all the data format
179178
self._serialize_with_data_format(flattened, sizes, data, self._data_format)
180179

181-
# If there is a single element and it is a tensor, enable continous array.
182-
if is_single_tensor:
183-
return data[0], flattened[0].shape[0]
184-
185-
# Concatenante into a single byte array
186-
head = np.array(sizes, np.uint32).tobytes()
187-
body = b"".join(data)
188-
return head + body, None
180+
return self._item_loader.encode_data(data, sizes, flattened)
189181

190182
def _serialize(self, item: Any, sizes: List[int], data: List[bytes]) -> str:
191183
"""Serialize a given item and append its size and bytes to the sizes and data array."""

tests/streaming/test_cache.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,7 @@ def test_cache_with_auto_wrapping(tmpdir):
211211
assert sorted(os.listdir(os.path.join(tmpdir, "cache_1"))) == [
212212
"chunk-0-0.bin",
213213
"chunk-0-1.bin",
214+
"chunk-0-2.bin",
214215
"index.json",
215216
]
216217
# Your dataset is optimised for the cloud

0 commit comments

Comments
 (0)