From a8aeea190cbf60f36ee87452ba4b9ba9adff8e23 Mon Sep 17 00:00:00 2001 From: Albert Zeyer Date: Fri, 18 Oct 2024 15:10:15 +0200 Subject: [PATCH] PT engine, report train complete status, expected remaining time --- returnn/torch/data/pipeline.py | 5 +- returnn/torch/data/returnn_dataset_wrapper.py | 11 ++++ returnn/torch/engine.py | 62 ++++++++++++++++--- 3 files changed, 67 insertions(+), 11 deletions(-) diff --git a/returnn/torch/data/pipeline.py b/returnn/torch/data/pipeline.py index c8a4a9059..9ff5f4054 100644 --- a/returnn/torch/data/pipeline.py +++ b/returnn/torch/data/pipeline.py @@ -59,6 +59,9 @@ def collate_batch(batch: List[Dict[str, numpy.ndarray]]) -> Dict[str, Union[torc res = {} for key in data_keys: + if key == "num_seqs": + res[key] = batch[0][key] # it should always be the same + continue ls = [create_tensor(sample[key]) for sample in batch] if not ls: raise ValueError("batch is empty?") @@ -116,7 +119,7 @@ def __iter__(self): if not chunking_data_keys: chunking_data_keys = list(data_dict.keys()) # use all if not configured separately - chunking_data_key_black_list = ["seq_tag"] + chunking_data_key_black_list = ["seq_tag", "seq_idx", "num_seqs"] for key in chunking_data_key_black_list: if key in chunking_data_keys: chunking_data_keys.remove(key) diff --git a/returnn/torch/data/returnn_dataset_wrapper.py b/returnn/torch/data/returnn_dataset_wrapper.py index bd6367271..78e305346 100644 --- a/returnn/torch/data/returnn_dataset_wrapper.py +++ b/returnn/torch/data/returnn_dataset_wrapper.py @@ -75,6 +75,13 @@ def __iter__(self) -> Iterable[Dict[str, numpy.ndarray]]: """ :return: generator providing data samples in the form of a dict data_key -> data """ + # noinspection PyBroadException + try: + num_seqs = self._dataset.num_seqs + except Exception: # might not work for all datasets + num_seqs = -1 + num_seqs = numpy.array(num_seqs) + try: data_keys = self._dataset.get_data_keys() @@ -83,6 +90,10 @@ def __iter__(self) -> Iterable[Dict[str, numpy.ndarray]]: self._dataset.load_seqs(seq_index, seq_index + 1) data = {data_key: self._dataset.get_data(seq_index, data_key) for data_key in data_keys} data["seq_tag"] = str_to_numpy_array(self._dataset.get_tag(seq_index)) + data["seq_idx"] = numpy.array(seq_index) + # It's slightly redundant to have num_seqs in each entry, + # but it's difficult to pass this back to the main proc otherwise. + data["num_seqs"] = num_seqs yield data seq_index += 1 diff --git a/returnn/torch/engine.py b/returnn/torch/engine.py index f00358cbf..c43b5363d 100644 --- a/returnn/torch/engine.py +++ b/returnn/torch/engine.py @@ -311,7 +311,7 @@ def train_epoch(self): accumulated_losses_dict = NumbersDict() accumulated_inv_norm_factors_dict = NumbersDict() step_idx = 0 - epoch_start_time = time.time() + epoch_start_time = time.monotonic() data_iter = iter(self._train_dataloader) elapsed_computation_time = 0 @@ -339,12 +339,14 @@ def train_epoch(self): zero_grad_next_step = True cur_count_grad_accum = 0 extern_data = None + num_seqs = None + last_seq_idx = 0 try: while True: with torch.no_grad(): extern_data_raw = next(data_iter, None) - step_begin_time = time.time() + step_begin_time = time.monotonic() _has_data = torch.tensor([extern_data_raw is not None], dtype=torch.int8) if self._torch_distributed_ctx: @@ -353,6 +355,22 @@ def train_epoch(self): torch.distributed.all_reduce(_has_data, op=torch.distributed.ReduceOp.MIN) if not _has_data[0]: break + num_seqs_ = ( + int(extern_data_raw["num_seqs"]) if extern_data_raw.get("num_seqs", None) is not None else -1 + ) + last_seq_idx_ = extern_data_raw["seq_idx"].max() + assert last_seq_idx_ >= last_seq_idx + last_seq_idx = int(last_seq_idx_) + del last_seq_idx_ + if step_idx == 0: + if num_seqs_ >= 0: + print(f"Epoch {self.epoch} num_seqs: {num_seqs_}", file=log.v5) + num_seqs = num_seqs_ + elif num_seqs_ >= 0: + assert num_seqs_ == num_seqs + del num_seqs_ + if num_seqs is not None: + assert last_seq_idx < num_seqs # clear the gradients when every gradient accumulation loop starts if zero_grad_next_step: @@ -404,7 +422,8 @@ def train_epoch(self): if self._torch_distributed_ctx: self._torch_distributed_ctx.step_after_param_update(module=self._pt_model, epoch_step_idx=step_idx) - step_duration = time.time() - step_begin_time + step_end_time = time.monotonic() + step_duration = step_end_time - step_begin_time elapsed_computation_time += step_duration accumulated_losses_dict += losses_dict @@ -415,6 +434,9 @@ def train_epoch(self): step=step_idx, eval_info=dict(eval_info), step_duration=step_duration, + start_elapsed=step_end_time - epoch_start_time, + seq_idx=last_seq_idx, + num_seqs=num_seqs, batch_size_info=_get_batch_size_info(extern_data) if self._log_batch_size else None, log_memory_usage_device=self._device if self._log_memory_usage else None, ) @@ -436,7 +458,7 @@ def train_epoch(self): help_on_torch_exception(exc, step_idx=step_idx, model=self._orig_model, extern_data=extern_data) raise - elapsed = time.time() - epoch_start_time + elapsed = time.monotonic() - epoch_start_time elapsed_computation_percentage = elapsed_computation_time / elapsed print( "Trained %i steps, %s elapsed (%.1f%% computing time)" @@ -1008,7 +1030,7 @@ def forward_with_callback(self, *, dataset: Dataset, callback: ForwardCallbackIf assert isinstance(dataset, Dataset) assert isinstance(callback, ForwardCallbackIface) - epoch_start_time = time.time() + epoch_start_time = time.monotonic() elapsed_computation_time = 0.0 self._pt_model.eval() @@ -1087,7 +1109,7 @@ def _get_dim_tag_wo_batch(dim: Dim) -> Dim: step_idx = 0 for extern_data_raw in data_loader: - step_begin_time = time.time() + step_begin_time = time.monotonic() if self._forward_step_expected_outputs: # Also resets any dyn dims, which might have been set in the prev step. self._forward_step_expected_outputs.reset_content() @@ -1121,7 +1143,7 @@ def _get_dim_tag_wo_batch(dim: Dim) -> Dim: model_outputs_per_batch.data[k] = _get_tensor_wo_batch_numpy(v) callback.process_seq(seq_tag=seq_tag, outputs=model_outputs_per_batch) - elapsed_computation_time += time.time() - step_begin_time + elapsed_computation_time += time.monotonic() - step_begin_time _print_process( report_prefix, step=step_idx, @@ -1132,7 +1154,7 @@ def _get_dim_tag_wo_batch(dim: Dim) -> Dim: callback.finish() - elapsed = time.time() - epoch_start_time + elapsed = time.monotonic() - epoch_start_time elapsed_computation_percentage = elapsed_computation_time / elapsed print( "Forward %i steps, %s elapsed (%.1f%% computing time)" @@ -1202,20 +1224,26 @@ def _to_raw(n: Union[int, float, Tensor]): def _print_process( report_prefix: str, + *, step: int, eval_info: Optional[Dict[str, Any]] = None, batch_size_info: Optional[Dict[str, Any]] = None, step_duration: Optional[float] = None, + start_elapsed: Optional[float] = None, + seq_idx: Optional[int] = None, + num_seqs: Optional[int] = None, log_memory_usage_device: Optional[str] = None, ): """ Similar but simplified from TF engine _print_process. :param report_prefix: - :param step: + :param step: for this epoch :param eval_info: :param batch_size_info: - :param step_duration: + :param step_duration: time elapsed for this step (secs) + :param start_elapsed: time elapsed since epoch start (secs) + :param num_seqs: total number of sequences for this epoch :param log_memory_usage_device: if given, will log memory usage (peak allocated memory) :return: nothing, will be printed to log """ @@ -1233,6 +1261,20 @@ def _print_process( ] if step_duration is not None: info += ["%.3f sec/step" % step_duration] + if start_elapsed is not None: + info += ["elapsed %s" % hms(start_elapsed)] + if num_seqs is not None: + assert seq_idx is not None and start_elapsed is not None # unexpected combination... + complete = (seq_idx + 1) / num_seqs + assert 1 >= complete > 0, f"{step} step, {num_seqs} num_seqs" + total_time_estimated = start_elapsed / complete + remaining_estimated = total_time_estimated - start_elapsed + info += [ + "exp. remaining %s" % hms(remaining_estimated), + "complete %.02f%%" % (complete * 100), + ] + if start_elapsed is not None and num_seqs is None: + info += ["(unk epoch len)"] print(", ".join(filter(None, info)), file=log.v5)