Skip to content

loss averaging across multiple devices #254

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 10, 2023
Merged

Conversation

hadipash
Copy link
Collaborator

@hadipash hadipash commented May 4, 2023

Motivation

Average loss across devices to see better training stats.

@hadipash hadipash requested review from SamitHuang and zhtmike May 4, 2023 09:07
@jit
def reduce(x): # lamda expression is not supported in MindSpore
return reduce_sum(x) / device_num # average loss across all cards

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it is better to put reduce outside the function

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agree. btw, is it necessary to use jit decorator here as we are already in graph mode and the reduce computation should be low-weight and fast. If no, we can simply use self._loss_reduce = lambda x: reduce_sum(x) / device_num

Copy link
Collaborator Author

@hadipash hadipash May 5, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Talked to Jun about it, he said that: 1. callbacks are always executed in native mode, 2. ops.AllReduce() may take a noticeable amount of time in native mode due to some overhead computations in the backend, so it is generally recommended to wrap it with jit.

Although, I agree that reducing single number tensors may be very quick and jit could be an overkill here. Maybe we can benchmark later and see if it is really necessary.

else:
self._loss_reduce = lambda x: x
@jit
def _reduce(self, x):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zhtmike @SamitHuang Please check.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure whether running with ms_function in callback is a stable choice. For MS 1.9, pynative with ms_function is not as stable as MS 2.0. If using ms_function is risky, i don't it is worthy to add jit/ms_function considering the ignorable acceleration on this one-step division computation.

@hadipash hadipash requested a review from zhtmike May 8, 2023 03:49
@SamitHuang SamitHuang merged commit 62724b4 into mindspore-lab:main May 10, 2023
@hadipash hadipash deleted the fix branch May 23, 2023 08:26
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants