-
Notifications
You must be signed in to change notification settings - Fork 60
loss averaging across multiple devices #254
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@jit | ||
def reduce(x): # lamda expression is not supported in MindSpore | ||
return reduce_sum(x) / device_num # average loss across all cards | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it is better to put reduce outside the function
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
agree. btw, is it necessary to use jit decorator here as we are already in graph mode and the reduce computation should be low-weight and fast. If no, we can simply use self._loss_reduce = lambda x: reduce_sum(x) / device_num
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Talked to Jun about it, he said that: 1. callbacks are always executed in native mode, 2. ops.AllReduce()
may take a noticeable amount of time in native mode due to some overhead computations in the backend, so it is generally recommended to wrap it with jit
.
Although, I agree that reducing single number tensors may be very quick and jit
could be an overkill here. Maybe we can benchmark later and see if it is really necessary.
else: | ||
self._loss_reduce = lambda x: x | ||
@jit | ||
def _reduce(self, x): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@zhtmike @SamitHuang Please check.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure whether running with ms_function in callback is a stable choice. For MS 1.9, pynative with ms_function is not as stable as MS 2.0. If using ms_function is risky, i don't it is worthy to add jit/ms_function considering the ignorable acceleration on this one-step division computation.
Motivation
Average loss across devices to see better training stats.