diff --git a/docs/source/about_mapstyle_vs_iterable.mdx b/docs/source/about_mapstyle_vs_iterable.mdx index 84bdb108ed4..92b0cb01d7a 100644 --- a/docs/source/about_mapstyle_vs_iterable.mdx +++ b/docs/source/about_mapstyle_vs_iterable.mdx @@ -205,10 +205,6 @@ for epoch in range(n_epochs): pass ``` -## Checkpoint and resuming differences - -If you training loop stops, you may want to restart the training from where it was. To do so you can save a checkpoint of your model and optimizers, as well as your data loader. - To restart the iteration of a map-style dataset, you can simply skip the first examples: ```python diff --git a/docs/source/package_reference/main_classes.mdx b/docs/source/package_reference/main_classes.mdx index 6bfa21d9d83..9c964bc56d5 100644 --- a/docs/source/package_reference/main_classes.mdx +++ b/docs/source/package_reference/main_classes.mdx @@ -170,6 +170,7 @@ The base class [`IterableDataset`] implements an iterable Dataset backed by pyth - rename_column - filter - shuffle + - batch - skip - take - load_state_dict diff --git a/docs/source/stream.mdx b/docs/source/stream.mdx index 93f52cf8e07..df5109b25aa 100644 --- a/docs/source/stream.mdx +++ b/docs/source/stream.mdx @@ -318,6 +318,44 @@ You can filter rows in the dataset based on a predicate function using [`Dataset {'id': 4, 'text': 'Are you looking for Number the Stars (Essential Modern Classics)? Normally, ...'}] ``` +## Batch + +The `batch` method transforms your `IterableDataset` into an iterable of batches. This is particularly useful when you want to work with batches in your training loop or when using frameworks that expect batched inputs. + + + +There is also a "Batch Processing" option when using the `map` function to apply a function to batches of data, which is discussed in the [Map section](#map) above. The `batch` method described here is different and provides a more direct way to create batches from your dataset. + + + +You can use the `batch` method like this: + +```python +from datasets import load_dataset + +# Load a dataset in streaming mode +dataset = load_dataset("some_dataset", split="train", streaming=True) + +# Create batches of 32 samples +batched_dataset = dataset.batch(batch_size=32) + +# Iterate over the batched dataset +for batch in batched_dataset: + print(batch) + break +``` + +In this example, batched_dataset is still an IterableDataset, but each item yielded is now a batch of 32 samples instead of a single sample. +This batching is done on-the-fly as you iterate over the dataset, preserving the memory-efficient nature of IterableDataset. + +The batch method also provides a drop_last_batch parameter. +When set to True, it will discard the last batch if it's smaller than the specified batch_size. +This can be useful in scenarios where your downstream processing requires all batches to be of the same size: + +```python +batched_dataset = dataset.batch(batch_size=32, drop_last_batch=True) +``` + ## Stream in a training loop [`IterableDataset`] can be integrated into a training loop. First, shuffle the dataset: diff --git a/src/datasets/iterable_dataset.py b/src/datasets/iterable_dataset.py index c23f45570b4..94fd11a1b55 100644 --- a/src/datasets/iterable_dataset.py +++ b/src/datasets/iterable_dataset.py @@ -2885,6 +2885,26 @@ def _resolve_features(self): token_per_repo_id=self._token_per_repo_id, ) + def batch(self, batch_size: int, drop_last_batch: bool = False) -> "IterableDataset": + """ + Group samples from the dataset into batches. + + Args: + batch_size (`int`): The number of samples in each batch. + drop_last_batch (`bool`, defaults to `False`): Whether to drop the last incomplete batch. + + Example: + ```py + >>> ds = load_dataset("some_dataset", streaming=True) + >>> batched_ds = ds.batch(batch_size=32) + ``` + """ + + def batch_fn(unbatched): + return {k: [v] for k, v in unbatched.items()} + + return self.map(batch_fn, batched=True, batch_size=batch_size, drop_last_batch=drop_last_batch) + def _concatenate_iterable_datasets( dsets: List[IterableDataset], diff --git a/tests/test_iterable_dataset.py b/tests/test_iterable_dataset.py index 6d21eda3863..232652f1fa3 100644 --- a/tests/test_iterable_dataset.py +++ b/tests/test_iterable_dataset.py @@ -2176,3 +2176,54 @@ def test_resume_dataloader(dataset: IterableDataset): dl = StatefulDataLoader(dataset) dl.load_state_dict(state_dict) assert remaining == list(dl) + + +def test_iterable_dataset_batch(): + # Create a simple IterableDataset + data = [{"id": i, "text": f"Text {i}"} for i in range(10)] + ds = IterableDataset.from_generator(lambda: (x for x in data)) + + # Test with batch_size=3, drop_last_batch=False + batched_ds = ds.batch(batch_size=3, drop_last_batch=False) + batches = list(batched_ds) + + assert len(batches) == 4 # 3 full batches and 1 partial batch + for i, batch in enumerate(batches[:3]): # Check full batches + assert len(batch["id"]) == 3 + assert len(batch["text"]) == 3 + assert batch["id"] == [3 * i, 3 * i + 1, 3 * i + 2] + assert batch["text"] == [f"Text {3*i}", f"Text {3*i+1}", f"Text {3*i+2}"] + + # Check last partial batch + assert len(batches[3]["id"]) == 1 + assert len(batches[3]["text"]) == 1 + assert batches[3]["id"] == [9] + assert batches[3]["text"] == ["Text 9"] + + # Test with batch_size=3, drop_last_batch=True + batched_ds = ds.batch(batch_size=3, drop_last_batch=True) + batches = list(batched_ds) + + assert len(batches) == 3 # Only full batches + for i, batch in enumerate(batches): + assert len(batch["id"]) == 3 + assert len(batch["text"]) == 3 + assert batch["id"] == [3 * i, 3 * i + 1, 3 * i + 2] + assert batch["text"] == [f"Text {3*i}", f"Text {3*i+1}", f"Text {3*i+2}"] + + # Test with batch_size=4 (doesn't evenly divide dataset size) + batched_ds = ds.batch(batch_size=4, drop_last_batch=False) + batches = list(batched_ds) + + assert len(batches) == 3 # 2 full batches and 1 partial batch + for i, batch in enumerate(batches[:2]): # Check full batches + assert len(batch["id"]) == 4 + assert len(batch["text"]) == 4 + assert batch["id"] == [4 * i, 4 * i + 1, 4 * i + 2, 4 * i + 3] + assert batch["text"] == [f"Text {4*i}", f"Text {4*i+1}", f"Text {4*i+2}", f"Text {4*i+3}"] + + # Check last partial batch + assert len(batches[2]["id"]) == 2 + assert len(batches[2]["text"]) == 2 + assert batches[2]["id"] == [8, 9] + assert batches[2]["text"] == ["Text 8", "Text 9"]