Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MultiEpochDataset, and some other smaller things #1639

Merged
merged 6 commits into from
Oct 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions returnn/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -576,6 +576,10 @@ def execute_main_task():
assert data, "set forward_data"
else:
data = init_dataset(config.opt_typed_value("forward_data"))
# engine.epoch is usually the epoch of the loaded checkpoint,
# or what EngineBase.get_epoch_model will return.
# You can have both load and load_epoch, where load points to the checkpoint,
# and load_epoch is some other epoch, which you will get here for the dataset.
data.init_seq_order(epoch=engine.epoch or 1)
forward_callback = config.typed_value("forward_callback")
assert forward_callback, "no forward_callback specified"
Expand Down
106 changes: 105 additions & 1 deletion returnn/datasets/meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -1875,7 +1875,7 @@ def _load_dataset(self, epoch: int):

def init_seq_order(self, epoch=None, seq_list=None, seq_order=None):
"""init seq order"""
super().init_seq_order()
super().init_seq_order(epoch=epoch, seq_list=seq_list, seq_order=seq_order)
if epoch is None:
if seq_list is not None or seq_order is not None:
raise ValueError(f"{self}: epoch is None, but given seq_list or seq_order, not supported")
Expand Down Expand Up @@ -1973,6 +1973,110 @@ def is_data_sparse(self, key: str) -> bool:
return self._dataset.is_data_sparse(key)


class MultiEpochDataset(CachedDataset2):
"""
It wraps some dataset, where one outer epoch corresponds to multiple epochs in the inner wrapped dataset.

This can be useful when the inner dataset uses partition_epoch, and we want to cover the whole full epoch.

One specific example when the data is distributed over multiple files,
and for reasonable performance, you want to have the data copied to the local disk,
but all data together is too large to fit on the local disk.
Then :class:`DistributeFilesDataset` is the logical choice,
which solves these issues.
However, you must use some partition_epoch in :class:`DistributeFilesDataset`
such that it will not load all data at once.
To cover all the data, you can use this :class:`MultiEpochDataset`
and set multi_epoch = partition_epoch of the inner dataset.
"""

def __init__(self, *, dataset: Dict[str, Any], multi_epoch: int, **kwargs):
"""
:param dataset: the inner wrapped dataset
:param multi_epoch: how much inner epochs correspond to one outer epoch
"""
super().__init__(**kwargs)
self._dataset = init_dataset(dataset, parent_dataset=self)
assert self._dataset
self._multi_epoch = multi_epoch
self.num_inputs = self._dataset.num_inputs
self.num_outputs = self._dataset.num_outputs
self.labels = self._dataset.labels
self._cur_inner_start_epoch: Optional[int] = None
self._cur_inner_epoch_offset = 0
self._cur_inner_epoch_seq_idx_offset = 0
self._epoch_have_predefined_seq_order = False

def init_seq_order(self, epoch=None, seq_list=None, seq_order=None):
"""init seq order"""
super().init_seq_order(epoch=epoch, seq_list=seq_list, seq_order=seq_order)
self._epoch_have_predefined_seq_order = bool(seq_list or seq_order)
# epoch is 1-based
self._cur_inner_start_epoch = ((epoch - 1) * self._multi_epoch + 1) if epoch is not None else None
self._cur_inner_epoch_offset = 0
self._cur_inner_epoch_seq_idx_offset = 0
self._dataset.init_seq_order(epoch=self._cur_inner_start_epoch, seq_list=seq_list, seq_order=seq_order)

def finish_epoch(self, *, free_resources: bool = False):
"""finish epoch"""
super().finish_epoch(free_resources=free_resources)
self._dataset.finish_epoch(free_resources=free_resources)

def get_all_tags(self) -> List[str]:
"""all tags"""
return self._dataset.get_all_tags()

def get_total_num_seqs(self, *, fast: bool = False) -> int:
"""total num seqs"""
return self._dataset.get_total_num_seqs(fast=fast)

def get_data_keys(self) -> List[str]:
"""data keys"""
return self._dataset.get_data_keys()

def get_target_list(self) -> List[str]:
"""target list"""
return self._dataset.get_target_list()

def _collect_single_seq(self, seq_idx: int) -> Optional[DatasetSeq]:
assert seq_idx >= self._cur_inner_epoch_seq_idx_offset
sub_seq_idx = seq_idx - self._cur_inner_epoch_seq_idx_offset
if not self._dataset.is_less_than_num_seqs(sub_seq_idx):
if self._epoch_have_predefined_seq_order:
return None # predefined seq order, so no multi-epoch handling
if self._cur_inner_start_epoch is None:
return None # there was no epoch given, so no multi-epoch handling
self._cur_inner_epoch_offset += 1
if self._cur_inner_epoch_offset >= self._multi_epoch:
return None # we are done
self._dataset.init_seq_order(epoch=self._cur_inner_start_epoch + self._cur_inner_epoch_offset)
self._cur_inner_epoch_seq_idx_offset = seq_idx
sub_seq_idx = 0
assert self._dataset.is_less_than_num_seqs(sub_seq_idx) # expect that the sub epoch has some seqs
self._dataset.load_seqs(sub_seq_idx, sub_seq_idx + 1)
data = {}
for key in self.get_data_keys():
data[key] = self._dataset.get_data(sub_seq_idx, key)
seq_tag = self._dataset.get_tag(sub_seq_idx)
return DatasetSeq(seq_idx=seq_idx, seq_tag=seq_tag, features=data)

