Skip to content

doc: improve dev doc #488

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Mar 4, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 89 additions & 7 deletions src/litdata/streaming/item_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def setup(
self._chunks = chunks
self._serializers = {**serializers}
self._data_format = self._config["data_format"]
self._shift_idx = len(self._data_format) * 4
self._shift_idx = len(self._data_format) * 4 # each item takes 4 bytes
self.region_of_interest = region_of_interest
self._force_download_queue = force_download_queue

Expand Down Expand Up @@ -147,6 +147,27 @@ def load_item_from_chunk(
filesize_bytes: int,
encryption: Optional[Encryption] = None,
) -> bytes:
#
# Let's say, a chunk contains items from [5,9] index.
# And the index of the item we want to load is 7.
# begin = 5
# index = 7
#
# The chunk's binary format is structured as follows:
#
# +------------+---------------+-------------+
# | num_items | offset_array | item_data |
# +------------+---------------+-------------+
# | uint32 | uint32[N+1] | bytes |
# | 4 bytes | 4*(N+1) bytes | variable |
# +------------+---------------+-------------+
#
# To get to the offset index of the item we want to load, we need to jumpy by:
# => 1 + (index - begin) # 1 is added since first 4 bytes store `num_items` (1 uint32)
# => 1 + (7 - 5) = 3
# => 3 * 4 = 12 # each takes 4 bytes
# => offset = 12
#
offset = (1 + (index - begin) if index >= begin else index + 1) * 4

if chunk_filepath in self._chunk_filepaths and not os.path.isfile(chunk_filepath):
Expand Down Expand Up @@ -175,6 +196,7 @@ def load_item_from_chunk(
data = self._load_encrypted_data(chunk_filepath, chunk_index, offset, encryption)
else:
with open(chunk_filepath, "rb", 0) as fp:
# load the data from raw bytes using the offset for the item we want to load
data = self._load_data(fp, offset)

# check for mosaic mds format
Expand Down Expand Up @@ -214,11 +236,16 @@ def _load_encrypted_data(

def _load_data(self, fp: Union[FileIO, BytesIO], offset: int) -> bytes:
"""Load the data from the file pointer."""
fp.seek(offset)
fp.seek(offset) # move the file pointer to the offset

# Refer to `writer.py::_create_chunk` for more details on the chunk's binary format
# We want to read the `offset_start` and `offset_end` for the item we want to load
# 2 uint32 (4 bytes each) => 8 bytes; are read to get the offset_start and offset_end
pair = fp.read(8)
begin, end = np.frombuffer(pair, np.uint32)
fp.seek(begin)
return fp.read(end - begin)

fp.seek(begin) # move the file pointer to the offset_start where the item starts
return fp.read(end - begin) # read the item

def mds_deserialize(self, raw_item_data: bytes, chunk_index: int) -> "PyTree":
"""Deserialize the mds raw bytes into their python equivalent."""
Expand Down Expand Up @@ -268,7 +295,30 @@ def _validate_encryption(self, encryption: Optional[Encryption]) -> None:

@classmethod
def encode_data(cls, data: List[bytes], sizes: List[int], flattened: List[Any]) -> Tuple[bytes, Optional[int]]:
# Concatenante into a single byte array
"""Encodes multiple serialized objects into a single binary format with size metadata.

This method combines multiple serialized objects into a single byte array, prefixed with their sizes.
The resulting format is: [size_header][concatenated_data], where size_header contains the byte sizes
of each object encoded as uint32.

Args:
data: List of serialized objects as bytes
sizes: List of integers representing the byte size of each object
flattened: List of flattened pytree leaves

Returns:
Tuple containing:
- bytes: Combined binary data with header
- Optional[int]: dimension of the item (None for PyTreeLoader)

Example:
For a row containing [int, image, tensor]:
- sizes might be [4, 100000, 1000] (number of bytes for each object)
- data would be their respective serialized bytes
The method combines these into:

[size_bytes][int_bytes][image_bytes][tensor_bytes]
"""
head = np.array(sizes, np.uint32).tobytes()
body = b"".join(data)
return head + body, None
Expand Down Expand Up @@ -325,7 +375,7 @@ def generate_intervals(self) -> List[Interval]:
begin = 0
end = 0
for idx, chunk in enumerate(self._chunks):
dim = chunk["dim"]
dim = chunk["dim"] # number of tokens in the chunk
num_blocks = dim // self._block_size
end += num_blocks
start_idx, end_idx = begin, end
Expand All @@ -343,8 +393,9 @@ def _load_chunk(self, chunk_index: int, chunk_filepath: str) -> None:
chunk = self._chunks[chunk_index]

# Skip the header
# The number of items + the number of offsets (number of items in the chunk + 1)
# [number of items] + [number of offsets (number of items in the chunk + 1)] {since offset starts at 0}
# multiplied by the header encoding dtype (np.uint32)
# for more details on the chunk's binary format, see `writer.py::_create_chunk`
offset = (1 + chunk["chunk_size"] + 1) * 4
mmap = np.memmap(chunk_filepath, mode="r", order="C", offset=offset)
self._mmaps[chunk_index] = mmap
Expand Down Expand Up @@ -394,11 +445,21 @@ def load_item_from_chunk(
assert self._dtype

buffer: bytes = self._buffers[chunk_index]

# offset: how many bytes to skip to get to the item we want to load
# -> if chunk begins at 5, and we want to load the item at index 7,
# -> we need to skip 2 items, and each item has `self._block_size` tokens
# -> and each token takes `self._dtype.itemsize` bytes
#
# Note: We have already accounted for offsets corresponding to starting bytes in `_load_chunk` function
# while creating the memory map.
offset = self._dtype.itemsize * (index - begin) * self._block_size

if self._serializer_name == "no_header_tensor":
# count: number of tokens to read from buffer => `self._block_size`
data = torch.frombuffer(buffer, dtype=self._dtype, count=self._block_size, offset=offset)
else:
# count: number of tokens to read from buffer => `self._block_size`
data = np.frombuffer(buffer, dtype=self._dtype, count=self._block_size, offset=offset) # type: ignore
return data

Expand Down Expand Up @@ -426,6 +487,27 @@ def close(self, chunk_index: int) -> None:

@classmethod
def encode_data(cls, data: List[bytes], _: List[int], flattened: List[Any]) -> Tuple[bytes, Optional[int]]:
r"""Encodes tokenized data into a raw byte format while preserving dimensional information.

Parameters:
- data (List[bytes]): A list containing a single element, which is the raw byte
representation of tokenized data.
- _ (List[int]): A list containing sizes of each PyTree leaf in the item.
Since only one item (tokens) is present, this argument is ignored.
- flattened (List[Any]): A list containing a single element, which is the list of tokens.

Example:
- Original data: "hello world"
- Tokenized data: [1, 2] (word tokenizer)
- Data (raw bytes): [b'\x01\x00\x00\x00\x02\x00\x00\x00']
(raw bytes representing the tokenized data)
- Flattened data: [[1, 2]] (returned by PyTree's `flatten` function)

Returns:
- Tuple[bytes, Optional[int]]:
- bytes: The raw byte representation of tokenized data.
- dimension: The number of tokens in the data (extracted from `flattened[0].shape[0]`).
"""
return data[0], flattened[0].shape[0]


Expand Down
40 changes: 35 additions & 5 deletions src/litdata/streaming/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,8 +206,34 @@ def _serialize_with_data_format(
sizes.append(serializer.size if hasattr(serializer, "size") else len(serialized_item))

def _create_chunk(self, filename: str, on_done: bool = False) -> bytes:
"""Create a binary chunk from all the binarized items."""
items = []
"""Creates a binary chunk file from serialized items."""
# The chunk's binary format is structured as follows:

# +------------+---------------+-------------+
# | num_items | offset_array | item_data |
# +------------+---------------+-------------+
# | uint32 | uint32[N+1] | bytes |
# | 4 bytes | 4*(N+1) bytes | variable |
# +------------+---------------+-------------+

# Where:
# - num_items: Number of items in the chunk (N)
# - offset_array: Array of N+1 offsets indicating where each item begins/ends
# - item_data: Concatenated binary data of all items

# Example:
# For a chunk with 3 items of sizes [10, 20, 15] bytes:
# - num_items = 3 (4 bytes)
# - offset_array = [start, start+10, start+30, start+45]
# where start = 4 + (4 * 4) = 20 bytes (header size)
# - item_data = concatenated bytes of all items

# This format allows direct access to any item by reading its offset
# from `offset_array[i]` to `offset_array[i+1]`.
# Then, read bytes from `offset_start` to `offset_end` to get the item bytes.
# Now, item_loader can use these raw bytes to deserialize the item.

items: List[Item] = []

if on_done:
indices = sorted(self._serialized_items.keys())
Expand All @@ -228,11 +254,15 @@ def _create_chunk(self, filename: str, on_done: bool = False) -> bytes:
f" Found {self._pretty_serialized_items()} with boundaries: {self._min_index}, {self._max_index}."
)

num_items = np.uint32(len(items))
sizes = list(map(len, items))
offsets = np.array([0] + sizes).cumsum().astype(np.uint32)
num_items = np.uint32(len(items)) # total number of items in the chunk
sizes = list(map(len, items)) # list of sizes (length of bytes) of each item
offsets = np.array([0] + sizes).cumsum().astype(np.uint32) # let's say: [0, 10, 30, 45]

# add the number of bytes taken to store (num_items and offsets). Let's say 60: offsets -> [60, 70, 90, 105]
offsets += len(num_items.tobytes()) + len(offsets.tobytes())
sample_data = b"".join([item.data for item in items])

# combine all bytes data which will be written to the chunk file
data = num_items.tobytes() + offsets.tobytes() + sample_data

# Whether to encrypt the data at the chunk level
Expand Down
Loading