Skip to content

Commit

Permalink
Iterator commit
Browse files Browse the repository at this point in the history
  • Loading branch information
EntilZha committed Dec 12, 2024
1 parent 6638133 commit b8c8cb8
Show file tree
Hide file tree
Showing 16 changed files with 2,167 additions and 1 deletion.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -164,5 +164,5 @@ cython_debug/

figures/
.vscode/
data/
.DS_Store

1 change: 1 addition & 0 deletions bytelatent/data/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
115 changes: 115 additions & 0 deletions bytelatent/data/data_types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
import json
from dataclasses import dataclass
from typing import Any, Iterator

import numpy as np
from pydantic import BaseModel, ConfigDict


class BltExample(BaseModel):
model_config = ConfigDict(extra="forbid")
sample_id: str
text: str
tokens: list[int] | None
entropies: list[float] | None
patch_lengths: list[int] | None
mask: list[bool] | None


class MultiChoiceState(BaseModel):
model_config = ConfigDict(extra="forbid")
root_dir: str
sources: dict[str, float]
source_to_state: dict[str, Any]
rng_state: dict[str, Any]


class PrefetchState(BaseModel):
model_config = ConfigDict(extra="forbid")
seq_idx: int
rng_state: dict[str, Any]
prefetch_size: int
batch_size: int


class BltPackTokensState(BaseModel):
model_config = ConfigDict(extra="forbid")
start_token: int
output_seq_len: int
n_views: int = 2


class DataLoaderState(BaseModel):
model_config = ConfigDict(extra="forbid")
multi_choice_state: MultiChoiceState
pack_tokens_state: BltPackTokensState
prefetch_state: PrefetchState


BltIterator = Iterator[tuple[BltExample, DataLoaderState]]


class BltSequence(BaseModel):
tokens: list[int]
mask: list[bool]
patch_lengths: list[int]


@dataclass
class Batch:
x: np.ndarray
y: np.ndarray
mask: np.ndarray | None = None
patch_lengths: np.ndarray | None = None
ngram_ids: np.ndarray | None = None
is_final: bool = False

def to_python_dict(self) -> dict:
x = self.x.tolist()
y = self.y.tolist()
if self.mask is None:
mask = None
else:
mask = self.mask.tolist()
if self.patch_lengths is None:
patch_lengths = None
else:
patch_lengths = self.patch_lengths.tolist()
if self.ngram_ids is None:
ngram_ids = None
else:
ngram_ids = self.ngram_ids.tolist()
return {
"x": x,
"y": y,
"mask": mask,
"patch_lengths": patch_lengths,
"ngram_ids": ngram_ids,
"is_final": self.is_final,
}

@classmethod
def from_python_dict(cls, data: dict) -> "Batch":
x = np.array(data["x"])
y = np.array(data["y"])
if data["mask"] is None:
mask = None
else:
mask = np.array(data["mask"])
if data["patch_lengths"] is None:
patch_lengths = None
else:
patch_lengths = np.array(data["patch_lengths"])
if data["ngram_ids"] is None:
ngram_ids = None
else:
ngram_ids = np.array(data["ngram_ids"])
return Batch(
x=x,
y=y,
mask=mask,
patch_lengths=patch_lengths,
ngram_ids=ngram_ids,
is_final=data["is_final"],
)
1 change: 1 addition & 0 deletions bytelatent/data/iterators/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
23 changes: 23 additions & 0 deletions bytelatent/data/iterators/abstract_iterator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
import abc
from typing import Any, Generator, Generic, TypeVar

T = TypeVar("T")
C = TypeVar("C")


class StatefulIterator(Generic[T, C], abc.ABC):

@abc.abstractmethod
def get_state(self) -> C:
pass

@abc.abstractmethod
def create_iter(self) -> Generator[T, Any, None]:
pass


class IteratorState(Generic[C]):
@abc.abstractmethod
def build(self) -> StatefulIterator[T, C]:
pass
216 changes: 216 additions & 0 deletions bytelatent/data/iterators/arrow_iterator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,216 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
import re
from logging import getLogger
from pathlib import Path
from typing import Any, Generator

import pyarrow as pa

# pyarrow needs the initialization from this import
import pyarrow.dataset # pyright: ignore
from pydantic import BaseModel, ConfigDict

from bytelatent import ByteLatentError
from bytelatent.data.data_types import BltExample
from bytelatent.data.iterators.abstract_iterator import IteratorState, StatefulIterator

logger = getLogger(__name__)


class ArrowFileIteratorState(BaseModel, IteratorState):
model_config = ConfigDict(extra="forbid")
file_path: str | None
row_num: int
num_workers: int
worker_id: int
preprocess_dir: str | None
dataset_files: list[str] | None
entropy_model_name: str | None
arrow_batch_size: int = 100

def build(self) -> "ArrowFileIterator":
arrow_file = ArrowFileIterator(
file_path=self.file_path,
worker_id=self.worker_id,
num_workers=self.num_workers,
preprocess_dir=self.preprocess_dir,
entropy_model_name=self.entropy_model_name,
arrow_batch_size=self.arrow_batch_size,
dataset_files=self.dataset_files,
)
if self.row_num != 0:
arrow_file._set_row_num(self.row_num)
return arrow_file


def shard_sort_key(file: str | Path):
match = re.search(r".+\.shard_([0-9]+)\.arrow", str(file))
shard_number = int(match.group(1))
return shard_number


