Skip to content

Modular dataset configuration #104

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 56 commits into from
Jan 22, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
56 commits
Select commit Hold shift + click to select a range
147e33b
Modular dataset configuration
jlamypoirier Jan 6, 2025
c41a2c5
fixes
jlamypoirier Jan 7, 2025
e013ba2
fix
jlamypoirier Jan 7, 2025
6b45944
Merge branch 'main' into modular_dataset
tscholak Jan 9, 2025
952a03d
Merge branch 'main' into modular_dataset
jlamypoirier Jan 9, 2025
82285ae
Generalize indexed
jlamypoirier Jan 9, 2025
7011ca3
fix
jlamypoirier Jan 9, 2025
9574715
Modularize fim, decouple data from dataset, basic tests, misc
jlamypoirier Jan 10, 2025
5532b97
Make tests pass
jlamypoirier Jan 11, 2025
5d5e0ab
Remove split datasets
jlamypoirier Jan 11, 2025
baacc4e
Make tests pass
jlamypoirier Jan 13, 2025
09640d8
misc
jlamypoirier Jan 13, 2025
a73acf6
misc
jlamypoirier Jan 13, 2025
bb1b87f
Fix merge
jlamypoirier Jan 15, 2025
148b448
Type hints
jlamypoirier Jan 15, 2025
13e4f43
misc
jlamypoirier Jan 16, 2025
0219006
Dataset tweaks
jlamypoirier Jan 16, 2025
b9b516f
Merge branch 'dataset_tweaks' into modular_dataset
jlamypoirier Jan 16, 2025
8a33cef
fix
jlamypoirier Jan 16, 2025
62fbe01
misc
jlamypoirier Jan 16, 2025
1934828
misc
jlamypoirier Jan 16, 2025
c0be45c
misc
jlamypoirier Jan 16, 2025
6358d08
Merge branch 'dataset_tweaks' into modular_dataset
jlamypoirier Jan 16, 2025
53922e2
fix
jlamypoirier Jan 16, 2025
c2ee93d
fix
jlamypoirier Jan 16, 2025
05cf63f
Merge branch 'dataset_tweaks' into modular_dataset
jlamypoirier Jan 16, 2025
dc11ca6
Merge branch 'main' into modular_dataset
jlamypoirier Jan 16, 2025
a8facf0
Merge remote-tracking branch 'origin/main' into modular_dataset
jlamypoirier Jan 17, 2025
c2105b5
stuff
jlamypoirier Jan 17, 2025
b5e816c
More tests
jlamypoirier Jan 17, 2025
a4fe3d4
tests and misc
jlamypoirier Jan 17, 2025
0b184d3
Dataset tests
jlamypoirier Jan 17, 2025
92aef96
Merge branch 'dataset_tests' into modular_dataset
jlamypoirier Jan 17, 2025
1e57882
fixes
jlamypoirier Jan 17, 2025
9213b0a
Merge branch 'dataset_tests' into modular_dataset
jlamypoirier Jan 17, 2025
7e43ea9
Legacy tests
jlamypoirier Jan 19, 2025
bae839b
Merge branch 'dataset_tests' into modular_dataset
jlamypoirier Jan 20, 2025
41ad25f
Fix merge, update tests
jlamypoirier Jan 20, 2025
d7c1f38
Fix
jlamypoirier Jan 20, 2025
28bf7de
stuff
jlamypoirier Jan 20, 2025
401f56e
Merge branch 'dataset_tests' into modular_dataset
jlamypoirier Jan 20, 2025
17e3aea
fix
jlamypoirier Jan 20, 2025
ab2f468
Drop class kwarg
jlamypoirier Jan 20, 2025
27587a4
fixes
jlamypoirier Jan 22, 2025
ca1f944
fixes
jlamypoirier Jan 22, 2025
3c17819
fixes
jlamypoirier Jan 22, 2025
bd2fcec
Merge branch 'dataset_tests' into modular_dataset
jlamypoirier Jan 22, 2025
5405d42
Merge branch 'main' into modular_dataset
jlamypoirier Jan 22, 2025
0d8bf14
Fix merge
jlamypoirier Jan 22, 2025
a0aae75
Fix merge
jlamypoirier Jan 22, 2025
9dbbcf9
Fix merge
jlamypoirier Jan 22, 2025
6dea63e
Fix merge
jlamypoirier Jan 22, 2025
8245041
fixes
jlamypoirier Jan 22, 2025
77b1324
fixes
jlamypoirier Jan 22, 2025
54e5fa5
fixes
jlamypoirier Jan 22, 2025
755c355
Match legacy
jlamypoirier Jan 22, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions examples/mistral.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,9 @@ batch:
micro_batch_size: 2
batch_size: 64
data:
format: random
split: [1, 0, 0]
datasets:
Training:
type: random
optimizer:
learning_rate:
base: 1.0e-05
Expand Down
24 changes: 23 additions & 1 deletion fast_llm/data/data/gpt/config.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
import logging

