Skip to content

Commit

Permalink
[zero_to_fp32] adapt to 4-bytes alignment in z2 (#1372)
Browse files Browse the repository at this point in the history
Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
  • Loading branch information
stas00 and tjruwase authored Sep 16, 2021
1 parent cf22a69 commit 30537e7
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 14 deletions.
35 changes: 28 additions & 7 deletions deepspeed/utils/zero_to_fp32.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import argparse
import torch
import glob
import math
import os
from collections import OrderedDict

Expand Down Expand Up @@ -116,6 +117,11 @@ def parse_optim_states(files, ds_checkpoint_dir):


def zero3_partitioned_param_info(unpartitioned_numel, world_size):
#print("*** ", unpartitioned_numel, world_size, " ***",)
# handle an edge case where there is only 1 element (e.g. bias in a tiny test model)
if unpartitioned_numel == 1:
return 1, 0

remainder = unpartitioned_numel % world_size
padding_numel = (world_size - remainder) if remainder else 0
partitioned_numel = int(unpartitioned_numel / world_size)
Expand Down Expand Up @@ -205,18 +211,33 @@ def _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir):
f"{total_params} {name} full shape: {shape} partition0 numel={partitioned_numel} partitioned_padding_numel={partitioned_padding_numel}"
)

# XXX: memory usage doubles here (zero3)
state_dict[name] = torch.cat(
tuple(fp32_flat_groups[i].narrow(0,
offset,
partitioned_numel)
for i in range(world_size)),
0).view(shape)
if unpartitioned_numel > 1:
# XXX: memory usage doubles here (zero3)
state_dict[name] = torch.cat(
tuple(fp32_flat_groups[i].narrow(0,
offset,
partitioned_numel)
for i in range(world_size)),
0).view(shape)
else:
# handle an edge case where there is only 1 element (e.g. bias in a tiny test model)
state_dict[name] = fp32_flat_groups[0].narrow(
0,
offset,
partitioned_numel).view(shape)
offset += partitioned_numel + partitioned_padding_numel

if zero_stage == 3:
offset *= world_size

def align_to_4(x):
return 4 * math.ceil(x / 4)

if zero_stage == 2:
# Z2 started to align to 4 to improve nccl performance
offset = align_to_4(offset)
avail_numel = align_to_4(avail_numel)

# Sanity check
if offset != avail_numel:
raise ValueError(
Expand Down
20 changes: 13 additions & 7 deletions tests/unit/test_zero.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ def _test_zero3_repeat_forward_loop(args, model, hidden_dim):


# testing the fix https://github.com/microsoft/DeepSpeed/pull/1227
# also reproduces the https://github.com/microsoft/DeepSpeed/pull/1372
@pytest.mark.parametrize('zero_stage', [2, 3])
def test_zero_to_fp32(tmpdir, zero_stage):

Expand Down Expand Up @@ -168,6 +169,10 @@ def __init__(self, hidden_dim, n_layers):
self.ll = torch.nn.ModuleList(
torch.nn.Linear(hidden_dim,
hidden_dim) for i in range(n_layers))
# to reproduce https://github.com/microsoft/DeepSpeed/pull/1372 it is important that
# the number of params is uneven - the following adds 4+1 params - the linear
# layers are 6 param each + 5 - so total 17 elements (for 1 gpu)
self.classifier = torch.nn.Linear(4, 1)
self.cross_entropy_loss = torch.nn.CrossEntropyLoss()

def forward(self, x, y):
Expand Down Expand Up @@ -221,14 +226,15 @@ def dump_state_dict(model):
orig_state_dict[name] = param.detach().cpu()
print(orig_state_dict)

fp32_model = load_state_dict_from_zero_checkpoint(model.module, tmpdir)
#dump_state_dict(fp32_model)
if dist.get_rank() == 0:
fp32_model = load_state_dict_from_zero_checkpoint(model.module, tmpdir)
#dump_state_dict(fp32_model)

fp32_state_dict = fp32_model.state_dict()
for name in orig_state_dict.keys():
# float() workaround for torch<1.6
assert torch.allclose(orig_state_dict[name].float(),
fp32_state_dict[name].float())
fp32_state_dict = fp32_model.state_dict()
for name in orig_state_dict.keys():
# float() workaround for torch<1.6
assert torch.allclose(orig_state_dict[name].float(),
fp32_state_dict[name].float())

_test_zero_to_fp32()

Expand Down

0 comments on commit 30537e7

Please sign in to comment.