Skip to content

Commit

Permalink
Comments for better understanding of zero stage1_2 (#2027)
Browse files Browse the repository at this point in the history
Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
Co-authored-by: Jeff Rasley <jerasley@microsoft.com>
  • Loading branch information
3 people authored Jul 6, 2022
1 parent 9fc4e5f commit 9305916
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 7 deletions.
2 changes: 2 additions & 0 deletions deepspeed/runtime/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1089,6 +1089,8 @@ def _configure_optimizer(self, client_optimizer, model_parameters):
logger.warning(
"**** You are using ZeRO with an untested optimizer, proceed with caution *****"
)
# This optimizer in engine is ZeRO optimizer of stage1_2 or stage3 based on the 'stage' config,
# while ZeRO optimizer itself wraps the original optimizer.
self.optimizer = self._configure_zero_optimizer(basic_optimizer)
elif self.amp_enabled():
assert not (self.fp16_enabled() or self.bfloat16_enabled()), "Cannot enable both amp with (legacy) fp16 or bfloat16 mode"
Expand Down
20 changes: 13 additions & 7 deletions deepspeed/runtime/zero/stage_1_and_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,9 @@ def __init__(self,
assert (partitioned_data.data_ptr() %
(2 * self.nccl_start_alignment_factor) == 0)

# a partition of the fp32 master weights that will be updated by this process
# A partition of the fp32 master weights that will be updated by this process.
# Note that the params in single_partition_of_fp32_groups is cloned and detached
# from the origin params of the model.
if not fp16_master_weights_and_gradients:
self.single_partition_of_fp32_groups.append(
self.parallel_partitioned_bit16_groups[i][partition_id].to(
Expand All @@ -356,7 +358,9 @@ def __init__(self,
self.parallel_partitioned_bit16_groups[i][partition_id].to(
self.device).clone().half().detach())

# modify optimizer of have flat master weight
# Set local optimizer to have flat params of its own partition.
# After this, the local optimizer will only contain its own partition of params.
# In that case, the local optimizer only saves the states(momentum, variance, etc.) related to its partition's params(zero stage1).
self.single_partition_of_fp32_groups[
i].requires_grad = True # keep this in case internal optimizer uses it
param_group['params'] = [self.single_partition_of_fp32_groups[i]]
Expand Down Expand Up @@ -1426,7 +1430,7 @@ def get_data_parallel_partitions(self, tensor, group_id):
partitions = []

dp = dist.get_world_size(group=self.real_dp_process_group[group_id])
dp_id = dist.get_rank(group=self.real_dp_process_group[group_id])
# dp_id = dist.get_rank(group=self.real_dp_process_group[group_id])

total_num_elements = tensor.numel()

Expand Down Expand Up @@ -1691,7 +1695,7 @@ def step(self, closure=None):
self.start_timers([OPTIMIZER_GRADIENTS])
norm_groups = []
single_partition_grad_groups = []
skip = False
# skip = False
for i, group in enumerate(self.bit16_groups):
partition_id = dist.get_rank(group=self.real_dp_process_group[i])
if self.cpu_offload:
Expand All @@ -1704,7 +1708,7 @@ def step(self, closure=None):
self.get_grad_norm_direct(self.averaged_gradients[i],
self.params_in_partition[i]))

# free gradients for all the parameters that are not updated by this process
# free gradients for all the parameters that are not updated by this process(ZeRO stage2)
self.free_grad_in_param_list(self.params_not_in_partition[i])

# create a flat gradients for parameters updated by this process
Expand All @@ -1723,7 +1727,7 @@ def step(self, closure=None):
single_grad_partition.numel(), self.partition_size[i], i, partition_id)

self.single_partition_of_fp32_groups[i].grad = single_grad_partition
# release all the gradient since we have already created a necessary copy in dp_grad_partition
# release all the gradient since we have already created a necessary copy in dp_grad_partition(ZeRO stage2)
self.free_grad_in_param_list(self.params_in_partition[i])

self.averaged_gradients[i] = None
Expand Down Expand Up @@ -1752,6 +1756,7 @@ def step(self, closure=None):
self.optimizer.step(fp16_param_groups=bit16_param_groups)
else:
self.optimizer.step()
# after step(), single_partition_of_fp32_groups has the local optimizer's own partition of updated params
for bit16_partitions, fp32_partition in zip(self.parallel_partitioned_bit16_groups, self.single_partition_of_fp32_groups):
bit16_partitions[partition_id].data.copy_(fp32_partition.data)
else:
Expand All @@ -1772,7 +1777,8 @@ def step(self, closure=None):
self.reset_cpu_buffers()

self.start_timers([OPTIMIZER_ALLGATHER])
# gather the updated weights from everyone
# Gather the updated weights from everyone.
# Then all partitions of the model parameters are updated and ready for next round forward.
all_gather_dp_groups(
partitioned_param_groups=self.parallel_partitioned_bit16_groups,
dp_process_group=self.real_dp_process_group,
Expand Down

0 comments on commit 9305916

Please sign in to comment.