Skip to content

loss averaging across multiple devices #254

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 33 additions & 20 deletions mindocr/utils/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,8 @@
import time
from tqdm import tqdm
from typing import List
import shutil
from packaging import version

import numpy as np
import mindspore as ms
from mindspore.ops import functional as F
from mindspore.common import dtype as mstype
Expand All @@ -15,6 +14,13 @@

__all__ = ['Evaluator', 'EvalSaveCallback']

# WARNING: `mindspore.ms_function` will be deprecated and removed in a future version.
if version.parse(ms.__version__) >= version.parse('2.0.0rc'):
from mindspore import jit
else:
from mindspore import ms_function
jit = ms_function


class Evaluator:
"""
Expand All @@ -41,7 +47,7 @@ def __init__(self,
label_indices=None,
meta_data_indices=None,
num_epochs=-1,
visualize=False,
visualize=False,
verbose=False,
**kwargs):
self.net = network
Expand All @@ -52,7 +58,7 @@ def __init__(self,
assert hasattr(m, 'metric_names') and isinstance(m.metric_names,
List), f'Metric object must contain `metric_names` attribute to indicate the metric names as a List type, but not found in {m.__class__.__name__}'
self.metric_names += m.metric_names

self.pred_cast_fp32 = pred_cast_fp32
self.visualize = visualize
self.verbose = verbose
Expand All @@ -65,7 +71,6 @@ def __init__(self,
# create iterator
self.reload(dataloader, input_indices, label_indices, meta_data_indices, num_epochs)


def reload(self, dataloader, input_indices=None, label_indices=None, meta_data_indices=None, num_epochs=-1):
# create iterator
self.iterator = dataloader.create_tuple_iterator(num_epochs=num_epochs, output_numpy=False, do_copy=False)
Expand All @@ -86,7 +91,6 @@ def eval(self):
for m in self.metrics:
m.clear()


for i, data in tqdm(enumerate(self.iterator), total=self.num_batches_eval):
if self.input_indices is not None:
inputs = [data[x] for x in self.input_indices]
Expand Down Expand Up @@ -162,6 +166,7 @@ def __init__(self,
metrics=None,
pred_cast_fp32=False,
rank_id=0,
device_num=None,
logger=None,
batch_size=20,
ckpt_save_dir='./',
Expand All @@ -178,30 +183,41 @@ def __init__(self,
self.is_main_device = rank_id in [0, None]
self.loader_eval = loader
self.network = network
self.ema= ema
self.ema = ema
self.logger = print if logger is None else logger.info
self.val_interval = val_interval
self.val_start_epoch = val_start_epoch
self.log_interval = log_interval
self.batch_size = batch_size
if self.loader_eval is not None:
self.net_evaluator = Evaluator(network, loader, loss_fn, postprocessor, metrics, pred_cast_fp32=pred_cast_fp32, input_indices=input_indices, label_indices=label_indices, meta_data_indices=meta_data_indices)
self.net_evaluator = Evaluator(network, loader, loss_fn, postprocessor, metrics,
pred_cast_fp32=pred_cast_fp32, input_indices=input_indices,
label_indices=label_indices, meta_data_indices=meta_data_indices)
self.main_indicator = main_indicator
self.best_perf = -1e8
else:
self.main_indicator = 'train_loss'
self.best_perf = 1e8


self.ckpt_save_dir = ckpt_save_dir
if not os.path.exists(ckpt_save_dir):
os.makedirs(ckpt_save_dir)

self._losses = list()
self.last_epoch_end_time = time.time()
self.epoch_start_time = time.time()
self.step_start_time = time.time()

self._losses = []

self._reduce_sum = ms.ops.AllReduce()
self._device_num = device_num
# lamda expression is not supported in jit
self._loss_reduce = self._reduce if device_num is not None else lambda x: x

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it is better to put reduce outside the function

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agree. btw, is it necessary to use jit decorator here as we are already in graph mode and the reduce computation should be low-weight and fast. If no, we can simply use self._loss_reduce = lambda x: reduce_sum(x) / device_num

Copy link
Collaborator Author

@hadipash hadipash May 5, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Talked to Jun about it, he said that: 1. callbacks are always executed in native mode, 2. ops.AllReduce() may take a noticeable amount of time in native mode due to some overhead computations in the backend, so it is generally recommended to wrap it with jit.

Although, I agree that reducing single number tensors may be very quick and jit could be an overkill here. Maybe we can benchmark later and see if it is really necessary.

@jit
def _reduce(self, x):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zhtmike @SamitHuang Please check.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure whether running with ms_function in callback is a stable choice. For MS 1.9, pynative with ms_function is not as stable as MS 2.0. If using ms_function is risky, i don't it is worthy to add jit/ms_function considering the ignorable acceleration on this one-step division computation.

return self._reduce_sum(x) / self._device_num # average value across all devices

def on_train_step_end(self, run_context):
"""
Print training loss at the end of step.
Expand All @@ -215,19 +231,17 @@ def on_train_step_end(self, run_context):
data_sink_mode = cb_params.dataset_sink_mode
cur_step_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num + 1

