Skip to content

Split dataset #44

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Nov 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 55 additions & 23 deletions fast_llm/data/blended.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
import logging
import pathlib
import time
import typing

import numpy as np

from fast_llm.core.distributed import ProcessGroup, safe_barrier
from fast_llm.data.config import SampledDataset
from fast_llm.core.distributed import safe_barrier
from fast_llm.data.config import PhaseSplits, SampledDataset, SamplingConfig, SplitDataset
from fast_llm.engine.config_utils.run import log_main_rank
from fast_llm.utils import Assert
from fast_llm.utils import Assert, normalize_probabilities

if typing.TYPE_CHECKING:
from fast_llm.data.gpt.data import GPTData

try:
from fast_llm.csrc.data import build_blending_indices # noqa
Expand All @@ -29,41 +33,69 @@ class BlendedDataset(SampledDataset):

def __init__(
self,
datasets: list[SampledDataset],
weights: list[float],
*,
name: str = "blended",
num_samples: int,
cache_directory: pathlib.Path | None = None,
group: ProcessGroup | None = None,
verbose: bool = True,
data_sample_warn_time_ms: float = 1000,
name: str,
datasets_and_weights: list[tuple[SampledDataset, float]],
sampling_config: SamplingConfig,
# TODO: Generalize
data: "GPTData",
):
self._datasets = datasets
self._name = name
self._num_samples = num_samples
self._weights = weights
self._data_sample_warn_time_ms = data_sample_warn_time_ms
assert len(datasets_and_weights) > 0
self._datasets, weights = zip(*datasets_and_weights)
self._weights = normalize_probabilities(weights)
self._num_samples = sampling_config.num_samples
self._data_sample_warn_time_ms = data.config.data_sample_warn_time_ms

if cache_directory is None:
if sampling_config.cache_directory is None:
self._dataset_idx_filename, self._sample_idx_filename = None, None
self._dataset_index, self._sample_index = self._build_blending_indices(verbose and len(datasets) <= 20)
self._dataset_index, self._sample_index = self._build_blending_indices(
sampling_config.verbose and len(self._datasets) <= 20
)
else:
self._dataset_idx_filename = cache_directory / (self._name + "_blending_dataset_idx.npy")
self._sample_idx_filename = cache_directory / (self._name + "_blending_sample_idx.npy")
group = data.distributed.world_group
self._dataset_idx_filename = sampling_config.cache_directory / (self._name + "_blending_dataset_idx.npy")
self._sample_idx_filename = sampling_config.cache_directory / (self._name + "_blending_sample_idx.npy")

# Build the indexed mapping if it doesn't exist.
# TODO: This only works if the dataset location is accessible by all job.
if (group is None or group.rank() == 0) and not (
self._dataset_idx_filename.is_file() and self._sample_idx_filename.is_file()
):
dataset_index, sample_index = self._build_blending_indices(verbose and len(datasets) <= 20)
cache_directory.mkdir(exist_ok=True, parents=True)
dataset_index, sample_index = self._build_blending_indices(
sampling_config.verbose and len(self._datasets) <= 20
)
sampling_config.cache_directory.mkdir(exist_ok=True, parents=True)
np.save(self._dataset_idx_filename, dataset_index)
np.save(self._sample_idx_filename, sample_index)

safe_barrier(group, self._name)
self._load_mappings(verbose)
self._load_mappings(sampling_config.verbose)

@classmethod
def apply(
cls,
name: str,
datasets_and_weights: list[(SplitDataset[SampledDataset], float)],
sampling_configs: PhaseSplits[SamplingConfig],
data: "GPTData",
):
Assert.leq(set(sampling_configs), set.union(*[set(dataset) for dataset, _ in datasets_and_weights]))
return SplitDataset[BlendedDataset](
name,
{
phase: BlendedDataset(
f"{name}_{phase.value}",
[
(dataset[phase], weight)
for dataset, weight in datasets_and_weights
if phase in dataset and weight > 0
],
sampling_config,
data,
)
for phase, sampling_config in sampling_configs.items()
},
)

def __getstate__(self):
return (
Expand Down
76 changes: 65 additions & 11 deletions fast_llm/data/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,9 +123,18 @@ class TokenizerConfig(Config):
)


@config_class
class SamplingConfig(Config):
num_samples: int = Field(default=1, desc="Number of samples to generate.")
seed: int = Field(default=0, desc="Random seed.")
cache_directory: pathlib.Path | None = Field(default=None, desc="Path to the sampling cache directory.")
verbose: bool = Field(default=True, desc="Log sampling progress.")


@config_class()
class DataConfig(Config):
_abstract = True
_sampling_config_class: typing.ClassVar[type[SamplingConfig]]


