-
Notifications
You must be signed in to change notification settings - Fork 4.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Allow modification of zero partitioned parameters #4192
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a great feature! The code modifications and the document are also clear. I do have one observation, though it's not immediately pressing:
Currently we have three get_*
functions (safe_get_full_fp32_param
, safe_get_full_grad
, and safe_get_full_optimizer_state
). This PR introduces safe_set_full_fp32_param
and safe_set_full_optimizer_state
. Is there a specific reason we're omitting safe_set_full_grad?
Maintaining consistency in the APIs can help users understand the design better.
Co-authored-by: Michael Wyatt <michaelwyatt@microsoft.com>
@tohtana, thanks for this valid question. I am delaying support for |
@tjruwase |
Essentialy whats needed: grad = get_grads(layernorm.weight)
dist.all_reduce(grad, group=tp_group)
safe_set_grads(grad, layernorm.weight) if there is an alternative way to do it, that will also be helpful. |
Utilities for flexible modification of partitioned fp32 parameters and optimizer states.
Fix #3830