Skip to content

(fix) CombinedDataset with more than 2 streaming datasets #164

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jun 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 17 additions & 12 deletions src/litdata/streaming/combined.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from copy import deepcopy
from typing import Any, Dict, Iterator, List, Optional, Sequence

import numpy as np
from torch.utils.data import IterableDataset

from litdata.streaming.dataset import StreamingDataset
Expand Down Expand Up @@ -182,14 +181,14 @@ def __init__(
self,
datasets: List[StreamingDataset],
seed: int,
weights: Sequence[float],
weights: Sequence[Optional[float]],
use_streaming_dataloader: bool,
num_samples_yielded: Any,
iterate_over_all: bool = False,
) -> None:
self._datasets = datasets
self._dataset_iters = [iter(dataset) for dataset in datasets]
self._dataset_indexes = list(range(len(datasets)))
self._dataset_indexes: List[Optional[int]] = list(range(len(datasets)))
self._num_samples_yielded = num_samples_yielded or [0 for _ in range(len(datasets))]
self._original_weights = deepcopy(weights)
self._weights = deepcopy(weights)
Expand All @@ -200,7 +199,9 @@ def __init__(
if num_samples_yielded is not None:
self._num_samples_yielded = num_samples_yielded
for _ in range(sum(num_samples_yielded)):
self._rng.choices(self._dataset_indexes, weights=self._weights, k=1)
choice_indexes: List[int] = [index for index in self._dataset_indexes if index is not None]
choice_weights: List[float] = [w for w in self._weights if w is not None]
self._rng.choices(choice_indexes, weights=choice_weights, k=1)

self._use_streaming_dataloader = use_streaming_dataloader
self._is_done = False
Expand All @@ -209,27 +210,31 @@ def __next__(self) -> Any:
if self._iterate_over_all:
while True:
try:
if len(self._dataset_indexes) > 1:
indexes_left = [index for index in self._dataset_indexes if index is not None]
if len(indexes_left) > 1:
dataset_index = self._get_dataset_index()
elif len(self._dataset_indexes) == 1:
dataset_index = self._dataset_indexes[0]
elif len(indexes_left) == 1:
dataset_index = indexes_left[0]
return self._get_sample(dataset_index)
except StopIteration as e:
if len(self._dataset_indexes) == 1:
if len(indexes_left) == 1:
self._dataset_indexes = list(range(len(self._datasets)))
self._weights = deepcopy(self._original_weights)
raise e

self._dataset_indexes.pop(dataset_index)
self._weights.pop(dataset_index)
self._weights /= np.sum(self._weights)
self._dataset_indexes[dataset_index] = None
self._weights[dataset_index] = None # type: ignore
new_sum = sum([w for w in self._weights if w is not None])
self._weights = [None if w is None else w / new_sum for w in self._weights]

# stop on the first iteration
return self._get_sample(self._get_dataset_index())

def _get_dataset_index(self) -> int:
# randomly select a dataset index
(dataset_index,) = self._rng.choices(self._dataset_indexes, weights=self._weights, k=1)
indexes = [index for index in self._dataset_indexes if index is not None]
weights = [w for w in self._weights if w is not None]
(dataset_index,) = self._rng.choices(indexes, weights=weights, k=1)
return dataset_index

def _get_sample(self, dataset_index: int) -> Any:
Expand Down
26 changes: 26 additions & 0 deletions tests/streaming/test_combined.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import sys
from unittest.mock import ANY, MagicMock

import numpy as np
import pytest
import torch
from litdata.streaming.cache import Cache
Expand Down Expand Up @@ -46,6 +47,31 @@ def test_combined_dataset_num_samples_yield():
assert dataset._iterator._num_samples_yielded == [2, 4]


class Range:
def __init__(self, start, end, step=1):
self.values = list(range(start, end, step))

def set_epoch(self, epoch):
self.values = np.random.RandomState(42 + epoch).permutation(self.values).tolist()

def __iter__(self):
yield from self.values


def test_combined_dataset_iterate_over_all_4_datasets():
dataset = TestCombinedStreamingDataset(
[Range(0, 10), Range(10, 20), Range(20, 30), Range(30, 40)], 42, iterate_over_all=True
)
data = []
for i in range(2):
dataset.set_epoch(i)
data.append(list(dataset))

assert len(data[0]) == 40
assert data[0][-3:] == [14, 13, 16]
assert data[1][-3:] == [14, 18, 17]


def test_combined_dataset_num_samples_yield_iterate_over_all():
dataset = TestCombinedStreamingDataset([range(10), range(0, -10, -1)], 42, iterate_over_all=True)
assert len(dataset) == 20
Expand Down
Loading