Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add repeat method to datasets #7198

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Next Next commit
implement repeat method for iterable dataset
  • Loading branch information
alex-hh committed Oct 4, 2024
commit df3a5dc265333d966b5f12128fd7534054307a50
74 changes: 74 additions & 0 deletions src/datasets/iterable_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1483,6 +1483,57 @@ def n_shards(self) -> int:
return self.ex_iterable.n_shards


class RepeatExamplesIterable(_BaseExamplesIterable):
"""
Iterable that repeats the underlying iterable a given number of times.
It does not duplicate shards, so that duplicate shards are not seen in
the same iteration.
"""

def __init__(
self,
ex_iterable: _BaseExamplesIterable,
num_times: int,
alex-hh marked this conversation as resolved.
Show resolved Hide resolved
):
super().__init__()
self.ex_iterable = ex_iterable
self.num_times = num_times

def _init_state_dict(self) -> dict:
self._state_dict = {
"repeat_index": 0,
"ex_iterable": self.ex_iterable._init_state_dict(),
}
return self._state_dict

def __iter__(self):
if self.num_times is None:
while True:
yield from self.ex_iterable
if self._state_dict:
self._state_dict["repeat_index"] += 1
else:
for _ in range(self.num_times):
yield from self.ex_iterable
if self._state_dict:
self._state_dict["repeat_index"] += 1

def shuffle_data_sources(self, generator: np.random.Generator) -> "RepeatExamplesIterable":
"""Shuffle the underlying iterable, then repeat."""
return RepeatExamplesIterable(self.ex_iterable.shuffle_data_sources(generator), num_times=self.num_times)

def shard_data_sources(self, worker_id: int, num_workers: int) -> "RepeatExamplesIterable":
"""Shard, then repeat shards."""
return RepeatExamplesIterable(
self.ex_iterable.shard_data_sources(worker_id, num_workers),
num_times=self.num_times,
)

@property
def n_shards(self) -> int:
return self.ex_iterable.n_shards


class TakeExamplesIterable(_BaseExamplesIterable):
def __init__(
self,
Expand Down Expand Up @@ -2513,6 +2564,29 @@ def skip(self, n: int) -> "IterableDataset":
token_per_repo_id=self._token_per_repo_id,
)

def repeat(self, num_times: Optional[int] = None) -> "IterableDataset":
"""
Create a new [`IterableDataset`] that repeats the underlying dataset `num_times` times.

N.B. Duplicate data is never seen in the same iteration, even after shuffling:
ds.repeat(n).shuffle(seed=42) is equivalent to ds.shuffle(seed=42).repeat(n)

Args:
num_times (`int`) or (`None`):
Number of times to repeat the dataset. If `None`, the dataset will be repeated indefinitely.

Example:
"""
return IterableDataset(
ex_iterable=RepeatExamplesIterable(self._ex_iterable, num_times=num_times),
info=self._info,
split=self._split,
formatting=self._formatting,
shuffling=copy.deepcopy(self._shuffling),
distributed=copy.deepcopy(self._distributed),
token_per_repo_id=self._token_per_repo_id,
)

def take(self, n: int) -> "IterableDataset":
"""
Create a new [`IterableDataset`] with only the first `n` elements.
Expand Down