Skip to content

Commit

Permalink
MultiEpochDataset
Browse files Browse the repository at this point in the history
  • Loading branch information
albertz committed Oct 25, 2024
1 parent 85405bd commit 98e9755
Show file tree
Hide file tree
Showing 2 changed files with 167 additions and 2 deletions.
104 changes: 104 additions & 0 deletions returnn/datasets/meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -1973,6 +1973,110 @@ def is_data_sparse(self, key: str) -> bool:
return self._dataset.is_data_sparse(key)


class MultiEpochDataset(CachedDataset2):
"""
It wraps some dataset, where one outer epoch corresponds to multiple epochs in the inner wrapped dataset.
This can be useful when the inner dataset uses partition_epoch, and we want to cover the whole full epoch.
One specific example when the data is distributed over multiple files,
and for reasonable performance, you want to have the data copied to the local disk,
but all data together is too large to fit on the local disk.
Then :class:`DistributeFilesDataset` is the logical choice,
which solves these issues.
However, you must use some partition_epoch in :class:`DistributeFilesDataset`
such that it will not load all data at once.
To cover all the data, you can use this :class:`MultiEpochDataset`
and set multi_epoch = partition_epoch of the inner dataset.
"""

def __init__(self, *, dataset: Dict[str, Any], multi_epoch: int, **kwargs):
"""
:param dataset: the inner wrapped dataset
:param multi_epoch: how much inner epochs correspond to one outer epoch
"""
super().__init__(**kwargs)
self._dataset = init_dataset(dataset, parent_dataset=self)
assert self._dataset
self._multi_epoch = multi_epoch
self.num_inputs = self._dataset.num_inputs
self.num_outputs = self._dataset.num_outputs
self.labels = self._dataset.labels
self._cur_inner_start_epoch: Optional[int] = None
self._cur_inner_epoch_offset = 0
self._cur_inner_epoch_seq_idx_offset = 0
self._epoch_have_predefined_seq_order = False

def init_seq_order(self, epoch=None, seq_list=None, seq_order=None):
"""init seq order"""
super().init_seq_order(epoch=epoch, seq_list=seq_list, seq_order=seq_order)
self._epoch_have_predefined_seq_order = bool(seq_list or seq_order)
# epoch is 1-based
self._cur_inner_start_epoch = ((epoch - 1) * self._multi_epoch + 1) if epoch is not None else None
self._cur_inner_epoch_offset = 0
self._cur_inner_epoch_seq_idx_offset = 0
self._dataset.init_seq_order(epoch=self._cur_inner_start_epoch, seq_list=seq_list, seq_order=seq_order)

def finish_epoch(self, *, free_resources: bool = False):
"""finish epoch"""
super().finish_epoch(free_resources=free_resources)
self._dataset.finish_epoch(free_resources=free_resources)

def get_all_tags(self) -> List[str]:
"""all tags"""
return self._dataset.get_all_tags()

def get_total_num_seqs(self, *, fast: bool = False) -> int:
"""total num seqs"""
return self._dataset.get_total_num_seqs(fast=fast)

def get_data_keys(self) -> List[str]:
"""data keys"""
return self._dataset.get_data_keys()

def get_target_list(self) -> List[str]:
"""target list"""
return self._dataset.get_target_list()

def _collect_single_seq(self, seq_idx: int) -> Optional[DatasetSeq]:
assert seq_idx >= self._cur_inner_epoch_seq_idx_offset
sub_seq_idx = seq_idx - self._cur_inner_epoch_seq_idx_offset
if not self._dataset.is_less_than_num_seqs(sub_seq_idx):
if self._epoch_have_predefined_seq_order:
return None # predefined seq order, so no multi-epoch handling
if self._cur_inner_start_epoch is None:
return None # there was no epoch given, so no multi-epoch handling
self._cur_inner_epoch_offset += 1
if self._cur_inner_epoch_offset >= self._multi_epoch:
return None # we are done
self._dataset.init_seq_order(epoch=self._cur_inner_start_epoch + self._cur_inner_epoch_offset)
self._cur_inner_epoch_seq_idx_offset = seq_idx
sub_seq_idx = 0
assert self._dataset.is_less_than_num_seqs(sub_seq_idx) # expect that the sub epoch has some seqs
self._dataset.load_seqs(sub_seq_idx, sub_seq_idx + 1)
data = {}
for key in self.get_data_keys():
data[key] = self._dataset.get_data(sub_seq_idx, key)
seq_tag = self._dataset.get_tag(sub_seq_idx)
return DatasetSeq(seq_idx=seq_idx, seq_tag=seq_tag, features=data)