from fast_llm.config import Field, FieldHint, check_field, config_class
from fast_llm.data.config import MultiprocessingContext, TokenizerConfig
from fast_llm.data.data.config import DataConfig
from fast_llm.data.dataset.gpt.config import GPTLegacyConfig
from fast_llm.data.dataset.gpt.config import GPTLegacyConfig, GPTLegacyDatasetConfig, GPTSampledDatasetConfig
from fast_llm.engine.distributed.config import PhaseType
from fast_llm.utils import Assert

logger = logging.getLogger(__name__)


@config_class()
class GPTDataConfig(DataConfig, GPTLegacyConfig):
Expand All @@ -20,6 +25,12 @@ class GPTDataConfig(DataConfig, GPTLegacyConfig):
desc="Configuration for the tokenizer (for FIM).",
hint=FieldHint.feature,
)
# TODO: Review field. Move closer to phase definition in training config?
datasets: dict[PhaseType, GPTSampledDatasetConfig] = Field(
default_factory=dict,
desc="Configuration for the dataset(s).",
hint=FieldHint.core,
)
data_sample_warn_time_ms: float = Field(
default=1000,
desc="Warn if a sample takes too long to load.",
Expand All @@ -31,3 +42,14 @@ class GPTDataConfig(DataConfig, GPTLegacyConfig):
desc="Multiprocessing context. Do not touch.",
hint=FieldHint.expert,
)

def _validate(self) -> None:
if not self.datasets:
logger.warning(
"Using the legacy dataset definition format." " Specify it through `data.datasets` instead."
)
self.datasets = {
phase: GPTLegacyDatasetConfig.from_dict(self, strict=False)
for phase in (PhaseType.training, PhaseType.validation, PhaseType.test)
}
super()._validate()
180 changes: 23 additions & 157 deletions fast_llm/data/data/gpt/data.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
import json
import logging
import math
import pathlib
import typing
import warnings
Expand All @@ -10,21 +8,16 @@

from fast_llm.data.data.abstract import Data
from fast_llm.data.data.gpt.config import GPTDataConfig
from fast_llm.data.dataset.abstract import CopySplitDataset, PhaseSplits, SampledSplitDataset
from fast_llm.data.dataset.blended import BlendedDataset
from fast_llm.data.dataset.gpt.config import GPTSamplingConfig, LegacyDatasetSource
from fast_llm.data.dataset.gpt.fim import GPTFimDataset
from fast_llm.data.dataset.gpt.indexed import GPTDatasetSlice
from fast_llm.data.dataset.gpt.memmap import GPTMemmapDataset
from fast_llm.data.dataset.gpt.random import GPTRandomDataset
from fast_llm.data.dataset.abstract import SampledDataset
from fast_llm.data.dataset.gpt.config import GPTSamplingConfig
from fast_llm.data.dataset.monitor import DatasetMonitor
from fast_llm.data.iterator import SampledDatasetIterator
from fast_llm.data.tokenizer import Tokenizer
from fast_llm.engine.config_utils.run import log_main_rank
from fast_llm.engine.distributed.config import DistributedConfig, PhaseType
from fast_llm.engine.distributed.distributed import Distributed
from fast_llm.engine.schedule.config import BatchConfig
from fast_llm.utils import Assert, normalize_probabilities
from fast_llm.utils import Assert

logger = logging.getLogger(__name__)

