Skip to content

Improve dataset sampling #138

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 17 commits into from
Feb 12, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Megatron-LM
19 changes: 16 additions & 3 deletions fast_llm/data/data/config.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,23 @@
import typing

from fast_llm.config import Config, config_class
from fast_llm.data.dataset.config import SamplingConfig
from fast_llm.config import Config, Field, FieldHint, FieldUpdate, config_class
from fast_llm.data.dataset.config import SamplingConfig, SamplingData


@config_class()
class SamplingDefaultConfig(SamplingConfig):
seed: int = FieldUpdate(
default=784569,
desc="Seed for random sampling.",
hint=FieldHint.feature,
)


@config_class()
class DataConfig(Config):
_abstract = True
_sampling_config_class: typing.ClassVar[type[SamplingConfig]]
_sampling_config_class: typing.ClassVar[type[SamplingData]]

sampling: SamplingConfig = Field(
default_factory=SamplingConfig, desc="Default configuration for dataset sampling."
)
20 changes: 17 additions & 3 deletions fast_llm/data/data/gpt/config.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,28 @@
import logging

from fast_llm.config import Field, FieldHint, check_field, config_class
from fast_llm.config import Field, FieldHint, FieldUpdate, check_field, config_class
from fast_llm.data.config import MultiprocessingContext, TokenizerConfig
from fast_llm.data.data.config import DataConfig
from fast_llm.data.dataset.gpt.config import GPTLegacyConfig, GPTLegacyDatasetConfig, GPTSampledDatasetConfig
from fast_llm.data.data.config import DataConfig, SamplingDefaultConfig
from fast_llm.data.dataset.gpt.config import (
GPTLegacyConfig,
GPTLegacyDatasetConfig,
GPTSampledDatasetConfig,
GPTSamplingConfig,
ShufflingType,
)
from fast_llm.engine.distributed.config import PhaseType
from fast_llm.utils import Assert

logger = logging.getLogger(__name__)


@config_class()
class GPTSamplingDefaultConfig(SamplingDefaultConfig, GPTSamplingConfig):
gpu: bool = FieldUpdate(default=True)
use_loss_masking_spans: bool = FieldUpdate(default=False)
shuffle: ShufflingType = FieldUpdate(default=ShufflingType.epoch)


