Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Samyamr/full precision for ZeRO Stage2 and Stage3 #1004

Merged
merged 21 commits into from
Apr 29, 2021
Merged
Changes from 1 commit
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Formatting fix
  • Loading branch information
tjruwase committed Apr 29, 2021
commit bb46f5880ce92ed9286af18f5f47bacfda98be86
4 changes: 3 additions & 1 deletion deepspeed/runtime/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -582,7 +582,9 @@ def is_replicated(p):
def _configure_distributed_model(self, model):
self.module = model
if self.fp16_enabled():
if self.zero_optimization_partition_weights() and any([hasattr(param,'ds_id') for param in self.module.parameters()]):
if self.zero_optimization_partition_weights() and any(
[hasattr(param,
'ds_id') for param in self.module.parameters()]):
assert all([param.dtype == torch.half for param in self.module.parameters()]), f"Model must initialized in fp16 mode for ZeRO Stage 3."
self.module.half()
tjruwase marked this conversation as resolved.
Show resolved Hide resolved
else:
Expand Down