Expand All @@ -36,9 +29,8 @@ class GPTData[ConfigType: GPTDataConfig](Data[ConfigType]):
TODO: Separate generic and GPT classes.
"""

_datasets: SampledSplitDataset
_datasets: dict[PhaseType, SampledDataset]
_tokenizer: Tokenizer | None
_phases: typing.ClassVar[tuple[PhaseType, ...]] = (PhaseType.training, PhaseType.validation, PhaseType.test)
_is_setup: bool = False

def __init__(
Expand All @@ -55,59 +47,6 @@ def __init__(
super().__init__(config, distributed_config)
self._vocab_size = vocab_size
self._max_sequence_length = max_sequence_length
Assert.eq(len(self._config.split), len(self._phases))
self._phase_split = {
phase: ratio
for phase, ratio in zip(self._phases, normalize_probabilities(self._config.split))
if ratio > 0
}

data_base_path = None
if self._config.format == LegacyDatasetSource.file:
Assert.eq(len(self._config.path), 1)
data_path = pathlib.Path(self._config.path[0])
dataset_defs = json.load(data_path.open("r"))
data_base_path = data_path.parent
dataset_prefixes = [dataset_def["prefix"] for dataset_def in dataset_defs["datasets"]]
dataset_weights = normalize_probabilities(
[dataset_def["weight"] for dataset_def in dataset_defs["datasets"]]
)
self._build_and_sample_dataset = self._build_and_sample_gpt_dataset
elif self._config.format == LegacyDatasetSource.list:
Assert.geq(len(self._config.path), 1)
if len(self._config.path) == 1:
dataset_prefixes, dataset_weights = [self._config.path[0].strip()], [1.0]
else:
Assert.custom(lambda x: x % 2 == 0, len(self._config.path))
dataset_prefixes = [x.strip() for x in self._config.path[1::2]]
assert len(dataset_prefixes) == len(set(dataset_prefixes))
dataset_weights = normalize_probabilities([float(x) for x in self._config.path[::2]])
self._build_and_sample_dataset = self._build_and_sample_gpt_dataset
elif self._config.format == LegacyDatasetSource.random:
Assert.eq(len(self._config.path), 0)
dataset_prefixes, dataset_weights = [None], [1.0]
self._build_and_sample_dataset = self._build_and_sample_dummy_dataset
else:
raise NotImplementedError(self._config.format)

dataset_names = [
f"dataset_{i}_{'dummy' if prefix is None else prefix.replace('/','__')}"
for i, prefix in enumerate(dataset_prefixes)
]
self._num_datasets = len(dataset_names)
self._dataset_prefixes = {
name: (
None
if prefix is None
else (
pathlib.Path(prefix).resolve()
if data_base_path is None
else (pathlib.Path(data_base_path) / prefix).resolve()
)
)
for name, prefix in zip(dataset_names, dataset_prefixes)
}
self._dataset_weights = {name: weight for name, weight in zip(dataset_names, dataset_weights)}

def setup(
self,
Expand All @@ -120,82 +59,30 @@ def setup(
This may take a while and a significant amount of cpu memory.
"""
super().setup(distributed, samples_per_phase, cache_directory)
Assert.leq(set(samples_per_phase), set(self._phase_split))
log_main_rank(f"Preparing {self._num_datasets} datasets. This may take several minutes.")
self._tokenizer = Tokenizer(self._config.tokenizer) if self._config.fim.rate > 0 else None
self._distributed = distributed
self._samples_per_phase = samples_per_phase
log_main_rank(f"Preparing dataset. This may take several minutes.")
self._tokenizer = None if self._config.tokenizer.path is None else Tokenizer(self._config.tokenizer)

if self._cache_directory is None:
# TODO: Avoid this
warnings.warn(f"Using the dataset directory for the index cache.")

datasets_and_weights = []
for i, (name, weight) in enumerate(self._dataset_weights.items()):
if i % 100 == 0 and i > 0:
log_main_rank(f"Prepared {i} of {self._num_datasets} datasets.")
dataset_samples_per_phase = {}
for phase, samples_per_phase in self._samples_per_phase.items():
expected_samples = self._dataset_weights[name] * samples_per_phase
# Add 5 times the standard deviation (of a binomial distribution)
# so the probability of sampling more than this amount during blending is negligible.
dataset_samples_per_phase[phase] = math.ceil(
expected_samples
+ 5 * math.sqrt(expected_samples * self._dataset_weights[name] * (1 - self._dataset_weights[name]))
self._datasets = {}
for phase, num_samples in samples_per_phase.items():
if num_samples > 0:
# TODO: Do the check earlier.
assert phase in self._config.datasets
sampling_config = GPTSamplingConfig(
num_samples=samples_per_phase[phase],
seed=self._distributed_config.seed,
cache_directory=self._cache_directory,
distributed=distributed,
phase=phase,
sequence_length=self._max_sequence_length,
vocab_size=self._vocab_size,
tokenizer=self._tokenizer,
)

sampling_configs = PhaseSplits[GPTSamplingConfig](
{
phase: GPTSamplingConfig(
num_samples=dataset_samples_per_phase[phase],
seed=self._distributed_config.seed,
cache_directory=(
self._dataset_prefixes[name].parent
if self._cache_directory is None and isinstance(self._dataset_prefixes[name], pathlib.Path)
else self._cache_directory
),
verbose=self._num_datasets <= 5,
distributed=self._distributed,
sequence_length=self._max_sequence_length,
vocab_size=self._vocab_size,
tokenizer=self._tokenizer,
)
for phase, num_samples in dataset_samples_per_phase.items()
if num_samples > 0
}
)
datasets_and_weights.append(
(self._build_and_sample_dataset(name, sampling_configs), self._dataset_weights[name])
)

if len(datasets_and_weights) == 1:
datasets = datasets_and_weights[0][0]
else:
datasets = BlendedDataset.apply(
"blended",
datasets_and_weights,
PhaseSplits[GPTSamplingConfig](
{
phase: GPTSamplingConfig(
num_samples=samples_per_phase,
seed=self._distributed_config.seed,
cache_directory=None if self._cache_directory is None else self._cache_directory,
verbose=self._num_datasets <= 5,
distributed=self._distributed,
sequence_length=self._max_sequence_length,
vocab_size=self._vocab_size,
tokenizer=self._tokenizer,
)
for phase, samples_per_phase in self._samples_per_phase.items()
}
),
)
self._datasets = SampledSplitDataset[GPTDatasetSlice](
"monitor",
{
phase: DatasetMonitor(dataset, self._config.data_sample_warn_time_ms)
for phase, dataset in datasets.items()
},
)
dataset = self._config.datasets[phase].build_and_sample(sampling_config)
self._datasets[phase] = DatasetMonitor(dataset, self._config.data_sample_warn_time_ms)
self._is_setup = True

@property
Expand Down Expand Up @@ -232,24 +119,3 @@ def get_iterator(
multiprocessing_context=self._config.multiprocessing_context.value if num_workers > 0 else None,
)
)

