Skip to content

Commit

Permalink
PT engine, report train complete status, expected remaining time
Browse files Browse the repository at this point in the history
  • Loading branch information
albertz committed Oct 18, 2024
1 parent c3878ec commit a8aeea1
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 11 deletions.
5 changes: 4 additions & 1 deletion returnn/torch/data/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,9 @@ def collate_batch(batch: List[Dict[str, numpy.ndarray]]) -> Dict[str, Union[torc

res = {}
for key in data_keys:
if key == "num_seqs":
res[key] = batch[0][key] # it should always be the same
continue
ls = [create_tensor(sample[key]) for sample in batch]
if not ls:
raise ValueError("batch is empty?")
Expand Down Expand Up @@ -116,7 +119,7 @@ def __iter__(self):

if not chunking_data_keys:
chunking_data_keys = list(data_dict.keys()) # use all if not configured separately
chunking_data_key_black_list = ["seq_tag"]
chunking_data_key_black_list = ["seq_tag", "seq_idx", "num_seqs"]
for key in chunking_data_key_black_list:
if key in chunking_data_keys:
chunking_data_keys.remove(key)
Expand Down
11 changes: 11 additions & 0 deletions returnn/torch/data/returnn_dataset_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,13 @@ def __iter__(self) -> Iterable[Dict[str, numpy.ndarray]]:
"""
:return: generator providing data samples in the form of a dict data_key -> data
"""
# noinspection PyBroadException
try:
num_seqs = self._dataset.num_seqs
except Exception: # might not work for all datasets
num_seqs = -1
num_seqs = numpy.array(num_seqs)

try:
data_keys = self._dataset.get_data_keys()

Expand All @@ -83,6 +90,10 @@ def __iter__(self) -> Iterable[Dict[str, numpy.ndarray]]:
self._dataset.load_seqs(seq_index, seq_index + 1)
data = {data_key: self._dataset.get_data(seq_index, data_key) for data_key in data_keys}
data["seq_tag"] = str_to_numpy_array(self._dataset.get_tag(seq_index))
data["seq_idx"] = numpy.array(seq_index)
# It's slightly redundant to have num_seqs in each entry,
# but it's difficult to pass this back to the main proc otherwise.
data["num_seqs"] = num_seqs
yield data
seq_index += 1

Expand Down
62 changes: 52 additions & 10 deletions returnn/torch/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,7 @@ def train_epoch(self):
accumulated_losses_dict = NumbersDict()
accumulated_inv_norm_factors_dict = NumbersDict()
step_idx = 0
epoch_start_time = time.time()
epoch_start_time = time.monotonic()

data_iter = iter(self._train_dataloader)
elapsed_computation_time = 0
Expand Down Expand Up @@ -339,12 +339,14 @@ def train_epoch(self):
zero_grad_next_step = True
cur_count_grad_accum = 0
extern_data = None
num_seqs = None
last_seq_idx = 0
try:
while True:
with torch.no_grad():
extern_data_raw = next(data_iter, None)

step_begin_time = time.time()
step_begin_time = time.monotonic()

_has_data = torch.tensor([extern_data_raw is not None], dtype=torch.int8)
if self._torch_distributed_ctx:
Expand All @@ -353,6 +355,22 @@ def train_epoch(self):
torch.distributed.all_reduce(_has_data, op=torch.distributed.ReduceOp.MIN)
if not _has_data[0]:
break
num_seqs_ = (
int(extern_data_raw["num_seqs"]) if extern_data_raw.get("num_seqs", None) is not None else -1
)
last_seq_idx_ = extern_data_raw["seq_idx"].max()
assert last_seq_idx_ >= last_seq_idx
last_seq_idx = int(last_seq_idx_)
del last_seq_idx_
if step_idx == 0:
if num_seqs_ >= 0:
print(f"Epoch {self.epoch} num_seqs: {num_seqs_}", file=log.v5)
num_seqs = num_seqs_
elif num_seqs_ >= 0:
assert num_seqs_ == num_seqs
del num_seqs_
if num_seqs is not None:
assert last_seq_idx < num_seqs

# clear the gradients when every gradient accumulation loop starts
if zero_grad_next_step:
Expand Down Expand Up @@ -404,7 +422,8 @@ def train_epoch(self):
if self._torch_distributed_ctx:
self._torch_distributed_ctx.step_after_param_update(module=self._pt_model, epoch_step_idx=step_idx)

step_duration = time.time() - step_begin_time
step_end_time = time.monotonic()
step_duration = step_end_time - step_begin_time
elapsed_computation_time += step_duration

accumulated_losses_dict += losses_dict
Expand All @@ -415,6 +434,9 @@ def train_epoch(self):
step=step_idx,
eval_info=dict(eval_info),
step_duration=step_duration,
start_elapsed=step_end_time - epoch_start_time,
seq_idx=last_seq_idx,
num_seqs=num_seqs,
batch_size_info=_get_batch_size_info(extern_data) if self._log_batch_size else None,
log_memory_usage_device=self._device if self._log_memory_usage else None,
)
Expand All @@ -436,7 +458,7 @@ def train_epoch(self):
help_on_torch_exception(exc, step_idx=step_idx, model=self._orig_model, extern_data=extern_data)
raise

elapsed = time.time() - epoch_start_time
elapsed = time.monotonic() - epoch_start_time
elapsed_computation_percentage = elapsed_computation_time / elapsed
print(
"Trained %i steps, %s elapsed (%.1f%% computing time)"
Expand Down Expand Up @@ -1008,7 +1030,7 @@ def forward_with_callback(self, *, dataset: Dataset, callback: ForwardCallbackIf
assert isinstance(dataset, Dataset)
assert isinstance(callback, ForwardCallbackIface)

epoch_start_time = time.time()
epoch_start_time = time.monotonic()
elapsed_computation_time = 0.0

self._pt_model.eval()
Expand Down Expand Up @@ -1087,7 +1109,7 @@ def _get_dim_tag_wo_batch(dim: Dim) -> Dim:

step_idx = 0
for extern_data_raw in data_loader:
step_begin_time = time.time()
step_begin_time = time.monotonic()
if self._forward_step_expected_outputs:
# Also resets any dyn dims, which might have been set in the prev step.
self._forward_step_expected_outputs.reset_content()
Expand Down Expand Up @@ -1121,7 +1143,7 @@ def _get_dim_tag_wo_batch(dim: Dim) -> Dim:
model_outputs_per_batch.data[k] = _get_tensor_wo_batch_numpy(v)
callback.process_seq(seq_tag=seq_tag, outputs=model_outputs_per_batch)

elapsed_computation_time += time.time() - step_begin_time
elapsed_computation_time += time.monotonic() - step_begin_time
_print_process(
report_prefix,
step=step_idx,
Expand All @@ -1132,7 +1154,7 @@ def _get_dim_tag_wo_batch(dim: Dim) -> Dim:

callback.finish()

elapsed = time.time() - epoch_start_time
elapsed = time.monotonic() - epoch_start_time
elapsed_computation_percentage = elapsed_computation_time / elapsed
print(
"Forward %i steps, %s elapsed (%.1f%% computing time)"
Expand Down Expand Up @@ -1202,20 +1224,26 @@ def _to_raw(n: Union[int, float, Tensor]):

def _print_process(
report_prefix: str,
*,
step: int,
eval_info: Optional[Dict[str, Any]] = None,
batch_size_info: Optional[Dict[str, Any]] = None,
step_duration: Optional[float] = None,
start_elapsed: Optional[float] = None,
seq_idx: Optional[int] = None,
num_seqs: Optional[int] = None,
log_memory_usage_device: Optional[str] = None,
):
"""
Similar but simplified from TF engine _print_process.
:param report_prefix:
:param step:
:param step: for this epoch
:param eval_info:
:param batch_size_info:
:param step_duration:
:param step_duration: time elapsed for this step (secs)
:param start_elapsed: time elapsed since epoch start (secs)
:param num_seqs: total number of sequences for this epoch
:param log_memory_usage_device: if given, will log memory usage (peak allocated memory)
:return: nothing, will be printed to log
"""
Expand All @@ -1233,6 +1261,20 @@ def _print_process(
]
if step_duration is not None:
info += ["%.3f sec/step" % step_duration]
if start_elapsed is not None:
info += ["elapsed %s" % hms(start_elapsed)]
if num_seqs is not None:
assert seq_idx is not None and start_elapsed is not None # unexpected combination...
complete = (seq_idx + 1) / num_seqs
assert 1 >= complete > 0, f"{step} step, {num_seqs} num_seqs"
total_time_estimated = start_elapsed / complete
remaining_estimated = total_time_estimated - start_elapsed
info += [
"exp. remaining %s" % hms(remaining_estimated),
"complete %.02f%%" % (complete * 100),
]
if start_elapsed is not None and num_seqs is None:
info += ["(unk epoch len)"]
print(", ".join(filter(None, info)), file=log.v5)


Expand Down

0 comments on commit a8aeea1

Please sign in to comment.