class Data(abc.ABC):
Expand Down Expand Up @@ -159,17 +168,8 @@ def name(self):
A name for the dataset to facilitate identification and debugging.
"""


@config_class
class SamplingConfig(Config):
num_samples: int = Field(default=1, desc="Number of samples to generate.")
seed: int = Field(default=0, desc="Random seed.")
cache_directory: pathlib.Path | None = Field(default=None, desc="Path to the sampling cache directory.")
verbose: bool = Field(default=True, desc="Log sampling progress.")


class SamplableDataset(Dataset):
def sample(self, config: SamplingConfig, data: Data):
@abc.abstractmethod
def as_split(self, default_phase: PhaseType = PhaseType.training):
pass


Expand All @@ -186,3 +186,57 @@ def __getitem__(self, index: int):
@abc.abstractmethod
def __len__(self):
pass

def as_split(self, default_phase: PhaseType = PhaseType.training):
return SplitDataset(self.name, {default_phase: self})


class SamplableDataset(Dataset):
# TODO: Move to dataset config?
_data_config_class: typing.ClassVar[type[DataConfig]]

def sample(self, config: SamplingConfig, data: Data) -> SampledDataset:
pass

def as_split(self, default_phase: PhaseType = PhaseType.training):
return SplitDataset(self.name, {default_phase: self})


_SplittableType = typing.TypeVar("_SplittableType")
_DatasetType = typing.TypeVar("_DatasetType", bound=Dataset)
_SampledDatasetType = typing.TypeVar("_SampledDatasetType", bound=SampledDataset)
_SamplableDatasetType = typing.TypeVar("_SamplableDatasetType", bound=SamplableDataset)


class PhaseSplits(dict[PhaseType, _SplittableType], typing.Generic[_SplittableType]):
pass


class SplitDataset(Dataset, PhaseSplits[_DatasetType], typing.Generic[_DatasetType]):
def __init__(self, name: str, datasets: dict[PhaseType, _DatasetType]):
super().__init__(datasets)
self._name = name

def as_split(self, default_phase: PhaseType = PhaseType.training):
return self

@property
def name(self):
return self._name


class SampledSplitDataset(SplitDataset[_SampledDatasetType], typing.Generic[_SampledDatasetType]):
pass


class SamplableSplitDataset(SplitDataset[_SamplableDatasetType], typing.Generic[_SamplableDatasetType]):
def sample(self, sampling_configs: PhaseSplits[SamplingConfig], data: Data):
return SampledSplitDataset(
f"{self.name}_sampled",
{phase: self[phase].sample(sampling_config, data) for phase, sampling_config in sampling_configs.items()},
)


class CopySplitDataset(SamplableSplitDataset):
def __init__(self, name: str, dataset: _SplittableType, phases: list[PhaseType]):
super().__init__(name, {phase: dataset for phase in phases})
6 changes: 6 additions & 0 deletions fast_llm/data/gpt/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
DatasetSource,
FimConfig,
MultiprocessingContext,
SamplingConfig,
TokenizerConfig,
_validate_path,
_validate_split,
Expand Down Expand Up @@ -60,3 +61,8 @@ class GPTDataConfig(DataConfig):
desc="Multiprocessing context. Do not touch.",
hint=FieldHint.expert,
)


@config_class
class GPTSamplingConfig(SamplingConfig):
sequence_length: int = Field(default=None, desc="Number of token in each sample.")
119 changes: 57 additions & 62 deletions fast_llm/data/gpt/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,8 @@
import torch.utils.data

from fast_llm.data.blended import BlendedDataset
from fast_llm.data.config import Data, DatasetSource, SampledDataset
from fast_llm.data.gpt.config import GPTDataConfig
from fast_llm.data.gpt.dataset import GPTSamplingConfig
from fast_llm.data.config import CopySplitDataset, Data, DatasetSource, PhaseSplits, SampledSplitDataset
from fast_llm.data.gpt.config import GPTDataConfig, GPTSamplingConfig
from fast_llm.data.gpt.dummy import DummyGPTDataset
from fast_llm.data.gpt.memmap import GPTMemmapDataset
from fast_llm.data.gpt.slice import GPTDatasetSlice
Expand All @@ -33,8 +32,7 @@ class GPTData(Data):
TODO: Separate generic and GPT classes.
"""

