Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make z3 respect comm dtype #2807

Merged
merged 13 commits into from
Feb 22, 2023
19 changes: 10 additions & 9 deletions deepspeed/runtime/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1124,15 +1124,16 @@ def _configure_distributed_model(self, model):
if self.zero_optimization_partition_weights() and any(
[hasattr(param,
"ds_id") for param in self.module.parameters()]):
if not all(
[param.dtype == torch.half for param in self.module.parameters()]):
names = [
n for n,
p in self.module.named_parameters() if p.dtype != torch.half
]
raise ValueError(
f"fp16 is enabled but the following parameters have dtype that is not fp16: {', '.join(names)}"
)
self.__check_params(self.module, torch.half)
# if not all(
# [param.dtype == torch.half for param in self.module.parameters()]):
# names = [
# n for n,
# p in self.module.named_parameters() if p.dtype != torch.half
# ]
# raise ValueError(
# f"fp16 is enabled but the following parameters have dtype that is not fp16: {', '.join(names)}"
# )
self.module.half()
elif self.bfloat16_enabled():
if self.zero_optimization_partition_weights() and any(
Expand Down
16 changes: 10 additions & 6 deletions deepspeed/runtime/zero/stage3.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ def __init__(self,
self.reduce_bucket_size = int(reduce_bucket_size)

if self.reduce_scatter:
assert self.communication_data_type in (torch.float16, torch.bfloat16), f"ZeRO-3 supports only float16 or bfloat16 communication_data_type with reduce scatter enabled. Got: '{self.communication_data_type}'"
assert self.communication_data_type in (torch.float16, torch.bfloat16, torch.float32), f"ZeRO-3 supports only float16 or bfloat16 communication_data_type with reduce scatter enabled. Got: '{self.communication_data_type}'"
tjruwase marked this conversation as resolved.
Show resolved Hide resolved
assert self.gradient_predivide_factor == 1.0, "gradient_predivide_factor != 1.0 is not yet supported with ZeRO-3 with reduce scatter enabled"
assert self.postscale_gradients, "pre-scale gradients is not yet supported with ZeRO-3 with reduce scatter enabled"

Expand Down Expand Up @@ -1162,11 +1162,13 @@ def __reduce_and_partition_ipg_grads(self, safe_mode: bool = False) -> None:
@instrument_w_nvtx
def __avg_scatter_grads(self, params_to_reduce: List[Parameter]) -> List[Tensor]:
"""average gradients and scatter partitions across ranks"""
dtype = get_only_unique_item(p.grad.dtype for p in params_to_reduce)
# model_dtype = get_only_unique_item(p.grad.dtype for p in params_to_reduce)

full_grads_for_rank = [p.grad for p in params_to_reduce]
if self.communication_data_type == torch.float32:
full_grads_for_rank = [g.float() for g in full_grads_for_rank]
if self.communication_data_type != self.dtype:
full_grads_for_rank = [
g.to(self.communication_data_type) for g in full_grads_for_rank
]

if self.postscale_gradients and self.gradient_predivide_factor != 1.0:
full_grads_for_rank = [
Expand All @@ -1182,8 +1184,10 @@ def __avg_scatter_grads(self, params_to_reduce: List[Parameter]) -> List[Tensor]
g.mul(self.gradient_predivide_factor) for g in grad_partitions_for_rank
]

if self.communication_data_type == torch.float32:
grad_partitions_for_rank = [g.to(dtype) for g in grad_partitions_for_rank]
if self.communication_data_type != self.dtype:
grad_partitions_for_rank = [
g.to(self.dtype) for g in grad_partitions_for_rank
]

return grad_partitions_for_rank

Expand Down