def _build_and_sample_gpt_dataset(self, name: str, sampling_configs: PhaseSplits[GPTSamplingConfig]):
datasets = GPTDatasetSlice.from_splits(
GPTMemmapDataset(name, self._dataset_prefixes[name]), self._phase_split
).sample(sampling_configs)
if self._config.fim.rate > 0:
datasets = SampledSplitDataset[GPTDatasetSlice](
"fim",
{
phase: GPTFimDataset(self.config.fim, dataset, sampling_configs[phase])
for phase, dataset in datasets.items()
},
)
return datasets

def _build_and_sample_dummy_dataset(self, name: str, sampling_configs: PhaseSplits[GPTSamplingConfig]):
return CopySplitDataset(
f"{name}_split",
GPTRandomDataset(name),
list(sampling_configs),
).sample(sampling_configs)
59 changes: 3 additions & 56 deletions fast_llm/data/dataset/abstract.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
import abc
import typing

from fast_llm.data.data.config import DataConfig
from fast_llm.data.dataset.config import SamplingConfig
from fast_llm.engine.distributed.config import PhaseType
if typing.TYPE_CHECKING:
from fast_llm.data.dataset.config import SamplingConfig


class Dataset(abc.ABC):
Expand All @@ -18,10 +17,6 @@ def name(self) -> str:
A name for the dataset to facilitate identification and debugging.
"""

@abc.abstractmethod
def as_split(self, default_phase: PhaseType = PhaseType.training):
pass


class SampledDataset(Dataset):
"""
Expand All @@ -37,57 +32,9 @@ def __getitem__(self, index: int) -> typing.Any:
def __len__(self) -> int:
pass

def as_split(self, default_phase: PhaseType = PhaseType.training):
return SplitDataset(self.name, {default_phase: self})


class SamplableDataset(Dataset):
# TODO: Move to dataset config?
_data_config_class: typing.ClassVar[type[DataConfig]]

@abc.abstractmethod
def sample(self, config: SamplingConfig) -> SampledDataset:
def sample(self, config: "SamplingConfig") -> SampledDataset:
pass

def as_split(self, default_phase: PhaseType = PhaseType.training) -> "SplitDataset":
return SplitDataset(self.name, {default_phase: self})


_SplittableType = typing.TypeVar("_SplittableType")
_DatasetType = typing.TypeVar("_DatasetType", bound=Dataset)
_SampledDatasetType = typing.TypeVar("_SampledDatasetType", bound=SampledDataset)
_SamplableDatasetType = typing.TypeVar("_SamplableDatasetType", bound=SamplableDataset)


class PhaseSplits(dict[PhaseType, _SplittableType], typing.Generic[_SplittableType]):
pass


class SplitDataset(Dataset, PhaseSplits[_DatasetType], typing.Generic[_DatasetType]):
def __init__(self, name: str, datasets: dict[PhaseType, _DatasetType]):
super().__init__(datasets)
self._name = name

def as_split(self, default_phase: PhaseType = PhaseType.training):
return self

@property
def name(self):
return self._name


class SampledSplitDataset(SplitDataset[_SampledDatasetType], typing.Generic[_SampledDatasetType]):
pass


class SamplableSplitDataset(SplitDataset[_SamplableDatasetType], typing.Generic[_SamplableDatasetType]):
def sample(self, sampling_configs: PhaseSplits[SamplingConfig]):
return SampledSplitDataset(
f"{self.name}_sampled",
{phase: self[phase].sample(sampling_config) for phase, sampling_config in sampling_configs.items()},
)


class CopySplitDataset(SamplableSplitDataset):
def __init__(self, name: str, dataset: _SplittableType, phases: list[PhaseType]):
super().__init__(name, {phase: dataset for phase in phases})
Loading
Loading