_sampled_datasets: dict[PhaseType, dict[str, SampledDataset]]
_blended_datasets: dict[PhaseType, SampledDataset]
_datasets: SampledSplitDataset
_tokenizer: Tokenizer | None
_distributed: Distributed
_cache_directory: pathlib.Path | None
Expand Down Expand Up @@ -132,8 +130,7 @@ def setup(self, distributed: Distributed, samples_per_phase: dict[PhaseType, int
else:
self._cache_directory = run.experiment_directory / "dataset_cache"

# Build and split datasets.
self._sampled_datasets = {phase: {} for phase in self._samples_per_phase}
datasets_and_weights = []
for i, (name, weight) in enumerate(self._dataset_weights.items()):
if i % 100 == 0 and i > 0:
log_main_rank(f"Prepared {i} of {self._num_datasets} datasets.")
Expand All @@ -146,27 +143,47 @@ def setup(self, distributed: Distributed, samples_per_phase: dict[PhaseType, int
expected_samples
+ 5 * math.sqrt(expected_samples * self._dataset_weights[name] * (1 - self._dataset_weights[name]))
)
sampled_datasets = self._build_and_sample_dataset(name, dataset_samples_per_phase)
for phase, dataset in sampled_datasets.items():
self._sampled_datasets[phase][name] = dataset

self._blended_datasets = {
phase: (
list(datasets.values())[0]
if len(datasets) == 1
else BlendedDataset(
list(datasets.values()),
weights=[self._dataset_weights[name] for name in datasets],
name=phase.value,
num_samples=self._samples_per_phase[phase],
cache_directory=self._cache_directory,
group=self._distributed.world_group,
verbose=run.is_main_rank,
data_sample_warn_time_ms=self._config.data_sample_warn_time_ms,
)
sampling_configs = PhaseSplits[GPTSamplingConfig](
{
phase: GPTSamplingConfig(
num_samples=dataset_samples_per_phase[phase],
sequence_length=self._max_sequence_length,
seed=self._distributed_config.seed,
cache_directory=(
self._dataset_prefixes[name].parent
if self._cache_directory is None
else self._cache_directory
),
verbose=self._num_datasets <= 5,
)
for phase, num_samples in dataset_samples_per_phase.items()
if num_samples > 0
}
)
datasets_and_weights.append(
(self._build_and_sample_dataset(name, sampling_configs), self._dataset_weights[name])
)

if len(datasets_and_weights) == 1:
self._datasets = datasets_and_weights[0][0]
else:
self._datasets = BlendedDataset.apply(
"blended",
datasets_and_weights,
PhaseSplits[GPTSamplingConfig](
{
phase: GPTSamplingConfig(
num_samples=samples_per_phase,
sequence_length=self._max_sequence_length,
seed=self._distributed_config.seed,
cache_directory=None if self._cache_directory is None else self._cache_directory,
verbose=self._num_datasets <= 5,
)
for phase, samples_per_phase in self._samples_per_phase.items()
}
),
self,
)
for phase, datasets in self._sampled_datasets.items()
}
self._is_setup = True

@property
Expand All @@ -192,14 +209,14 @@ def get_iterator(
prefetch_factor: int | None = None,
):
assert self._is_setup
Assert.incl(phase, self._blended_datasets)
Assert.incl(phase, self._datasets)
Assert.in_range_incl(batch_config.sequence_length, 1, self._max_sequence_length)
log_main_rank(f"Initializing {phase} data iterator from sample {consumed_samples}...")
return iter(
torch.utils.data.DataLoader(
self._blended_datasets[phase], # noqa
self._datasets[phase], # noqa
batch_sampler=SampledDatasetIterator(
total_samples=len(self._blended_datasets[phase]),
total_samples=len(self._datasets[phase]),
begin_index=consumed_samples,
micro_batch_size=batch_config.micro_batch_size,
data_rank=self._distributed.config.batch_data_rank,
Expand All @@ -212,36 +229,14 @@ def get_iterator(
)
)

def _build_and_sample_gpt_dataset(self, name: str, dataset_samples_per_phase: dict[PhaseType, int]):
dataset_split = GPTDatasetSlice.from_splits(
def _build_and_sample_gpt_dataset(self, name: str, sampling_configs: PhaseSplits[GPTSamplingConfig]):
return GPTDatasetSlice.from_splits(
GPTMemmapDataset(name, self._dataset_prefixes[name]), self._phase_split
)

sampled_datasets = {}
for phase, num_samples in dataset_samples_per_phase.items():
if num_samples == 0:
continue
sampled_datasets[phase] = dataset_split[phase].sample(
GPTSamplingConfig(
num_samples=num_samples,
sequence_length=self._max_sequence_length,
seed=self._distributed_config.seed,
cache_directory=(
self._dataset_prefixes[name].parent if self._cache_directory is None else self._cache_directory
),
verbose=self._num_datasets <= 5,
),
self,
)
return sampled_datasets

def _build_and_sample_dummy_dataset(self, name: str, dataset_samples_per_phase: dict[PhaseType, int]):
return {
phase: DummyGPTDataset(
dataset_samples_per_phase[phase],
self._max_sequence_length,
self._vocab_size,
name,
)
for phase in dataset_samples_per_phase
}
).sample(sampling_configs, self)

def _build_and_sample_dummy_dataset(self, name: str, sampling_configs: PhaseSplits[GPTSamplingConfig]):
return CopySplitDataset(
f"{name}_split",
DummyGPTDataset(name, self._max_sequence_length, self._vocab_size),
list(sampling_configs),
).sample(sampling_configs, self)
Loading