-
Notifications
You must be signed in to change notification settings - Fork 66
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
16 changed files
with
2,167 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -164,5 +164,5 @@ cython_debug/ | |
|
||
figures/ | ||
.vscode/ | ||
data/ | ||
.DS_Store | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,115 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
import json | ||
from dataclasses import dataclass | ||
from typing import Any, Iterator | ||
|
||
import numpy as np | ||
from pydantic import BaseModel, ConfigDict | ||
|
||
|
||
class BltExample(BaseModel): | ||
model_config = ConfigDict(extra="forbid") | ||
sample_id: str | ||
text: str | ||
tokens: list[int] | None | ||
entropies: list[float] | None | ||
patch_lengths: list[int] | None | ||
mask: list[bool] | None | ||
|
||
|
||
class MultiChoiceState(BaseModel): | ||
model_config = ConfigDict(extra="forbid") | ||
root_dir: str | ||
sources: dict[str, float] | ||
source_to_state: dict[str, Any] | ||
rng_state: dict[str, Any] | ||
|
||
|
||
class PrefetchState(BaseModel): | ||
model_config = ConfigDict(extra="forbid") | ||
seq_idx: int | ||
rng_state: dict[str, Any] | ||
prefetch_size: int | ||
batch_size: int | ||
|
||
|
||
class BltPackTokensState(BaseModel): | ||
model_config = ConfigDict(extra="forbid") | ||
start_token: int | ||
output_seq_len: int | ||
n_views: int = 2 | ||
|
||
|
||
class DataLoaderState(BaseModel): | ||
model_config = ConfigDict(extra="forbid") | ||
multi_choice_state: MultiChoiceState | ||
pack_tokens_state: BltPackTokensState | ||
prefetch_state: PrefetchState | ||
|
||
|
||
BltIterator = Iterator[tuple[BltExample, DataLoaderState]] | ||
|
||
|
||
class BltSequence(BaseModel): | ||
tokens: list[int] | ||
mask: list[bool] | ||
patch_lengths: list[int] | ||
|
||
|
||
@dataclass | ||
class Batch: | ||
x: np.ndarray | ||
y: np.ndarray | ||
mask: np.ndarray | None = None | ||
patch_lengths: np.ndarray | None = None | ||
ngram_ids: np.ndarray | None = None | ||
is_final: bool = False | ||
|
||
def to_python_dict(self) -> dict: | ||
x = self.x.tolist() | ||
y = self.y.tolist() | ||
if self.mask is None: | ||
mask = None | ||
else: | ||
mask = self.mask.tolist() | ||
if self.patch_lengths is None: | ||
patch_lengths = None | ||
else: | ||
patch_lengths = self.patch_lengths.tolist() | ||
if self.ngram_ids is None: | ||
ngram_ids = None | ||
else: | ||
ngram_ids = self.ngram_ids.tolist() | ||
return { | ||
"x": x, | ||
"y": y, | ||
"mask": mask, | ||
"patch_lengths": patch_lengths, | ||
"ngram_ids": ngram_ids, | ||
"is_final": self.is_final, | ||
} | ||
|
||
@classmethod | ||
def from_python_dict(cls, data: dict) -> "Batch": | ||
x = np.array(data["x"]) | ||
y = np.array(data["y"]) | ||
if data["mask"] is None: | ||
mask = None | ||
else: | ||
mask = np.array(data["mask"]) | ||
if data["patch_lengths"] is None: | ||
patch_lengths = None | ||
else: | ||
patch_lengths = np.array(data["patch_lengths"]) | ||
if data["ngram_ids"] is None: | ||
ngram_ids = None | ||
else: | ||
ngram_ids = np.array(data["ngram_ids"]) | ||
return Batch( | ||
x=x, | ||
y=y, | ||
mask=mask, | ||
patch_lengths=patch_lengths, | ||
ngram_ids=ngram_ids, | ||
is_final=data["is_final"], | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
import abc | ||
from typing import Any, Generator, Generic, TypeVar | ||
|
||
T = TypeVar("T") | ||
C = TypeVar("C") | ||
|
||
|
||
class StatefulIterator(Generic[T, C], abc.ABC): | ||
|
||
@abc.abstractmethod | ||
def get_state(self) -> C: | ||
pass | ||
|
||
@abc.abstractmethod | ||
def create_iter(self) -> Generator[T, Any, None]: | ||
pass | ||
|
||
|
||
class IteratorState(Generic[C]): | ||
@abc.abstractmethod | ||
def build(self) -> StatefulIterator[T, C]: | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,216 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
import re | ||
from logging import getLogger | ||
from pathlib import Path | ||
from typing import Any, Generator | ||
|
||
import pyarrow as pa | ||
|
||
# pyarrow needs the initialization from this import | ||
import pyarrow.dataset # pyright: ignore | ||
from pydantic import BaseModel, ConfigDict | ||
|
||
from bytelatent import ByteLatentError | ||
from bytelatent.data.data_types import BltExample | ||
from bytelatent.data.iterators.abstract_iterator import IteratorState, StatefulIterator | ||
|
||
logger = getLogger(__name__) | ||
|
||
|
||
class ArrowFileIteratorState(BaseModel, IteratorState): | ||
model_config = ConfigDict(extra="forbid") | ||
file_path: str | None | ||
row_num: int | ||
num_workers: int | ||
worker_id: int | ||
preprocess_dir: str | None | ||
dataset_files: list[str] | None | ||
entropy_model_name: str | None | ||
arrow_batch_size: int = 100 | ||
|
||
def build(self) -> "ArrowFileIterator": | ||
arrow_file = ArrowFileIterator( | ||
file_path=self.file_path, | ||
worker_id=self.worker_id, | ||
num_workers=self.num_workers, | ||
preprocess_dir=self.preprocess_dir, | ||
entropy_model_name=self.entropy_model_name, | ||
arrow_batch_size=self.arrow_batch_size, | ||
dataset_files=self.dataset_files, | ||
) | ||
if self.row_num != 0: | ||
arrow_file._set_row_num(self.row_num) | ||
return arrow_file | ||
|
||
|
||
def shard_sort_key(file: str | Path): | ||
match = re.search(r".+\.shard_([0-9]+)\.arrow", str(file)) | ||
shard_number = int(match.group(1)) | ||
return shard_number | ||
|
||
|
||
class ArrowFileIterator(StatefulIterator): | ||
def __init__( | ||
self, | ||
*, | ||
file_path: str | None, | ||
worker_id: int, | ||
num_workers: int, | ||
preprocess_dir: str | None, | ||
entropy_model_name: str | None, | ||
arrow_batch_size: int, | ||
dataset_files: list[str] | None = None, | ||
): | ||
assert 0 <= worker_id < num_workers, (worker_id, num_workers) | ||
if file_path is None and dataset_files is None: | ||
raise ByteLatentError("file_path and dataset_files cannot both be None") | ||
self.row_num = 0 | ||
self.iter_id = 0 | ||
self.batch_iterator = None | ||
self.batch_to_consume = None | ||
self.dataset = None | ||
self.file_path = file_path | ||
self.worker_id = worker_id | ||
self.num_workers = num_workers | ||
self.preprocess_dir = preprocess_dir | ||
self.entropy_model_name = entropy_model_name | ||
self.arrow_batch_size = arrow_batch_size | ||
if dataset_files is None: | ||
# Prepare arrow shards | ||
jsonl_file = Path(file_path) | ||
parts = re.match(r"(.+)\.chunk\.[0-9]+\.jsonl", jsonl_file.name) | ||
assert parts is not None | ||
dataset = parts.group(1) | ||
data_dir = Path(preprocess_dir) / dataset / entropy_model_name | ||
shard_files = list(data_dir.glob(f"{jsonl_file.name}.shard_*.arrow")) | ||
for s in shard_files: | ||
if not (data_dir / f"{s.name}.complete").exists(): | ||
raise ValueError(f"Missing .complete for input file: {s}") | ||
|
||
shard_files = sorted(shard_files, key=shard_sort_key) | ||
if len(shard_files) == 0: | ||
raise ByteLatentError( | ||
f"Zero shard_files found corresponding to: {file_path} using preprocess_dir={preprocess_dir} and entropy_model_name={entropy_model_name}, so the search path is data_dir={data_dir} for matches to {jsonl_file.name}.shard_*.arrow" | ||
) | ||
self.dataset_files = [str(f) for f in shard_files] | ||
else: | ||
self.preprocess_dir = None | ||
self.dataset_files = dataset_files | ||
|
||
def get_state(self) -> ArrowFileIteratorState: | ||
return ArrowFileIteratorState( | ||
file_path=self.file_path, | ||
row_num=self.row_num, | ||
worker_id=self.worker_id, | ||
num_workers=self.num_workers, | ||
preprocess_dir=self.preprocess_dir, | ||
entropy_model_name=self.entropy_model_name, | ||
arrow_batch_size=self.arrow_batch_size, | ||
dataset_files=self.dataset_files, | ||
) | ||
|
||
def create_iter( | ||
self, | ||
) -> Generator[BltExample, Any, None]: | ||
if self.dataset is None: | ||
self.dataset = pa.dataset.dataset(self.dataset_files, format="arrow") | ||
self.batch_iterator = self.dataset.to_batches( | ||
batch_size=self.arrow_batch_size | ||
) | ||
self.iter_id += 1 | ||
if self.batch_to_consume is not None: | ||
batch_columns: dict[str, list] = self.batch_to_consume | ||
self.batch_to_consume = None | ||
sample_ids = batch_columns["sample_id"] | ||
texts = batch_columns["text"] | ||
entropies = batch_columns["entropies"] | ||
for i in range(len(sample_ids)): | ||
out = BltExample( | ||
sample_id=sample_ids[i], | ||
entropies=entropies[i], | ||
text=texts[i], | ||
tokens=None, | ||
mask=None, | ||
patch_lengths=None, | ||
) | ||
self.row_num += 1 | ||
if (self.row_num - 1) % self.num_workers == self.worker_id: | ||
yield out | ||
|
||
for batch in self.batch_iterator: | ||
batch_columns = batch.to_pydict() | ||
sample_ids = batch_columns["sample_id"] | ||
texts = batch_columns["text"] | ||
entropies = batch_columns["entropies"] | ||
for i in range(len(sample_ids)): | ||
out = BltExample( | ||
sample_id=sample_ids[i], | ||
entropies=entropies[i], | ||
text=texts[i], | ||
tokens=None, | ||
mask=None, | ||
patch_lengths=None, | ||
) | ||
self.row_num += 1 | ||
if (self.row_num - 1) % self.num_workers == self.worker_id: | ||
yield out | ||
|
||
def _set_row_num(self, target_row_num: int): | ||
logger.info( | ||
f"Setting arrow position to {target_row_num} for {self.dataset_files}" | ||
) | ||
if target_row_num is None or target_row_num == 0: | ||
self.row_num = 0 | ||
self.dataset = None | ||
self.batch_iterator = None | ||
self.batch_to_consume = None | ||
else: | ||
self.dataset = pa.dataset.dataset(self.dataset_files, format="arrow") | ||
self.batch_iterator = self.dataset.to_batches( | ||
batch_size=self.arrow_batch_size | ||
) | ||
curr_remaining = target_row_num | ||
for batch in self.batch_iterator: | ||
if len(batch) > curr_remaining: | ||
batch_columns: dict[str, list] = batch.to_pydict() | ||
batch_columns["sample_id"] = batch_columns["sample_id"][ | ||
curr_remaining: | ||
] | ||
batch_columns["entropies"] = batch_columns["entropies"][ | ||
curr_remaining: | ||
] | ||
batch_columns["text"] = batch_columns["text"][curr_remaining:] | ||
self.batch_to_consume = batch_columns | ||
break | ||
elif len(batch) == curr_remaining: | ||
# We are exactly at the end of the batch, | ||
# so the next batch is the right spot | ||
break | ||
else: | ||
curr_remaining -= len(batch) | ||
self.row_num = target_row_num | ||
logger.info( | ||
f"Finished setting arrow position to {target_row_num} for {self.dataset_files}" | ||
) | ||
|
||
|
||
TRAIN_DATA_FILE_PATTERN = "*.chunk.*.jsonl" | ||
|
||
|
||
def find_and_sanitize_chunks( | ||
dataset_path: str, world_size: int, file_pattern: str = TRAIN_DATA_FILE_PATTERN | ||
): | ||
dataset_chunks = [str(p) for p in Path(dataset_path).glob(file_pattern)] | ||
n_chunks = len(dataset_chunks) | ||
|
||
if n_chunks > world_size: | ||
n_discard = n_chunks - world_size | ||
dataset_chunks = dataset_chunks[:world_size] | ||
else: | ||
assert ( | ||
world_size % n_chunks == 0 | ||
), "World size should be a multiple of number of chunks" | ||
|
||
assert n_chunks > 0, f"No valid chunks in {dataset_path}" | ||
|
||
return dataset_chunks |
Oops, something went wrong.