Skip to content

[Prototype] Flexible dataset configuration #34

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 16 commits into from
4 changes: 2 additions & 2 deletions fast_llm/data/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ class Dataset(abc.ABC):

@property
@abc.abstractmethod
def name(self):
def name(self) -> str:
"""
A name for the dataset to facilitate identification and debugging.
"""
Expand All @@ -169,7 +169,7 @@ class SamplingConfig(Config):


class SamplableDataset(Dataset):
def sample(self, config: SamplingConfig, data: Data):
def sample(self, config: SamplingConfig, data: Data) -> "SampledDataset":
pass


Expand Down
13 changes: 5 additions & 8 deletions fast_llm/data/gpt/concatenated.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import numpy as np

from fast_llm.data.gpt.config import GPTConcatenatedDatasetConfig
from fast_llm.data.gpt.dataset import GPTIndexedDataset
from fast_llm.utils import padded_cumsum

Expand All @@ -8,10 +9,10 @@ class GPTConcatenatedDataset(GPTIndexedDataset):

def __init__(
self,
name: str,
config: GPTConcatenatedDatasetConfig,
datasets: list[GPTIndexedDataset],
):
self._name = name
self._config = config
self._datasets = datasets
sizes = [dataset.num_documents for dataset in self._datasets]
self._dataset_splits = padded_cumsum(sizes)
Expand All @@ -22,11 +23,7 @@ def num_tokens(self):
return sum(dataset.num_tokens for dataset in self._datasets)

def num_documents(self):
return sum(dataset.num_documents for dataset in self._datasets)

def get_document_sizes(self) -> "np.ndarray":
# TODO: This can be really big.
return np.concatenate([dataset.get_document_sizes() for dataset in self._datasets])
return self._num_documents

def get(self, document: int, offset: int = 0, length: int | None = None):
"""
Expand All @@ -39,4 +36,4 @@ def get(self, document: int, offset: int = 0, length: int | None = None):

@property
def name(self):
return self._name
return self._config.name
Loading
Loading