-
Notifications
You must be signed in to change notification settings - Fork 60
loss averaging across multiple devices #254
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,9 +2,8 @@ | |
import time | ||
from tqdm import tqdm | ||
from typing import List | ||
import shutil | ||
from packaging import version | ||
|
||
import numpy as np | ||
import mindspore as ms | ||
from mindspore.ops import functional as F | ||
from mindspore.common import dtype as mstype | ||
|
@@ -15,6 +14,13 @@ | |
|
||
__all__ = ['Evaluator', 'EvalSaveCallback'] | ||
|
||
# WARNING: `mindspore.ms_function` will be deprecated and removed in a future version. | ||
if version.parse(ms.__version__) >= version.parse('2.0.0rc'): | ||
from mindspore import jit | ||
else: | ||
from mindspore import ms_function | ||
jit = ms_function | ||
|
||
|
||
class Evaluator: | ||
""" | ||
|
@@ -41,7 +47,7 @@ def __init__(self, | |
label_indices=None, | ||
meta_data_indices=None, | ||
num_epochs=-1, | ||
visualize=False, | ||
visualize=False, | ||
verbose=False, | ||
**kwargs): | ||
self.net = network | ||
|
@@ -52,7 +58,7 @@ def __init__(self, | |
assert hasattr(m, 'metric_names') and isinstance(m.metric_names, | ||
List), f'Metric object must contain `metric_names` attribute to indicate the metric names as a List type, but not found in {m.__class__.__name__}' | ||
self.metric_names += m.metric_names | ||
|
||
self.pred_cast_fp32 = pred_cast_fp32 | ||
self.visualize = visualize | ||
self.verbose = verbose | ||
|
@@ -65,7 +71,6 @@ def __init__(self, | |
# create iterator | ||
self.reload(dataloader, input_indices, label_indices, meta_data_indices, num_epochs) | ||
|
||
|
||
def reload(self, dataloader, input_indices=None, label_indices=None, meta_data_indices=None, num_epochs=-1): | ||
# create iterator | ||
self.iterator = dataloader.create_tuple_iterator(num_epochs=num_epochs, output_numpy=False, do_copy=False) | ||
|
@@ -86,7 +91,6 @@ def eval(self): | |
for m in self.metrics: | ||
m.clear() | ||
|
||
|
||
for i, data in tqdm(enumerate(self.iterator), total=self.num_batches_eval): | ||
if self.input_indices is not None: | ||
inputs = [data[x] for x in self.input_indices] | ||
|
@@ -162,6 +166,7 @@ def __init__(self, | |
metrics=None, | ||
pred_cast_fp32=False, | ||
rank_id=0, | ||
device_num=None, | ||
logger=None, | ||
batch_size=20, | ||
ckpt_save_dir='./', | ||
|
@@ -178,30 +183,41 @@ def __init__(self, | |
self.is_main_device = rank_id in [0, None] | ||
self.loader_eval = loader | ||
self.network = network | ||
self.ema= ema | ||
self.ema = ema | ||
self.logger = print if logger is None else logger.info | ||
self.val_interval = val_interval | ||
self.val_start_epoch = val_start_epoch | ||
self.log_interval = log_interval | ||
self.batch_size = batch_size | ||
if self.loader_eval is not None: | ||
self.net_evaluator = Evaluator(network, loader, loss_fn, postprocessor, metrics, pred_cast_fp32=pred_cast_fp32, input_indices=input_indices, label_indices=label_indices, meta_data_indices=meta_data_indices) | ||
self.net_evaluator = Evaluator(network, loader, loss_fn, postprocessor, metrics, | ||
pred_cast_fp32=pred_cast_fp32, input_indices=input_indices, | ||
label_indices=label_indices, meta_data_indices=meta_data_indices) | ||
self.main_indicator = main_indicator | ||
self.best_perf = -1e8 | ||
else: | ||
self.main_indicator = 'train_loss' | ||
self.best_perf = 1e8 | ||
|
||
|
||
self.ckpt_save_dir = ckpt_save_dir | ||
if not os.path.exists(ckpt_save_dir): | ||
os.makedirs(ckpt_save_dir) | ||
|
||
self._losses = list() | ||
self.last_epoch_end_time = time.time() | ||
self.epoch_start_time = time.time() | ||
self.step_start_time = time.time() | ||
|
||
self._losses = [] | ||
|
||
self._reduce_sum = ms.ops.AllReduce() | ||
self._device_num = device_num | ||
# lamda expression is not supported in jit | ||
self._loss_reduce = self._reduce if device_num is not None else lambda x: x | ||
|
||
@jit | ||
def _reduce(self, x): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @zhtmike @SamitHuang Please check. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not sure whether running with ms_function in callback is a stable choice. For MS 1.9, pynative with ms_function is not as stable as MS 2.0. If using ms_function is risky, i don't it is worthy to add jit/ms_function considering the ignorable acceleration on this one-step division computation. |
||
return self._reduce_sum(x) / self._device_num # average value across all devices | ||
|
||
def on_train_step_end(self, run_context): | ||
""" | ||
Print training loss at the end of step. | ||
|
@@ -215,19 +231,17 @@ def on_train_step_end(self, run_context): | |
data_sink_mode = cb_params.dataset_sink_mode | ||
cur_step_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num + 1 | ||
|
||
# TODO: need to stop gradient here ? | ||
self._losses.append(loss.asnumpy()) | ||
self._losses.append(self._loss_reduce(loss)) | ||
|
||
if not data_sink_mode and cur_step_in_epoch % self.log_interval == 0: | ||
opt = cb_params.train_network.optimizer | ||
learning_rate = opt.learning_rate | ||
cur_lr = learning_rate(opt.global_step - 1).asnumpy() | ||
per_step_time = (time.time() - self.step_start_time) * 1000 / self.log_interval | ||
fps = self.batch_size * 1000 / per_step_time | ||
loss = np.average(self._losses) | ||
msg = "epoch: [%s/%s] step: [%s/%s], loss: %.6f, lr: %.6f, per step time: %.3f ms, fps: %.2f img/s" % ( | ||
cur_epoch, cb_params.epoch_num, cur_step_in_epoch, cb_params.batch_num, | ||
loss, cur_lr, per_step_time, fps) | ||
loss = self._losses[-1].asnumpy() | ||
msg = f"epoch: [{cur_epoch}/{cb_params.epoch_num}] step: [{cur_step_in_epoch}/{cb_params.batch_num}], " \ | ||
f"loss: {loss:.6f}, lr: {cur_lr:.6f}, per step time: {per_step_time:.3f} ms, fps: {fps:.2f} img/s" | ||
self.logger(msg) | ||
self.step_start_time = time.time() | ||
|
||
|
@@ -249,16 +263,15 @@ def on_train_epoch_end(self, run_context): | |
run_context (RunContext): Include some information of the model. | ||
""" | ||
cb_params = run_context.original_args() | ||
loss = cb_params.net_outputs | ||
cur_epoch = cb_params.cur_epoch_num | ||
train_time = (time.time() - self.epoch_start_time) | ||
train_loss = np.average(self._losses) # TODO: aggregate training loss for multiple cards | ||
train_loss = ms.ops.stack(self._losses).mean().asnumpy() | ||
|
||
epoch_time = (time.time() - self.epoch_start_time) | ||
per_step_time = epoch_time * 1000 / cb_params.batch_num | ||
fps = 1000 * self.batch_size / per_step_time | ||
msg = "epoch: [%s/%s], loss: %.6f, epoch time: %.3f s, per step time: %.3f ms, fps: %.2f img/s" % ( | ||
cur_epoch, cb_params.epoch_num, train_loss, epoch_time, per_step_time, fps) | ||
msg = f"epoch: [{cur_epoch}/{cb_params.epoch_num}], loss: {train_loss:.6f}, " \ | ||
f"epoch time: {epoch_time:.3f} s, per step time: {per_step_time:.3f} ms, fps: {fps:.2f} img/s" | ||
self.logger(msg) | ||
|
||
eval_done = False | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it is better to put reduce outside the function
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
agree. btw, is it necessary to use jit decorator here as we are already in graph mode and the reduce computation should be low-weight and fast. If no, we can simply use
self._loss_reduce = lambda x: reduce_sum(x) / device_num
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Talked to Jun about it, he said that: 1. callbacks are always executed in native mode, 2.
ops.AllReduce()
may take a noticeable amount of time in native mode due to some overhead computations in the backend, so it is generally recommended to wrap it withjit
.Although, I agree that reducing single number tensors may be very quick and
jit
could be an overkill here. Maybe we can benchmark later and see if it is really necessary.