class ArrowFileIterator(StatefulIterator):
def __init__(
self,
*,
file_path: str | None,
worker_id: int,
num_workers: int,
preprocess_dir: str | None,
entropy_model_name: str | None,
arrow_batch_size: int,
dataset_files: list[str] | None = None,
):
assert 0 <= worker_id < num_workers, (worker_id, num_workers)
if file_path is None and dataset_files is None:
raise ByteLatentError("file_path and dataset_files cannot both be None")
self.row_num = 0
self.iter_id = 0
self.batch_iterator = None
self.batch_to_consume = None
self.dataset = None
self.file_path = file_path
self.worker_id = worker_id
self.num_workers = num_workers
self.preprocess_dir = preprocess_dir
self.entropy_model_name = entropy_model_name
self.arrow_batch_size = arrow_batch_size
if dataset_files is None:
# Prepare arrow shards
jsonl_file = Path(file_path)
parts = re.match(r"(.+)\.chunk\.[0-9]+\.jsonl", jsonl_file.name)
assert parts is not None
dataset = parts.group(1)
data_dir = Path(preprocess_dir) / dataset / entropy_model_name
shard_files = list(data_dir.glob(f"{jsonl_file.name}.shard_*.arrow"))
for s in shard_files:
if not (data_dir / f"{s.name}.complete").exists():
raise ValueError(f"Missing .complete for input file: {s}")

shard_files = sorted(shard_files, key=shard_sort_key)
if len(shard_files) == 0:
raise ByteLatentError(
f"Zero shard_files found corresponding to: {file_path} using preprocess_dir={preprocess_dir} and entropy_model_name={entropy_model_name}, so the search path is data_dir={data_dir} for matches to {jsonl_file.name}.shard_*.arrow"
)
self.dataset_files = [str(f) for f in shard_files]
else:
self.preprocess_dir = None
self.dataset_files = dataset_files

def get_state(self) -> ArrowFileIteratorState:
return ArrowFileIteratorState(
file_path=self.file_path,
row_num=self.row_num,
worker_id=self.worker_id,
num_workers=self.num_workers,
preprocess_dir=self.preprocess_dir,
entropy_model_name=self.entropy_model_name,
arrow_batch_size=self.arrow_batch_size,
dataset_files=self.dataset_files,
)

def create_iter(
self,
) -> Generator[BltExample, Any, None]:
if self.dataset is None:
self.dataset = pa.dataset.dataset(self.dataset_files, format="arrow")
self.batch_iterator = self.dataset.to_batches(
batch_size=self.arrow_batch_size
)
self.iter_id += 1
if self.batch_to_consume is not None:
batch_columns: dict[str, list] = self.batch_to_consume
self.batch_to_consume = None
sample_ids = batch_columns["sample_id"]
texts = batch_columns["text"]
entropies = batch_columns["entropies"]
for i in range(len(sample_ids)):
out = BltExample(
sample_id=sample_ids[i],
entropies=entropies[i],
text=texts[i],
tokens=None,
mask=None,
patch_lengths=None,
)
self.row_num += 1
if (self.row_num - 1) % self.num_workers == self.worker_id:
yield out

for batch in self.batch_iterator:
batch_columns = batch.to_pydict()
sample_ids = batch_columns["sample_id"]
texts = batch_columns["text"]
entropies = batch_columns["entropies"]
for i in range(len(sample_ids)):
out = BltExample(
sample_id=sample_ids[i],
entropies=entropies[i],
text=texts[i],
tokens=None,
mask=None,
patch_lengths=None,
)
self.row_num += 1
if (self.row_num - 1) % self.num_workers == self.worker_id:
yield out

def _set_row_num(self, target_row_num: int):
logger.info(
f"Setting arrow position to {target_row_num} for {self.dataset_files}"
)
if target_row_num is None or target_row_num == 0:
self.row_num = 0
self.dataset = None
self.batch_iterator = None
self.batch_to_consume = None
else:
self.dataset = pa.dataset.dataset(self.dataset_files, format="arrow")
self.batch_iterator = self.dataset.to_batches(
batch_size=self.arrow_batch_size
)
curr_remaining = target_row_num
for batch in self.batch_iterator:
if len(batch) > curr_remaining:
batch_columns: dict[str, list] = batch.to_pydict()
batch_columns["sample_id"] = batch_columns["sample_id"][
curr_remaining:
]
batch_columns["entropies"] = batch_columns["entropies"][
curr_remaining:
]
batch_columns["text"] = batch_columns["text"][curr_remaining:]
self.batch_to_consume = batch_columns
break
elif len(batch) == curr_remaining:
# We are exactly at the end of the batch,
# so the next batch is the right spot
break
else:
curr_remaining -= len(batch)
self.row_num = target_row_num
logger.info(
f"Finished setting arrow position to {target_row_num} for {self.dataset_files}"
)


TRAIN_DATA_FILE_PATTERN = "*.chunk.*.jsonl"


def find_and_sanitize_chunks(
dataset_path: str, world_size: int, file_pattern: str = TRAIN_DATA_FILE_PATTERN
):
dataset_chunks = [str(p) for p in Path(dataset_path).glob(file_pattern)]
n_chunks = len(dataset_chunks)

if n_chunks > world_size:
n_discard = n_chunks - world_size
dataset_chunks = dataset_chunks[:world_size]
else:
assert (
world_size % n_chunks == 0
), "World size should be a multiple of number of chunks"

assert n_chunks > 0, f"No valid chunks in {dataset_path}"

return dataset_chunks
Loading

0 comments on commit b8c8cb8

Please sign in to comment.