def get_data_dim(self, key: str) -> int:
"""data dim"""
return self._dataset.get_data_dim(key)

def get_data_shape(self, data_key: str) -> List[int]:
"""data shape"""
return self._dataset.get_data_shape(data_key)

def get_data_dtype(self, key: str) -> str:
"""data dtype"""
return self._dataset.get_data_dtype(key)

def is_data_sparse(self, key: str) -> bool:
"""is data sparse"""
return self._dataset.is_data_sparse(key)


class AnythingDataset(Dataset):
"""
An infinite dataset, creating dummy (zero) data on the fly,
Expand Down
65 changes: 63 additions & 2 deletions tests/test_Dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,12 @@
from returnn.util import better_exchook


def dummy_iter_dataset(dataset: Dataset) -> List[DatasetSeq]:
def dummy_iter_dataset(dataset: Dataset, *, epoch: int = 1) -> List[DatasetSeq]:
"""
:param Dataset dataset:
:return: seqs
"""
dataset.init_seq_order(epoch=1)
dataset.init_seq_order(epoch=epoch)
data_keys = dataset.get_data_keys()
seq_idx = 0
seqs = []
Expand Down Expand Up @@ -1147,6 +1147,67 @@ def _repeat2(input_iter: Iterator[TensorDict], **kwargs) -> Iterator[TensorDict]
assert func(2) == 21


def test_MultiEpochDataset():
from returnn.datasets.meta import MultiEpochDataset
from returnn.datasets.cached2 import CachedDataset2

in_dim, out_dim = 11, 7
seq_len = 5
inner_num_seqs = 10

class _MyDataset(CachedDataset2):
def __init__(self):
super().__init__()
self.num_inputs = in_dim
self.num_outputs = {"classes": out_dim}

# noinspection PyShadowingNames
def init_seq_order(self, epoch=None, seq_list=None, seq_order=None):
"""init seq order"""
super().init_seq_order(epoch=epoch, seq_list=seq_list, seq_order=seq_order)
self._num_seqs = inner_num_seqs

def _collect_single_seq(self, seq_idx: int) -> Optional[DatasetSeq]:
if seq_idx >= self._num_seqs:
return None
return DatasetSeq(
seq_idx=seq_idx,
seq_tag=repr({"epoch": self.epoch, "seq_idx": seq_idx}),
features=numpy.zeros((seq_len, in_dim)),
targets={"classes": numpy.zeros((seq_len,), dtype=numpy.int32)},
)

inner_dataset = _MyDataset()
inner_dataset.initialize()

multi_epoch = 3
dataset = MultiEpochDataset(dataset=inner_dataset, multi_epoch=multi_epoch)
for outer_epoch in [1, 7]:
seqs = dummy_iter_dataset(dataset, epoch=outer_epoch)
assert len(seqs) == inner_num_seqs * multi_epoch
outer_seq_idx = 0
sub_ep = (outer_epoch - 1) * multi_epoch + 1 # 1-based
sub_seq_idx = 0
for seq in seqs:
assert outer_seq_idx == seq.seq_idx
assert seq.features["data"].shape == (seq_len, in_dim)
assert seq.features["classes"].shape == (seq_len,)
print("seq:", seq.seq_tag)
d = eval(seq.seq_tag) # seq tag is dict repr
assert isinstance(d, dict)
assert d["epoch"] == sub_ep
assert d["seq_idx"] == sub_seq_idx
# Calc next expected values.
if sub_seq_idx >= inner_num_seqs - 1:
sub_seq_idx = 0
sub_ep += 1
else:
sub_seq_idx += 1
outer_seq_idx += 1
assert outer_seq_idx == len(seqs)
assert sub_ep == outer_epoch * multi_epoch + 1 and sub_seq_idx == 0


if __name__ == "__main__":
better_exchook.install()
if len(sys.argv) <= 1:
Expand Down

0 comments on commit 98e9755

Please sign in to comment.