Skip to content

Commit

Permalink
Bug fix for the "Link bit16 and fp32 parameters in partition" (#5681)
Browse files Browse the repository at this point in the history
In the function `_link_all_hp_params`
[link](https://github.com/microsoft/DeepSpeed/blob/b33873d234cf6679a3046be9a137682c3469d1fb/deepspeed/runtime/zero/stage_1_and_2.py#L575):
```python
def _link_all_hp_params(self):
    dp_world_size = dist.get_world_size(group=self.dp_process_group)
    if self.cpu_offload:
        self._get_offload_gradient_dict()

    for i, _ in enumerate(self.optimizer.param_groups):
        # Link bit16 and fp32 params in partition
        partition_id = dist.get_rank(group=self.real_dp_process_group[i])
        partition_size = self.bit16_groups_flat[i].numel() // dp_world_size
        flat_hp_partition = self.single_partition_of_fp32_groups[i]
        link_hp_params(lp_param_list=self.bit16_groups[i],
                       flat_hp_partition=flat_hp_partition,
                       gradient_dict=self.averaged_gradients,
                       offload_gradient_dict=self.offload_gradient_dict,
                       use_offload=self.cpu_offload,
                       param_group_index=i,
                       partition_start=partition_id * partition_size,
                       partition_size=partition_size,
                       dp_group=self.real_dp_process_group[i])
```
`dp_world_size = dist.get_world_size(group=self.dp_process_group)`
ensures that `dp_world_size` is always the global data parallel word
size.
However, for the MoEs parameter group, the line `partition_size =
self.bit16_groups_flat[i].numel() // dp_world_size` results in an
incorrect `partition_size` when `ep_size > 1` (when expert parallelism
is enabled).
This causes only some of the MoEs parameters to be correctly executed in
`link_hp_params`
[link](https://github.com/microsoft/DeepSpeed/blob/b33873d234cf6679a3046be9a137682c3469d1fb/deepspeed/runtime/zero/stage_1_and_2.py#L568),
while the remaining parameters have `_hp_mapping` set to None.
Consequently, this leads to some parameters not being mapped in
`self._param_slice_mappings = self._create_param_mapping()`, which
directly causes errors in storing the optimizer state file for MoEs
parameters.

To fix this bug, we need to use the correct `dp_world_size` for each
parameter group:
```python
    def _link_all_hp_params(self):
        if self.cpu_offload:
            self._get_offload_gradient_dict()

        for i, _ in enumerate(self.optimizer.param_groups):
            # Link bit16 and fp32 params in partition
            partition_id = dist.get_rank(group=self.real_dp_process_group[i])
            partition_size = self.bit16_groups_flat[i].numel() // dist.get_world_size(group=self.real_dp_process_group[i]) # <--
            flat_hp_partition = self.single_partition_of_fp32_groups[i]
            link_hp_params(lp_param_list=self.bit16_groups[i],
                           flat_hp_partition=flat_hp_partition,
                           gradient_dict=self.averaged_gradients,
                           offload_gradient_dict=self.offload_gradient_dict,
                           use_offload=self.cpu_offload,
                           param_group_index=i,
                           partition_start=partition_id * partition_size,
                           partition_size=partition_size,
                           dp_group=self.real_dp_process_group[i])
```

Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
  • Loading branch information
U-rara and tjruwase authored Jun 26, 2024
1 parent b3767d0 commit 224a05c
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions deepspeed/runtime/zero/stage_1_and_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -573,14 +573,14 @@ def _create_param_mapping(self):
return param_mapping

def _link_all_hp_params(self):
dp_world_size = dist.get_world_size(group=self.dp_process_group)
if self.cpu_offload:
self._get_offload_gradient_dict()

for i, _ in enumerate(self.optimizer.param_groups):
# Link bit16 and fp32 params in partition
partition_id = dist.get_rank(group=self.real_dp_process_group[i])
partition_size = self.bit16_groups_flat[i].numel() // dp_world_size
partition_size = self.bit16_groups_flat[i].numel() // dist.get_world_size(
group=self.real_dp_process_group[i])
flat_hp_partition = self.single_partition_of_fp32_groups[i]
link_hp_params(lp_param_list=self.bit16_groups[i],
flat_hp_partition=flat_hp_partition,
Expand Down

0 comments on commit 224a05c

Please sign in to comment.