Skip to content

Dataset from modular configuration #41

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions fast_llm/data/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ class Dataset(abc.ABC):

@property
@abc.abstractmethod
def name(self):
def name(self) -> str:
"""
A name for the dataset to facilitate identification and debugging.
"""
Expand All @@ -169,7 +169,7 @@ class SamplingConfig(Config):


class SamplableDataset(Dataset):
def sample(self, config: SamplingConfig, data: Data):
def sample(self, config: SamplingConfig, data: Data) -> "SampledDataset":
pass


Expand Down
9 changes: 5 additions & 4 deletions fast_llm/data/gpt/concatenated.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import numpy as np

from fast_llm.data.gpt.config import GPTConcatenatedDatasetConfig
from fast_llm.data.gpt.dataset import GPTIndexedDataset
from fast_llm.utils import padded_cumsum

Expand All @@ -8,10 +9,10 @@ class GPTConcatenatedDataset(GPTIndexedDataset):

def __init__(
self,
name: str,
config: GPTConcatenatedDatasetConfig,
datasets: list[GPTIndexedDataset],
):
self._name = name
self._config = config
self._datasets = datasets
sizes = [dataset.num_documents for dataset in self._datasets]
self._dataset_splits = padded_cumsum(sizes)
Expand All @@ -24,7 +25,7 @@ def num_tokens(self):
def num_documents(self):
return sum(dataset.num_documents for dataset in self._datasets)

def get_document_sizes(self) -> "np.ndarray":
def get_document_sizes(self) -> np.ndarray:
# TODO: This can be really big.
return np.concatenate([dataset.get_document_sizes() for dataset in self._datasets])

Expand All @@ -39,4 +40,4 @@ def get(self, document: int, offset: int = 0, length: int | None = None):

@property
def name(self):
return self._name
return self._config.name
189 changes: 187 additions & 2 deletions fast_llm/data/gpt/config.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,199 @@
from fast_llm.config import Field, FieldHint, check_field, config_class
import pathlib
import typing

from fast_llm.config import Config, Field, FieldHint, check_field, config_class
from fast_llm.data.config import (
Data,
DataConfig,
DatasetSource,
FimConfig,
MultiprocessingContext,
SamplableDataset,
SampledDataset,
SamplingConfig,
TokenizerConfig,
_validate_path,
_validate_split,
)
from fast_llm.utils import Assert
from fast_llm.engine.distributed.config import PhaseType
from fast_llm.utils import Assert, Registry

if typing.TYPE_CHECKING:
from fast_llm.data.gpt.data import GPTData
from fast_llm.data.gpt.dataset import GPTIndexedDataset, GPTSamplingConfig

dataset_registry = Registry("dataset")


@config_class()
class DatasetConfig(Config):
_abstract = True
type: str = Field(
desc="Format for the dataset definition.",
hint=FieldHint.core,
)

@classmethod
def from_dict(
cls,
default: Config | dict[str, typing.Any],
*updates: Config | dict[str | tuple[str, ...], typing.Any],
strict: bool = True,
):
if cls.type == 1:
cls_ = GPTSplitDatasetConfig
elif cls.type == 1:
cls_ = GPTConcatenatedDatasetConfig
elif cls.type == 1:
cls_ = GPTBlendedDatasetConfig
elif cls.type == 1:
cls_ = GPTMemmapDatasetConfig
else:
raise NotImplementedError(cls.type)
Assert.custom(issubclass, cls_, cls)
return cls_.from_dict(default, *updates, strict=strict)

def build(self, config: "SamplingConfig", data: "Data") -> dict[PhaseType, SampledDataset]:
raise NotImplementedError()


@config_class()
class SamplableDatasetConfig(DatasetConfig):
def build(self, config: "GPTSamplingConfig", data: "GPTData") -> dict[PhaseType, SampledDataset]:
return {phase: dataset.sample(config, data) for phase, dataset in self.build_unsampled().items()}

def build_unsampled(self) -> dict[PhaseType, SamplableDataset]:
raise NotImplementedError()


@config_class()
class SplittableDatasetConfig(DatasetConfig):
def build_unsampled(self) -> dict[PhaseType, SamplableDataset]:
return {PhaseType.training: self.build_unsplit()}

def build_unsplit(self) -> SamplableDataset:
raise NotImplementedError()


@config_class()
class GPTIndexedDatasetConfig(SplittableDatasetConfig):
def build_unsplit(self) -> GPTIndexedDataset:
raise NotImplementedError()