def get_data_dim(self, key: str) -> int:
"""data dim"""
return self._dataset.get_data_dim(key)

def get_data_shape(self, data_key: str) -> List[int]:
"""data shape"""
return self._dataset.get_data_shape(data_key)

def get_data_dtype(self, key: str) -> str:
"""data dtype"""
return self._dataset.get_data_dtype(key)

def is_data_sparse(self, key: str) -> bool:
"""is data sparse"""
return self._dataset.is_data_sparse(key)


class AnythingDataset(Dataset):
"""
An infinite dataset, creating dummy (zero) data on the fly,
Expand Down
6 changes: 6 additions & 0 deletions returnn/engine/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,12 @@ def is_first_epoch_after_pretrain(self):
"""
return self.pretrain and self.epoch == self.pretrain.get_train_num_epochs() + 1

def set_epoch(self, epoch: int):
"""
Set the current epoch.
"""
self.epoch = epoch

def forward_with_callback(self, *, dataset: Dataset, callback: ForwardCallbackIface):
"""
Iterate through the dataset, calling `forward_step` from user config,
Expand Down
9 changes: 6 additions & 3 deletions returnn/torch/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,11 @@ def init_train_from_config(
self._train_step_func = self.config.typed_value("train_step")
assert self._train_step_func, "train_step not defined"

def set_epoch(self, epoch: int):
"""set epoch"""
super().set_epoch(epoch)
self._epoch_mp_shared.value = epoch

def train(self):
"""
Main training loop.
Expand All @@ -243,9 +248,7 @@ def train(self):
)
self.epoch = self._start_epoch - 1
while self.epoch + 1 <= self._final_epoch:
self.epoch += 1
self._epoch_mp_shared.value = self.epoch

self.set_epoch(self.epoch + 1)
self.init_train_epoch()
self.train_epoch()

Expand Down
65 changes: 63 additions & 2 deletions tests/test_Dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,12 @@
from returnn.util import better_exchook


def dummy_iter_dataset(dataset: Dataset) -> List[DatasetSeq]:
def dummy_iter_dataset(dataset: Dataset, *, epoch: int = 1) -> List[DatasetSeq]:
"""
:param Dataset dataset:
:return: seqs
"""
dataset.init_seq_order(epoch=1)
dataset.init_seq_order(epoch=epoch)
data_keys = dataset.get_data_keys()
seq_idx = 0
seqs = []
Expand Down Expand Up @@ -1147,6 +1147,67 @@ def _repeat2(input_iter: Iterator[TensorDict], **kwargs) -> Iterator[TensorDict]
assert func(2) == 21


def test_MultiEpochDataset():
from returnn.datasets.meta import MultiEpochDataset
from returnn.datasets.cached2 import CachedDataset2

in_dim, out_dim = 11, 7
seq_len = 5
inner_num_seqs = 10

class _MyDataset(CachedDataset2):
def __init__(self):
super().__init__()
self.num_inputs = in_dim
self.num_outputs = {"classes": out_dim}

# noinspection PyShadowingNames
def init_seq_order(self, epoch=None, seq_list=None, seq_order=None):
"""init seq order"""
super().init_seq_order(epoch=epoch, seq_list=seq_list, seq_order=seq_order)
self._num_seqs = inner_num_seqs

def _collect_single_seq(self, seq_idx: int) -> Optional[DatasetSeq]:
if seq_idx >= self._num_seqs:
return None
return DatasetSeq(
seq_idx=seq_idx,
seq_tag=repr({"epoch": self.epoch, "seq_idx": seq_idx}),
features=numpy.zeros((seq_len, in_dim)),
targets={"classes": numpy.zeros((seq_len,), dtype=numpy.int32)},
)

inner_dataset = _MyDataset()
inner_dataset.initialize()

multi_epoch = 3
dataset = MultiEpochDataset(dataset=inner_dataset, multi_epoch=multi_epoch)
for outer_epoch in [1, 7]:
seqs = dummy_iter_dataset(dataset, epoch=outer_epoch)
assert len(seqs) == inner_num_seqs * multi_epoch
outer_seq_idx = 0
sub_ep = (outer_epoch - 1) * multi_epoch + 1 # 1-based
sub_seq_idx = 0
for seq in seqs:
assert outer_seq_idx == seq.seq_idx
assert seq.features["data"].shape == (seq_len, in_dim)
assert seq.features["classes"].shape == (seq_len,)
print("seq:", seq.seq_tag)
d = eval(seq.seq_tag) # seq tag is dict repr
assert isinstance(d, dict)
assert d["epoch"] == sub_ep
assert d["seq_idx"] == sub_seq_idx
# Calc next expected values.
if sub_seq_idx >= inner_num_seqs - 1:
sub_seq_idx = 0
sub_ep += 1
else:
sub_seq_idx += 1
outer_seq_idx += 1
assert outer_seq_idx == len(seqs)
assert sub_ep == outer_epoch * multi_epoch + 1 and sub_seq_idx == 0


if __name__ == "__main__":
better_exchook.install()
if len(sys.argv) <= 1:
Expand Down
Loading
Loading