-
Notifications
You must be signed in to change notification settings - Fork 4.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Samyamr/full precision for ZeRO Stage2 and Stage3 #1004
Conversation
And just to indicate priority to this PR, we have all those bfloat16 models that won't train under fp16/mixed precision, and users want to use DeepSpeed to overcome GPU memory limitations, so they badly need this. Thank you! |
When you feel this looks good enough to test please let me know and I will start testing this branch on the |
Assert to check if param.dtype is torch.half for ZeRO3 should only happen if the model was initialized in ZeRO3 context.
This is awesome - thank you! I encountered only one issue: As I am writing HF transformers tests for fp32, I found that
I thought the whole point of passing
|
I'm not sure if this is the best approach but with microsoft#1004 I still have to pass `zero.Init(dtype)` because this branch never gets executed: ``` def _set_dtype(self, ds_config, dtype): if ds_config is not None and dtype is None: _ds_config = DeepSpeedConfig(ds_config) self.dtype = torch.half if _ds_config.fp16_enabled else torch.float ```
No description provided.