@config_class()
class GPTDataConfig(DataConfig, GPTLegacyConfig):
"""
Expand All @@ -31,6 +44,7 @@ class GPTDataConfig(DataConfig, GPTLegacyConfig):
desc="Configuration for the dataset(s).",
hint=FieldHint.core,
)
sampling: GPTSamplingDefaultConfig = FieldUpdate(default_factory=GPTSamplingDefaultConfig)
data_sample_warn_time_ms: float = Field(
default=1000,
desc="Warn if a sample takes too long to load.",
Expand Down
9 changes: 4 additions & 5 deletions fast_llm/data/data/gpt/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from fast_llm.data.data.abstract import Data
from fast_llm.data.data.gpt.config import GPTDataConfig
from fast_llm.data.dataset.abstract import SampledDataset
from fast_llm.data.dataset.gpt.config import GPTSamplingConfig
from fast_llm.data.dataset.gpt.config import GPTSamplingData
from fast_llm.data.dataset.gpt.sampled import GPTSample
from fast_llm.data.dataset.monitor import DatasetMonitor
from fast_llm.data.iterator import SampledDatasetIterator
Expand Down Expand Up @@ -91,18 +91,17 @@ def setup(
if num_samples > 0:
# TODO: Do the check earlier.
assert phase in self._config.datasets
sampling_config = GPTSamplingConfig(
sampling = GPTSamplingData(
num_samples=samples_per_phase[phase],
seed=self._distributed_config.seed,
config=self._config.sampling,
cache_directory=self._cache_directory,
distributed=distributed,
phase=phase,
sequence_length=self._max_sequence_length,
vocab_size=self._vocab_size,
tokenizer=self._tokenizer,
use_loss_masking_spans=self._config.use_loss_masking_spans,
)
dataset = self._config.datasets[phase].build_and_sample(sampling_config)
dataset = self._config.datasets[phase].build_and_sample(sampling)
self._datasets[phase] = DatasetMonitor(dataset, self._config.data_sample_warn_time_ms)

safe_barrier(self._distributed.world_group, "data_preparation", timeout)
Expand Down
4 changes: 2 additions & 2 deletions fast_llm/data/dataset/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import typing

if typing.TYPE_CHECKING:
from fast_llm.data.dataset.config import SamplingConfig
from fast_llm.data.dataset.config import SamplingData


class Dataset(abc.ABC):
Expand Down Expand Up @@ -36,5 +36,5 @@ def __len__(self) -> int:
class SamplableDataset(Dataset):

@abc.abstractmethod
def sample(self, config: "SamplingConfig") -> SampledDataset:
def sample(self, config: "SamplingData") -> SampledDataset:
pass
4 changes: 2 additions & 2 deletions fast_llm/data/dataset/blended.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import numpy as np

from fast_llm.data.dataset.abstract import SampledDataset
from fast_llm.data.dataset.config import SamplingConfig
from fast_llm.data.dataset.config import SamplingData
from fast_llm.utils import Assert, normalize_probabilities

logger = logging.getLogger(__name__)
Expand All @@ -23,7 +23,7 @@ def __init__(
name: str,
datasets: list[SampledDataset],
weights: list[float],
sampling_config: SamplingConfig,
sampling_config: SamplingData,
):
self._name = name
assert len(datasets) > 0
Expand Down
79 changes: 65 additions & 14 deletions fast_llm/data/dataset/config.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import dataclasses
import functools
import itertools
import math
import pathlib
import typing

from fast_llm.config import Config, Field, FieldHint, check_field, config_class
from fast_llm.config import Config, Field, FieldHint, FieldVerboseLevel, check_field, config_class
from fast_llm.data.dataset.abstract import SamplableDataset, SampledDataset
from fast_llm.engine.distributed.config import PhaseType
from fast_llm.utils import Assert
Expand All @@ -14,15 +15,43 @@
from fast_llm.engine.distributed.distributed import Distributed


@config_class()
class SamplingConfig(Config):
seed: int | None = Field(
default=None,
desc="Seed for random sampling.",
hint=FieldHint.feature,
)

@property
def updates(self) -> dict[str, typing.Any]:
return {
key: value
for key, value in self.to_serialized(verbose=FieldVerboseLevel.everything).items()
if value is not None
}


@dataclasses.dataclass(kw_only=True)
class SamplingConfig:
class SamplingData:
# TODO: Have a separate configuration (subset?) for `build`?
config: SamplingConfig
num_samples: int
seed: int
cache_directory: pathlib.Path | None
# TODO: This prevents the sampling config from being pickled in multiprocessing.
distributed: "Distributed"
phase: PhaseType
# Using a mutable rather than an int so it's shared with all copies made with `update`.
_rank_counter: typing.Iterator[int] = itertools.count

def update(self, config: SamplingConfig, **kwargs):
if config_updates := config.updates:
kwargs["config"] = self.config.to_copy(config_updates)
return dataclasses.replace(self, **kwargs) if kwargs else self

def get_next_rank(self) -> int:
# Counter that loops over ranks to try to distribute workloads evenly between ranks.
return next(self._rank_counter()) % self.distributed.config.world_size


@config_class()
Expand All @@ -34,10 +63,9 @@ class DatasetConfig(Config):
class SampledDatasetConfig(DatasetConfig):
"""
A sampled dataset containing a prepared list of samples to be indexed sequentially (as-is) during training.
(See `fast_llm.data.sampler.Sampler`.)
"""

def build_and_sample(self, config: SamplingConfig) -> SampledDataset:
def build_and_sample(self, sampling: SamplingData) -> SampledDataset:
raise NotImplementedError()


Expand All @@ -46,13 +74,13 @@ class SamplableDatasetConfig(SampledDatasetConfig):
def build(self) -> SamplableDataset:
raise NotImplementedError()

def build_and_sample(self, config: SamplingConfig) -> SampledDataset:
return self.build().sample(config)
def build_and_sample(self, sampling: SamplingData) -> SampledDataset:
return self.build().sample(sampling)


@config_class()
class IndexedDatasetConfig(SamplableDatasetConfig):
def build(self) -> "IndexedDataset":
def _build(self) -> "IndexedDataset":
raise NotImplementedError()


Expand Down Expand Up @@ -128,6 +156,29 @@ def _build[T: DatasetSlice](self, cls: type[T]) -> T:
)


@config_class()
class SampledDatasetUpdateConfig(SampledDatasetConfig):
"""
Wrap a dataset to explicitly sample from it and optionally update its configuration parameters.
Only explicitly set parameters (not None) will be updated, other will still be taken from `build_and_sample`'s argument.
"""

_abstract = False
sampling: SamplingConfig = Field(
default_factory=SamplingConfig,
desc="Optional override to sampling configuration parameters.",
hint=FieldHint.core,
)
dataset: SampledDatasetConfig = Field(
default_factory=SampledDatasetConfig,
desc="The dataset to sample from.",
hint=FieldHint.core,
)

def build_and_sample(self, data: SamplingData) -> SampledDataset:
return self.dataset.build_and_sample(data.update(self.sampling))


@config_class()
class BlendedDatasetConfig(SampledDatasetConfig):
_abstract = False
Expand Down Expand Up @@ -159,7 +210,7 @@ def _validate(self) -> None:

def build_and_sample(
self,
config: SamplingConfig,
sampling: SamplingData,
) -> SampledDataset:
from fast_llm.data.dataset.blended import BlendedDataset

Expand All @@ -172,13 +223,13 @@ def build_and_sample(
dataset.build_and_sample(
# Blending is deterministic and the error will never be higher than 1.
dataclasses.replace(
config,
sampling,
num_samples=(
math.ceil(weight * (config.num_samples + 5 * (config.num_samples * (1 - weight)) ** 0.5))
math.ceil(weight * (sampling.num_samples + 5 * (sampling.num_samples * (1 - weight)) ** 0.5))
if self.legacy
else math.ceil(weight * config.num_samples) + 1
else math.ceil(weight * sampling.num_samples) + 1
),
seed=config.seed + i * (0 if self.legacy else 697),
config=sampling.config.to_copy({"seed": sampling.config.seed + i * (0 if self.legacy else 697)}),
),
)
for i, (dataset, weight) in enumerate(zip(self.datasets, self.weights, strict=True))
Expand All @@ -188,5 +239,5 @@ def build_and_sample(
self.name,
sampled_datasets,
self.weights,
config,
sampling,
)
Loading