@config_class()
class GPTMemmapDatasetConfig(GPTIndexedDatasetConfig):
# Path -> (unsampled, unsplit)
_abstract = False
path: pathlib.Path = Field(
desc="The path to the dataset, excluding the `.bin` or `.idx` suffix.",
hint=FieldHint.core,
)

def build_unsplit(self) -> SamplableDataset:
from fast_llm.data.gpt.memmap import GPTMemmapDataset

return GPTMemmapDataset(self)


@config_class()
class GPTConcatenatedDatasetConfig(GPTIndexedDatasetConfig):
"""
Concatenate multiple datasets as if they were one.
Must be done before sampling and splitting.
TODO: OK after sampling (staged training?) or splitting (Equal split for each sub-dataset, probably better?
[(unsampled, unsplit)] -> (unsampled, unsplit)
"""

_abstract = False
name: str = Field(
default="concatenated",
desc="The name of the dataset.",
hint=FieldHint.core,
)
datasets: list[GPTIndexedDatasetConfig] = Field(
desc="The datasets to concatenate.",
hint=FieldHint.core,
)

def build_unsplit(self) -> SamplableDataset:
from fast_llm.data.gpt.concatenated import GPTConcatenatedDataset

return GPTConcatenatedDataset(self, [dataset.build_unsplit() for dataset in self.datasets])


@config_class()
class GPTSplitDatasetConfig(SamplableDatasetConfig):
"""
Split a single dataset into multiple phases.
Must be done before sampling.
TODO: Ok after sampling?
(unsampled, unsplit) -> (unsampled, split)
"""

_abstract = False
dataset: GPTIndexedDatasetConfig = Field(
desc="The dataset to split.",
hint=FieldHint.core,
)
ratios: dict[PhaseType, float] = Field(
desc="The split ratio for each phase",
hint=FieldHint.core,
)

def build_unsampled(self) -> dict[PhaseType, SamplableDataset]:
from fast_llm.data.gpt.slice import GPTDatasetSlice

return GPTDatasetSlice.from_splits(self)


@config_class()
class GPTBlendedDatasetConfig(DatasetConfig):
# [(?sampled, ?split)] -> (sampled, split)
_abstract = False
datasets: list[DatasetConfig] = Field(
desc="The datasets to concatenate.",
hint=FieldHint.core,
)
weights: list[float] = Field(
desc="The blending weight of each dataset.",
hint=FieldHint.core,
)

@property
def split(self) -> bool:
return True

@property
def sampled(self) -> bool:
return True

def build(self, config: "GPTSamplingConfig", data: "GPTData") -> dict[PhaseType, SampledDataset]:
from fast_llm.data.blended import BlendedDataset

datasets = {}
for dataset in self.datasets:
dataset_split = dataset.build(data)
if datasets:
Assert.eq(set(datasets), set(dataset_split))
else:
datasets = {phase: [] for phase in dataset_split}
for phase, phase_datasets in datasets.items():
phase_datasets.append(dataset_split[phase])

BlendedDataset(
list(datasets.values()),
weights=[self._dataset_weights[name] for name in datasets],
name=phase.value,
num_samples=self._samples_per_phase[phase],
cache_directory=self._cache_directory,
group=self._distributed.world_group,
verbose=run.is_main_rank,
data_sample_warn_time_ms=self._config.data_sample_warn_time_ms,
)
return {
phase: BlendedDataset(phase_datasets, self.weights, data) for phase, phase_datasets in datasets.items()
}


@config_class()
Expand Down
6 changes: 3 additions & 3 deletions fast_llm/data/gpt/dataset.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import abc
import typing

import numpy as np

from fast_llm.config import Field, config_class
from fast_llm.data.config import SamplableDataset, SamplingConfig

if typing.TYPE_CHECKING:
import numpy as np

from fast_llm.data.gpt.data import GPTData


Expand Down Expand Up @@ -59,4 +59,4 @@ def get_document_sizes(self) -> "np.ndarray":
def sample(self, config: GPTSamplingConfig, data: "GPTData"):
from fast_llm.data.gpt.sampled import GPTSampledIndexedDataset

return GPTSampledIndexedDataset(self, config, data)
return GPTSampledIndexedDataset(config, self, data)
31 changes: 16 additions & 15 deletions fast_llm/data/gpt/memmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import numpy as np