# TODO: need to stop gradient here ?
self._losses.append(loss.asnumpy())
self._losses.append(self._loss_reduce(loss))

if not data_sink_mode and cur_step_in_epoch % self.log_interval == 0:
opt = cb_params.train_network.optimizer
learning_rate = opt.learning_rate
cur_lr = learning_rate(opt.global_step - 1).asnumpy()
per_step_time = (time.time() - self.step_start_time) * 1000 / self.log_interval
fps = self.batch_size * 1000 / per_step_time
loss = np.average(self._losses)
msg = "epoch: [%s/%s] step: [%s/%s], loss: %.6f, lr: %.6f, per step time: %.3f ms, fps: %.2f img/s" % (
cur_epoch, cb_params.epoch_num, cur_step_in_epoch, cb_params.batch_num,
loss, cur_lr, per_step_time, fps)
loss = self._losses[-1].asnumpy()
msg = f"epoch: [{cur_epoch}/{cb_params.epoch_num}] step: [{cur_step_in_epoch}/{cb_params.batch_num}], " \
f"loss: {loss:.6f}, lr: {cur_lr:.6f}, per step time: {per_step_time:.3f} ms, fps: {fps:.2f} img/s"
self.logger(msg)
self.step_start_time = time.time()

Expand All @@ -249,16 +263,15 @@ def on_train_epoch_end(self, run_context):
run_context (RunContext): Include some information of the model.
"""
cb_params = run_context.original_args()
loss = cb_params.net_outputs
cur_epoch = cb_params.cur_epoch_num
train_time = (time.time() - self.epoch_start_time)
train_loss = np.average(self._losses) # TODO: aggregate training loss for multiple cards
train_loss = ms.ops.stack(self._losses).mean().asnumpy()

epoch_time = (time.time() - self.epoch_start_time)
per_step_time = epoch_time * 1000 / cb_params.batch_num
fps = 1000 * self.batch_size / per_step_time
msg = "epoch: [%s/%s], loss: %.6f, epoch time: %.3f s, per step time: %.3f ms, fps: %.2f img/s" % (
cur_epoch, cb_params.epoch_num, train_loss, epoch_time, per_step_time, fps)
msg = f"epoch: [{cur_epoch}/{cb_params.epoch_num}], loss: {train_loss:.6f}, " \
f"epoch time: {epoch_time:.3f} s, per step time: {per_step_time:.3f} ms, fps: {fps:.2f} img/s"
self.logger(msg)

eval_done = False
Expand Down
1 change: 1 addition & 0 deletions tools/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ def main(cfg):
metrics=[metric],
pred_cast_fp32=(amp_level!='O0'),
rank_id=rank_id,
device_num=device_num,
logger=logger,
batch_size=cfg.train.loader.batch_size,
ckpt_save_dir=cfg.train.ckpt_save_dir,
Expand Down