Skip to content

Commit

Permalink
Make z3 respect comm dtype (#2807)
Browse files Browse the repository at this point in the history
* Make z3 respect comm dtype

* Support fp32 comm dtype

* Remove obsolete assert

* Code cleanup
  • Loading branch information
tjruwase authored Feb 22, 2023
1 parent 7c99def commit 81b4d5d
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 16 deletions.
11 changes: 1 addition & 10 deletions deepspeed/runtime/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1124,15 +1124,7 @@ def _configure_distributed_model(self, model):
if self.zero_optimization_partition_weights() and any(
[hasattr(param,
"ds_id") for param in self.module.parameters()]):
if not all(
[param.dtype == torch.half for param in self.module.parameters()]):
names = [
n for n,
p in self.module.named_parameters() if p.dtype != torch.half
]
raise ValueError(
f"fp16 is enabled but the following parameters have dtype that is not fp16: {', '.join(names)}"
)
self.__check_params(self.module, torch.half)
self.module.half()
elif self.bfloat16_enabled():
if self.zero_optimization_partition_weights() and any(
Expand Down Expand Up @@ -1506,7 +1498,6 @@ def _configure_bf16_optimizer(self, optimizer):
def _configure_zero_optimizer(self, optimizer):
zero_stage = self.zero_optimization_stage()
model_dtype, grad_accum_dtype = self.get_data_types()
assert self.communication_data_type in (torch.float16, torch.bfloat16), "ZeRO supports only 'communication_data_type': ['fp16', 'bfp16']"
timers = self.timers if self.wall_clock_breakdown() else None

if optimizer is None:
Expand Down
15 changes: 9 additions & 6 deletions deepspeed/runtime/zero/stage3.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ def __init__(self,
self.reduce_bucket_size = int(reduce_bucket_size)

if self.reduce_scatter:
assert self.communication_data_type in (torch.float16, torch.bfloat16), f"ZeRO-3 supports only float16 or bfloat16 communication_data_type with reduce scatter enabled. Got: '{self.communication_data_type}'"
assert self.communication_data_type in (torch.float16, torch.bfloat16, torch.float32), f"ZeRO-3 supports only float16 or bfloat16 communication_data_type with reduce scatter enabled. Got: '{self.communication_data_type}'"
assert self.gradient_predivide_factor == 1.0, "gradient_predivide_factor != 1.0 is not yet supported with ZeRO-3 with reduce scatter enabled"
assert self.postscale_gradients, "pre-scale gradients is not yet supported with ZeRO-3 with reduce scatter enabled"

Expand Down Expand Up @@ -1162,11 +1162,12 @@ def __reduce_and_partition_ipg_grads(self, safe_mode: bool = False) -> None:
@instrument_w_nvtx
def __avg_scatter_grads(self, params_to_reduce: List[Parameter]) -> List[Tensor]:
"""average gradients and scatter partitions across ranks"""
dtype = get_only_unique_item(p.grad.dtype for p in params_to_reduce)

full_grads_for_rank = [p.grad for p in params_to_reduce]
if self.communication_data_type == torch.float32:
full_grads_for_rank = [g.float() for g in full_grads_for_rank]
if self.communication_data_type != self.dtype:
full_grads_for_rank = [
g.to(self.communication_data_type) for g in full_grads_for_rank
]

if self.postscale_gradients and self.gradient_predivide_factor != 1.0:
full_grads_for_rank = [
Expand All @@ -1182,8 +1183,10 @@ def __avg_scatter_grads(self, params_to_reduce: List[Parameter]) -> List[Tensor]
g.mul(self.gradient_predivide_factor) for g in grad_partitions_for_rank
]

if self.communication_data_type == torch.float32:
grad_partitions_for_rank = [g.to(dtype) for g in grad_partitions_for_rank]
if self.communication_data_type != self.dtype:
grad_partitions_for_rank = [
g.to(self.dtype) for g in grad_partitions_for_rank
]

return grad_partitions_for_rank

Expand Down

0 comments on commit 81b4d5d

Please sign in to comment.