from fast_llm.data.gpt.config import GPTMemmapDatasetConfig
from fast_llm.data.gpt.dataset import GPTIndexedDataset
from fast_llm.utils import Assert, div, padded_cumsum

Expand All @@ -28,15 +29,15 @@ class GPTMemmapDataset(GPTIndexedDataset):
}
_INDEX_HEADER = b"MMIDIDX\x00\x00"

def __init__(self, name: str, prefix: pathlib.Path | str):
self._init(name, prefix)
def __init__(self, config: GPTMemmapDatasetConfig):
self._init(config)

def _init(self, name: str, prefix: pathlib.Path | str):
def _init(self, config: GPTMemmapDatasetConfig):
super().__init__()
self._name = name
self._prefix = pathlib.Path(prefix)
self._config = config
self._name = str(self._config.path).replace("/", "__")

with self._prefix.with_suffix(".idx").open("rb") as stream:
with self._config.path.with_suffix(".idx").open("rb") as stream:
Assert.eq(stream.read(9), self._INDEX_HEADER)
Assert.eq(struct.unpack("<Q", stream.read(8))[0], 1)

Expand All @@ -45,7 +46,7 @@ def _init(self, name: str, prefix: pathlib.Path | str):
_ = struct.unpack("<Q", stream.read(8))[0]
offset = stream.tell()

self._index_bin_buffer_mmap = np.memmap(self._prefix.with_suffix(".idx"), mode="r", order="C")
self._index_bin_buffer_mmap = np.memmap(self._config.path.with_suffix(".idx"), mode="r", order="C")
self._index_bin_buffer = memoryview(self._index_bin_buffer_mmap)
self._document_sizes = np.frombuffer(
self._index_bin_buffer, dtype=np.int32, count=self._num_documents, offset=offset
Expand All @@ -57,31 +58,31 @@ def _init(self, name: str, prefix: pathlib.Path | str):
offset=offset + self._document_sizes.nbytes,
)

self._bin_buffer_mmap = np.memmap(self._prefix.with_suffix(".bin"), mode="r", order="C")
self._bin_buffer_mmap = np.memmap(self._config.path.with_suffix(".bin"), mode="r", order="C")
self._bin_buffer = memoryview(self._bin_buffer_mmap)

def __getstate__(self):
return (self._name, self._prefix)
return self._config.to_serialized()

def __setstate__(self, state):
self._init(*state)
self._init(GPTMemmapDatasetConfig.from_dict(state))

def __del__(self):
self._bin_buffer_mmap._mmap.close() # noqa
del self._bin_buffer_mmap
self._index_bin_buffer_mmap._mmap.close() # noqa
del self._index_bin_buffer_mmap

def get(self, idx, offset=0, length=None):
def get(self, document: int, offset: int = 0, length: int | None = None):
return np.frombuffer(
self._bin_buffer,
dtype=self._dtype,
count=self._document_sizes[idx] - offset if length is None else length,
offset=self._pointers[idx] + offset * np.dtype(self._dtype).itemsize,
count=self._document_sizes[document] - offset if length is None else length,
offset=self._pointers[document] + offset * np.dtype(self._dtype).itemsize,
)

@property
def name(self):
def name(self) -> str:
return self._name

@property
Expand All @@ -92,7 +93,7 @@ def num_documents(self) -> int:
def num_tokens(self) -> int:
return div(self._bin_buffer_mmap.size, np.dtype(self._dtype).itemsize)

def get_document_sizes(self) -> "np.ndarray":
def get_document_sizes(self) -> np.ndarray:
"""
The size of each document in the dataset.
The resulting array could be very large, so this method should be called cautiously,
Expand Down
4 changes: 2 additions & 2 deletions fast_llm/data/gpt/sampled.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ class GPTSampledIndexedDataset(SampledDataset):

def __init__(
self,
indexed_dataset: GPTIndexedDataset,
config: GPTSamplingConfig,
indexed_dataset: GPTIndexedDataset,
data: GPTData,
):
assert isinstance(config, GPTSamplingConfig)
Expand Down Expand Up @@ -188,7 +188,7 @@ def __getitem__(self, idx):
dtype=np.int64,
)
if self._fim is not None:
sample = self._fim(sample, np.random.RandomState(seed=(self._seed + idx) % MAX_SEED))
sample = self._fim(sample, np.random.RandomState(seed=(self._config.seed + idx) % MAX_SEED))

return sample

Expand Down
Loading