From dd0c8fa73bd84f2fe377aa46ad34b8b0bdd977a6 Mon Sep 17 00:00:00 2001 From: eelxpeng Date: Mon, 4 Oct 2021 10:23:45 -0700 Subject: [PATCH] Revise param_shapes to be a list of ordered dict (#1424) * Revise param_shapes to be a list of ordered dict * test i can push * add tests; split z2 and z3 into separate funcs Co-authored-by: Xiaopeng Li Co-authored-by: Stas Bekman Co-authored-by: Olatunji Ruwase --- deepspeed/runtime/engine.py | 6 +- deepspeed/utils/zero_to_fp32.py | 208 ++++++++++++++++++++++---------- tests/unit/test_zero.py | 125 +++++++++++++++++-- 3 files changed, 261 insertions(+), 78 deletions(-) diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index ead38583fd75..1fdab186fba1 100755 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -2441,7 +2441,7 @@ def _get_zero_param_shapes(self): optimizer.fp16_groups seems to be the easiest to use as it's in all zeroX versions. """ - param_shapes = OrderedDict() + param_group_shapes = [] cnt = 0 numel = 0 @@ -2453,6 +2453,7 @@ def _get_zero_param_shapes(self): fp16_groups = self.optimizer.fp16_groups for fp16_group in fp16_groups: + param_shapes = OrderedDict() for param in fp16_group: cnt += 1 numel += param.ds_numel if hasattr(param, "ds_numel") else param.numel() @@ -2464,9 +2465,10 @@ def _get_zero_param_shapes(self): # uncomment to debug zero_to_fp32.py problems # if self.global_rank == 0: print(f"saving param {name} {shape} (numel={shape.numel()})") + param_group_shapes.append(param_shapes) # if self.global_rank == 0: print(f"Total saved {numel} numels in {cnt} params") - return param_shapes + return param_group_shapes def _copy_recovery_script(self, save_path): base_dir = os.path.dirname(os.path.dirname(__file__)) diff --git a/deepspeed/utils/zero_to_fp32.py b/deepspeed/utils/zero_to_fp32.py index 337c21abd7e6..37787a7962af 100755 --- a/deepspeed/utils/zero_to_fp32.py +++ b/deepspeed/utils/zero_to_fp32.py @@ -53,9 +53,6 @@ def get_optim_files(checkpoint_dir): def parse_model_state(file): - - # load to cpu - device = torch.device('cpu') state_dict = torch.load(file, map_location=device) if "buffer_names" not in state_dict: @@ -85,10 +82,10 @@ def parse_optim_states(files, ds_checkpoint_dir): zero_stage = state_dicts[0]['optimizer_state_dict']["zero_stage"] world_size = state_dicts[0]['optimizer_state_dict']["partition_count"] param_shapes = state_dicts[0]["param_shapes"] - '''For ZeRO-2 each param group can have different partiiton_count as data parallelism for expert - parameters can be different from data parallelism for non-expert parameters. So we can just use the max of - the partition_count to get the dp world_size. - ''' + # For ZeRO-2 each param group can have different partition_count as data parallelism for expert + # parameters can be different from data parallelism for non-expert parameters. So we can just + # use the max of the partition_count to get the dp world_size. + if type(world_size) is list: world_size = max(world_size) @@ -106,24 +103,24 @@ def parse_optim_states(files, ds_checkpoint_dir): else: raise ValueError(f"unknown zero stage {zero_stage}") - # if there is more than one param group, there will be multiple flattened tensors - one - # flattened tensor per group - for simplicity merge them into a single tensor - # - # XXX: could make the script more memory efficient for when there are multiple groups - it - # will require matching the sub-lists of param_shapes for each param group flattened tensor - fp32_flat_groups = [ - torch.cat(state_dicts[i]['optimizer_state_dict'][fp32_groups_key], - 0) for i in range(len(state_dicts)) - ] - - return zero_stage, world_size, param_shapes, fp32_flat_groups + if zero_stage == 2: + fp32_flat_groups = [ + state_dicts[i]['optimizer_state_dict'][fp32_groups_key] + for i in range(len(state_dicts)) + ] + elif zero_stage == 3: + # if there is more than one param group, there will be multiple flattened tensors - one + # flattened tensor per group - for simplicity merge them into a single tensor + # + # XXX: could make the script more memory efficient for when there are multiple groups - it + # will require matching the sub-lists of param_shapes for each param group flattened tensor + fp32_flat_groups = [ + torch.cat(state_dicts[i]['optimizer_state_dict'][fp32_groups_key], + 0) for i in range(len(state_dicts)) + ] -def zero3_partitioned_param_info(unpartitioned_numel, world_size): - remainder = unpartitioned_numel % world_size - padding_numel = (world_size - remainder) if remainder else 0 - partitioned_numel = math.ceil(unpartitioned_numel / world_size) - return partitioned_numel, padding_numel + return zero_stage, world_size, param_shapes, fp32_flat_groups def _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir): @@ -144,29 +141,48 @@ def _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir): model_file = get_model_state_file(ds_checkpoint_dir, zero_stage) buffers = parse_model_state(model_file) + if zero_stage == 2: + return _get_fp32_state_dict_from_zero2_checkpoint(world_size, + param_shapes, + fp32_flat_groups, + buffers) + elif zero_stage == 3: + return _get_fp32_state_dict_from_zero3_checkpoint(world_size, + param_shapes, + fp32_flat_groups, + buffers) + + +def _get_fp32_state_dict_from_zero2_checkpoint(world_size, + param_shapes, + fp32_flat_groups, + buffers): + # Reconstruction protocol: # - # - for zero2 we just need to concat the partitions back to back and reconsolidate over one huge - # flat buffer - no need to deal with padding since if there is any it will be only in the tail - # of the last partition so there it will be just left out - # - # - for zero3 we need to zip the partitions together at boundary of each param, re-consolidating - # each param, while dealing with padding if any + # XXX: document this if debug: for i in range(world_size): - print(f"fp32_flat_groups[i].shape={fp32_flat_groups[i].shape}") - - if zero_stage == 2: - # XXX: memory usage doubles here (zero2) - full_single_fp32_vector = torch.cat(fp32_flat_groups, 0) - avail_numel = full_single_fp32_vector.numel() - elif zero_stage == 3: - avail_numel = fp32_flat_groups[0].numel() * world_size + for j in range(len(fp32_flat_groups[0])): + print(f"fp32_flat_groups[{i}][{j}].shape={fp32_flat_groups[i][j].shape}") + + # XXX: memory usage doubles here (zero2) + num_param_groups = len(fp32_flat_groups[0]) + merged_single_partition_of_fp32_groups = [] + for i in range(num_param_groups): + merged_partitions = [sd[i] for sd in fp32_flat_groups] + full_single_fp32_vector = torch.cat(merged_partitions, 0) + merged_single_partition_of_fp32_groups.append(full_single_fp32_vector) + avail_numel = sum([ + full_single_fp32_vector.numel() + for full_single_fp32_vector in merged_single_partition_of_fp32_groups + ]) if debug: - wanted_params = len(param_shapes) - wanted_numel = sum(shape.numel() for shape in param_shapes.values()) + wanted_params = sum([len(shapes) for shapes in param_shapes]) + wanted_numel = sum( + [sum(shape.numel() for shape in shapes.values()) for shapes in param_shapes]) # not asserting if there is a mismatch due to possible padding print(f"Have {avail_numel} numels to process.") print(f"Need {wanted_numel} numels in {wanted_params} params.") @@ -181,16 +197,17 @@ def _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir): # params # XXX: for huge models that can't fit into the host's RAM we will have to recode this to support # out-of-core computing solution - offset = 0 total_numel = 0 total_params = 0 - for name, shape in param_shapes.items(): + for shapes, full_single_fp32_vector in zip(param_shapes, merged_single_partition_of_fp32_groups): + offset = 0 + avail_numel = full_single_fp32_vector.numel() + for name, shape in shapes.items(): - unpartitioned_numel = shape.numel() - total_numel += unpartitioned_numel - total_params += 1 + unpartitioned_numel = shape.numel() + total_numel += unpartitioned_numel + total_params += 1 - if zero_stage == 2: if debug: print( f"{name} full shape: {shape} unpartitioned numel {unpartitioned_numel} " @@ -201,26 +218,6 @@ def _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir): unpartitioned_numel).view(shape) offset += unpartitioned_numel - elif zero_stage == 3: - partitioned_numel, partitioned_padding_numel = zero3_partitioned_param_info(unpartitioned_numel, world_size) - - if debug: - print( - f"{total_params} {name} full shape: {shape} partition0 numel={partitioned_numel} partitioned_padding_numel={partitioned_padding_numel}" - ) - - # XXX: memory usage doubles here (zero3) - state_dict[name] = torch.cat( - tuple(fp32_flat_groups[i].narrow(0, - offset, - partitioned_numel) - for i in range(world_size)), - 0).narrow(0, - 0, - unpartitioned_numel).view(shape) - offset += partitioned_numel - - if zero_stage == 2: # Z2 started to align to 2*world_size to improve nccl performance. Therefore both offset and # avail_numel can differ by anywhere between 0..2*world_size. Due to two unrelated complex # paddings performed in the code it's almost impossible to predict the exact numbers w/o the @@ -239,8 +236,85 @@ def zero2_align(x): if debug: print(f"aligned offset={offset}, avail_numel={avail_numel}") - elif zero_stage == 3: - offset *= world_size + # Sanity check + if offset != avail_numel: + raise ValueError( + f"consumed {offset} numels out of {avail_numel} - something is wrong") + + print( + f"Reconstructed fp32 state dict with {total_params} params {total_numel} elements" + ) + + return state_dict + + +def zero3_partitioned_param_info(unpartitioned_numel, world_size): + remainder = unpartitioned_numel % world_size + padding_numel = (world_size - remainder) if remainder else 0 + partitioned_numel = math.ceil(unpartitioned_numel / world_size) + return partitioned_numel, padding_numel + + +def _get_fp32_state_dict_from_zero3_checkpoint(world_size, + param_shapes, + fp32_flat_groups, + buffers): + + # Reconstruction protocol: For zero3 we need to zip the partitions together at boundary of each + # param, re-consolidating each param, while dealing with padding if any + + avail_numel = fp32_flat_groups[0].numel() * world_size + # merge list of dicts, preserving order + param_shapes = {k: v for d in param_shapes for k, v in d.items()} + + if debug: + for i in range(world_size): + print(f"fp32_flat_groups[{i}].shape={fp32_flat_groups[i].shape}") + + wanted_params = len(param_shapes) + wanted_numel = sum(shape.numel() for shape in param_shapes.values()) + # not asserting if there is a mismatch due to possible padding + print(f"Have {avail_numel} numels to process.") + print(f"Need {wanted_numel} numels in {wanted_params} params.") + + state_dict = OrderedDict() + + # buffers + state_dict.update(buffers) + if debug: + print(f"added {len(buffers)} buffers") + + # params + # XXX: for huge models that can't fit into the host's RAM we will have to recode this to support + # out-of-core computing solution + offset = 0 + total_numel = 0 + total_params = 0 + for name, shape in param_shapes.items(): + + unpartitioned_numel = shape.numel() + total_numel += unpartitioned_numel + total_params += 1 + + partitioned_numel, partitioned_padding_numel = zero3_partitioned_param_info(unpartitioned_numel, world_size) + + if debug: + print( + f"{total_params} {name} full shape: {shape} partition0 numel={partitioned_numel} partitioned_padding_numel={partitioned_padding_numel}" + ) + + # XXX: memory usage doubles here + state_dict[name] = torch.cat( + tuple(fp32_flat_groups[i].narrow(0, + offset, + partitioned_numel) + for i in range(world_size)), + 0).narrow(0, + 0, + unpartitioned_numel).view(shape) + offset += partitioned_numel + + offset *= world_size # Sanity check if offset != avail_numel: diff --git a/tests/unit/test_zero.py b/tests/unit/test_zero.py index 74925d97c560..173e60e26b81 100755 --- a/tests/unit/test_zero.py +++ b/tests/unit/test_zero.py @@ -30,6 +30,13 @@ def enable_grads(model): enable_grads(model) +def dump_state_dict(model): + if dist.get_rank() == 0: + print("state_dict:") + for name, param in model.named_parameters(): + print(f"{name} {param.data}") + + @pytest.mark.parametrize('zero_stage', [1, 2, 3]) def test_zero_unbalanced_gradients(tmpdir, zero_stage): config_dict = { @@ -135,10 +142,9 @@ def _test_zero3_repeat_forward_loop(args, model, hidden_dim): # testing the fix https://github.com/microsoft/DeepSpeed/pull/1227 # also reproduces the https://github.com/microsoft/DeepSpeed/pull/1372 @pytest.mark.parametrize('zero_stage', [2, 3]) -def test_zero_to_fp32(tmpdir, zero_stage): +def test_zero_to_fp32_1_param_group(tmpdir, zero_stage): - # TODO: - # - need to test with multiple param groups + # XXX: ideally refactor with the 2_param_group test as 75% is the same # force all params to be partitioned by forcing threshold=0 config_dict = { @@ -209,11 +215,113 @@ def forward(self, x, y): # make sure all sides saved it dist.barrier() - def dump_state_dict(model): - if dist.get_rank() != 0: - return - for name, param in model.named_parameters(): - print(f"{name} {param}") + if zero_stage == 3: + with deepspeed.zero.GatheredParameters(list( + model.module.parameters(recurse=True)), + modifier_rank=None): + pass # this forces gathering the model + + #dump_state_dict(model) + + orig_state_dict = {} + for name, param in model.module.named_parameters(): + orig_state_dict[name] = param.detach().cpu() + + if dist.get_rank() == 0: + fp32_model = load_state_dict_from_zero_checkpoint(model.module, tmpdir) + #dump_state_dict(fp32_model) + + fp32_state_dict = fp32_model.state_dict() + for name in orig_state_dict.keys(): + # float() workaround for torch<1.6 + assert torch.allclose(orig_state_dict[name].float(), + fp32_state_dict[name].float()) + + _test_zero_to_fp32() + + +@pytest.mark.parametrize('zero_stage', [2, 3]) +def test_zero_to_fp32_2_param_groups(tmpdir, zero_stage): + + # TODO: + # - need to test with multiple param groups + + # force all params to be partitioned by forcing threshold=0 + config_dict = { + "train_micro_batch_size_per_gpu": 2, + "gradient_accumulation_steps": 2, + "steps_per_print": 1, + "zero_allow_untested_optimizer": 1, + "zero_optimization": { + "stage": zero_stage, + "stage3_param_persistence_threshold": 0 + }, + "optimizer": { + "type": "Adam", + "params": { + "lr": 1e-3 + } + }, + "fp16": { + "enabled": True, + "initial_scale_power": 8 + } + } + + @distributed_test(world_size=[2]) + def _test_zero_to_fp32(): + class MyModel(torch.nn.Module): + def __init__(self, hidden_dim, n_layers): + super().__init__() + self.ll = torch.nn.ModuleList( + torch.nn.Linear(hidden_dim, + hidden_dim) for i in range(n_layers)) + self.cross_entropy_loss = torch.nn.CrossEntropyLoss() + + def forward(self, x, y): + hidden = x + for l in self.ll: + hidden = l(hidden) + return self.cross_entropy_loss(hidden, y) + + args = args_from_dict(tmpdir, config_dict) + hidden_dim = 3 + + world_size = dist.get_world_size() + n_layers = world_size * 2 + model = MyModel(hidden_dim=hidden_dim, n_layers=n_layers) + + optim_groups = [ + { + "params": [l.weight for l in model.ll], + "weight_decay": 0.01, + }, + { + "params": [l.bias for l in model.ll], + "weight_decay": 0.0 + }, + ] + optim = torch.optim.SGD(optim_groups, lr=0.1) + + model, _, _, _ = deepspeed.initialize(args=args, + model=model, + model_parameters=model.parameters(), + optimizer = optim, + ) + data_loader = random_dataloader(model=model, + total_samples=16, + hidden_dim=hidden_dim, + device=model.device) + + for i, batch in enumerate(data_loader): + loss = model(batch[0], batch[1]) + model.backward(loss) + model.step() + + model.save_checkpoint(tmpdir) + + # make sure all sides saved it + dist.barrier() if zero_stage == 3: with deepspeed.zero.GatheredParameters(list( @@ -226,7 +334,6 @@ def dump_state_dict(model): orig_state_dict = {} for name, param in model.module.named_parameters(): orig_state_dict[name] = param.detach().cpu() - print(orig_state_dict) if dist.get_rank() == 0: fp32_model = load_state_dict_from_zero_checkpoint(model.module, tmpdir)