Skip to content

Commit

Permalink
Enable universal checkpoint for zero stage 1 (#4516)
Browse files Browse the repository at this point in the history
* Enable uni_ckpt for z1

* Remove logging fix to seperate PR. Relocate conversion script to avoid logging circular import issue

* Formatting fix

* PR feedback

* Handle replicated params

* Detect bf16_optimizer

* Docs

* Fix docs
  • Loading branch information
tjruwase authored Oct 25, 2023
1 parent 869629c commit 8fdd9b3
Show file tree
Hide file tree
Showing 8 changed files with 370 additions and 38 deletions.
2 changes: 2 additions & 0 deletions deepspeed/checkpoint/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
ZERO_STAGE = 'zero_stage'
CLIP_GRAD = 'clip_grad'
FP32_WEIGHT_KEY = "fp32"
LOSS_SCALER = 'loss_scaler'

#########################################
# Module checkpoint keys
Expand Down Expand Up @@ -69,3 +70,4 @@
PIPELINE_REPLICATED_PARAMETER_PATTERNS = 'pipeline_replicated_parameter_patterns'
PARAMETER_TO_AVERAGE_PATTERNS = 'parameter_to_average_patterns'
PARAMETER_WITH_ROW_PARALLELISM_PATTERNS = 'parameter_with_row_parallelism_patterns'
TP_REPLICATED_PARAMETER_PATTERNS = 'tp_replicated_parameter_patterns'
303 changes: 303 additions & 0 deletions deepspeed/checkpoint/ds_to_universal.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,303 @@
#!/usr/bin/env python

# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

from functools import partial
import argparse
import glob
import itertools
import multiprocessing
import os
import re
import shutil
import torch
import tqdm
# from pprint import pprint

from deepspeed.checkpoint import DeepSpeedCheckpoint
from deepspeed.checkpoint import (
OPTIMIZER_STATE_DICT,
BASE_OPTIMIZER_STATE,
SINGLE_PARTITION_OF_FP32_GROUPS,
PARAM_SLICE_MAPPINGS,
PARAM_SHAPES,
PARAM,
CAT_DIM,
VOCAB_DIVISIBILITY_PADDING_TENSOR,
ORIGINAL_VOCAB_SIZE,
UNIVERSAL_CHECKPOINT_INFO,
VOCABULARY_PARAMETER_PATTERNS,
PIPELINE_REPLICATED_PARAMETER_PATTERNS,
TP_REPLICATED_PARAMETER_PATTERNS,
PARAMETER_TO_AVERAGE_PATTERNS,
PARAMETER_WITH_ROW_PARALLELISM_PATTERNS,
)


def parse_arguments():
parser = argparse.ArgumentParser()
parser.add_argument('--input_folder', type=str, required=True, help='Input DeepSpeed Checkpoint folder')
parser.add_argument('--output_folder', type=str, required=True, help='Output DeepSpeed checkpoint folder')
parser.add_argument('--num_extract_workers',
default=4,
type=int,
help='How many parallel processes to extract zero shards')
parser.add_argument(
'--num_merge_workers',
default=2,
type=int,
help=
'How many parallel processes to merge tp slices (more memory intensive, use much fewer than --num_extract_workers))'
)
parser.add_argument('--keep_temp_folder',
action='store_true',
help='Preserve temporary folder of intermediate checkpoint slice files. Useful for debugging.')
args = parser.parse_args()
print(f'args = {args}')
return args


def _create_checkpoint_paths(base_folder, iteration, tp_degree, pp_degree):
path_list = []
iter_folder = f'iter_{iteration:07d}'
for i in range(0, tp_degree):
path_list.append([])
for j in range(0, pp_degree):
rank_folder = f'mp_rank_{i:02d}' if pp_degree == 1 else f'mp_rank_{i:02d}_{j:03d}'
ckpt_path = os.path.join(rank_folder, 'model_optim_rng.pt')
path_list[i].append(os.path.join(base_folder, iter_folder, ckpt_path))

return path_list


def _save_checkpoint(file_path, chkpt_sd):
dir, _ = os.path.split(file_path)
os.makedirs(dir, exist_ok=True)
torch.save(chkpt_sd, file_path)


def extract_zero_shards(dir, ds_checkpoint, indices_3D):
pp_index, tp_index, dp_index = indices_3D
sd = ds_checkpoint.get_zero_checkpoint_state(pp_index=pp_index, tp_index=tp_index, dp_index=dp_index)

# pprint(f"Processing {dp_index=} {pp_index=}, {tp_index=}")

optim_sd = sd[OPTIMIZER_STATE_DICT]
param_slice_mappings = optim_sd[PARAM_SLICE_MAPPINGS]
universal_checkpoint_info = ds_checkpoint.get_checkpoint_info(UNIVERSAL_CHECKPOINT_INFO)
pipeline_replicated_params = universal_checkpoint_info.get(PIPELINE_REPLICATED_PARAMETER_PATTERNS, [])
# print(f'{pipeline_replicated_params=}')

# dict
state_groups = optim_sd[BASE_OPTIMIZER_STATE]["state"]
# list
fp32_groups = optim_sd[SINGLE_PARTITION_OF_FP32_GROUPS]
param_groups_cnt = len(state_groups)

for param_group_id in range(param_groups_cnt):

flat_state = dict(
exp_avg=state_groups[param_group_id]["exp_avg"],
exp_avg_sq=state_groups[param_group_id]["exp_avg_sq"],
fp32=fp32_groups[param_group_id],
)

for name, fragment_mapping in param_slice_mappings[param_group_id].items():
if pp_index > 0 and any(re.match(pattern, name) for pattern in pipeline_replicated_params):
# Skip tied weights that are replicated in first and last pp stages
continue

# pprint(f"dpt{dp_index}{pp_index}{tp_index} {param_group_id} {name} => {fragment_mapping.start}:{fragment_mapping.numel}")
for state_key in flat_state.keys():
dump_param_fragment(dir, tp_index, dp_index, state_key, flat_state[state_key], name,
fragment_mapping.start, fragment_mapping.numel)


cnt = 0


def dump_param_fragment(dir, tp_index, dp_index, state_name, state_flat_tensor, param_name, offset, numel):

global cnt # temp hack

param_base_path = os.path.join(dir, param_name, str(tp_index))
os.makedirs(param_base_path, exist_ok=True)

cnt += 1
counter = f"{dp_index:0>2d}"

path = os.path.join(param_base_path, f"{state_name}.{counter}")

#print(f"{param_name}: {offset}: {numel} => {path}")

t = state_flat_tensor.narrow(0, offset, numel).clone()
_save_checkpoint(path, t)


def _merge_zero_shards(param_base_path, state, tp_degree, slice_shape):
slices = []
for tp_index in range(tp_degree):
prefix_path = os.path.join(param_base_path, str(tp_index), f"{state}")
paths = sorted(list(glob.glob(f"{prefix_path}.*")))
shards = [torch.load(p) for p in paths]
slice = torch.cat(shards, dim=0).reshape(slice_shape)
slices.append(slice)

return slices


def _get_vocab_divisibility_padding_tensor(universal_checkpoint_info, padded_vocab_tensor):
original_vocab_size = universal_checkpoint_info.get(ORIGINAL_VOCAB_SIZE)
if padded_vocab_tensor.shape[0] > original_vocab_size:
return padded_vocab_tensor[-1]
else:
return torch.zeros(padded_vocab_tensor.shape[1])


def merge_tp_slices(ds_checkpoint, dir, slice_dir, tp_degree, name_and_shape):
name, shape = name_and_shape
slice_base_path = os.path.join(slice_dir, name)
param_base_path = os.path.join(dir, name)

universal_checkpoint_info = ds_checkpoint.get_checkpoint_info(UNIVERSAL_CHECKPOINT_INFO)
replicated_parameters = universal_checkpoint_info.get(TP_REPLICATED_PARAMETER_PATTERNS, [])
parameters_to_average = universal_checkpoint_info.get(PARAMETER_TO_AVERAGE_PATTERNS, [])
parameters_with_row_parallelism = universal_checkpoint_info.get(PARAMETER_WITH_ROW_PARALLELISM_PATTERNS, [])
vocabulary_parameters = universal_checkpoint_info.get(VOCABULARY_PARAMETER_PATTERNS, [])
for state in ("fp32", "exp_avg", "exp_avg_sq"):
slices = _merge_zero_shards(slice_base_path, state, tp_degree, shape)
final_path = os.path.join(param_base_path, f"{state}.pt")

#print(f"Expected shape: {shape}")
#print(f"Fragment sizes:", list(frag.shape for frag in slices))
ckpt_dict = {}
if any(re.match(pattern, name) for pattern in replicated_parameters):
if len(slices) > 1:
assert all([slices[0].equal(other_slice) for other_slice in slices[1:]])
param = slices[0]
# print(f'replicate {name} using first slice')
elif any(re.match(pattern, name) for pattern in parameters_to_average):
param = sum(slices) / len(slices)
# print(f'merge {name} using average')
else:
cat_dim = 1 if any(re.match(pattern, name) for pattern in parameters_with_row_parallelism) else 0
# print(f"merge {name} with CAT DIM: {cat_dim}")
param = torch.cat(slices, dim=cat_dim)
ckpt_dict[CAT_DIM] = cat_dim

if any(re.match(pattern, name) for pattern in vocabulary_parameters):
#print(f"Before {param.shape=}")
# strip padding
#param = _strip_vocab_padding(ds_checkpoint, param)
ckpt_dict[VOCAB_DIVISIBILITY_PADDING_TENSOR] = _get_vocab_divisibility_padding_tensor(
universal_checkpoint_info, param)
#print(f"After {param.shape=}")

#print(f"Final shape: {param.shape}")
ckpt_dict[PARAM] = param
_save_checkpoint(final_path, ckpt_dict)


def _get_chunks(l, n):
for i in range(0, len(l), n):
yield l[i:i + n]


def _do_parallel_work(do_work, work_chunks, num_workers):
pool = multiprocessing.Pool(num_workers)
for batch in tqdm.tqdm(work_chunks):
pool.map(do_work, batch)
pool.close()
pool.join()


def _extract_zero_shard_files(args, ds_checkpoint, temp_dir):
_3d_range_list = list(
itertools.product(range(ds_checkpoint.pp_degree), range(ds_checkpoint.tp_degree),
range(ds_checkpoint.dp_degree)))
# pprint(f'{_3d_range_list=}')
work_chunks = list(_get_chunks(_3d_range_list, args.num_extract_workers))
# pprint(f'{work_chunks=}')

# extract_zero_shards(temp_dir, ds_checkpoint, _3d_range_list[0])
do_work = partial(extract_zero_shards, temp_dir, ds_checkpoint)
_do_parallel_work(do_work, work_chunks, args.num_extract_workers)


def _merge_tp_slice_files(args, ds_checkpoint, slice_shapes, temp_dir):
work_chunks = list(_get_chunks(list(slice_shapes.items()), args.num_merge_workers))
#pprint(work_chunks)
zero_output_folder = os.path.join(args.output_folder, "zero")
do_work = partial(merge_tp_slices, ds_checkpoint, zero_output_folder, temp_dir, ds_checkpoint.tp_degree)
_do_parallel_work(do_work, work_chunks, args.num_merge_workers)


def _save_optimizer_state(args, ds_checkpoint):
sharded_states = [BASE_OPTIMIZER_STATE, PARAM_SLICE_MAPPINGS, SINGLE_PARTITION_OF_FP32_GROUPS]
sd = ds_checkpoint.get_zero_checkpoint_state(pp_index=0, tp_index=0, dp_index=0)

optim_sd = sd[OPTIMIZER_STATE_DICT]
output_sd = {k: v for k, v in optim_sd.items() if k not in sharded_states}
zero_output_folder = os.path.join(args.output_folder, "zero")
output_file_path = os.path.join(zero_output_folder, f"optimizer_state.pt")
_save_checkpoint(output_file_path, output_sd)


def _check_for_required_state(ds_checkpoint):
universal_checkpoint_info = ds_checkpoint.get_checkpoint_info(UNIVERSAL_CHECKPOINT_INFO)
assert universal_checkpoint_info is not None, f'Required {UNIVERSAL_CHECKPOINT_INFO} state is missing in checkpoint. Verify that client creates this state.'


def main():
print(f'Convert DeepSpeed Checkpoint to Universal Checkpoint')

args = parse_arguments()
print(f'Converting DeepSpeed checkpoint in {args.input_folder} to Universal checkpoint in {args.output_folder}')

ds_checkpoint = DeepSpeedCheckpoint(args.input_folder)
_check_for_required_state(ds_checkpoint)

iteration = ds_checkpoint.get_iteration()
#_create_latest_file(args.output_folder, iteration)
checkpoint_paths = _create_checkpoint_paths(args.output_folder, iteration, ds_checkpoint.tp_degree,
ds_checkpoint.pp_degree)

slice_shapes = []
for mp_rank_file in ds_checkpoint.mp_rank_files:
mp_sd = torch.load(mp_rank_file, map_location=torch.device('cpu'))
slice_shapes += mp_sd[PARAM_SHAPES]

# fix back to normal flat dict, merge duplicates for tp>1
slice_shapes = dict((k, v) for d in slice_shapes for k, v in d.items())
temp_dir = os.path.join(args.output_folder, 'tmp')

print('*** 1. Extracting ZeRO fragments')
_extract_zero_shard_files(args, ds_checkpoint, temp_dir)

print('*** 2. Merging slices')
_merge_tp_slice_files(args, ds_checkpoint, slice_shapes, temp_dir)

print('*** 3. Saving common optimizer states')
_save_optimizer_state(args, ds_checkpoint)

if not args.keep_temp_folder:
shutil.rmtree(temp_dir, ignore_errors=True)

# Copy mp* files into output folder
for f in glob.glob(os.path.join(args.input_folder, 'mp*')):
shutil.copy2(f, args.output_folder)

# Update latest to output folder
checkpoint_root_folder, step_folder = os.path.split(args.output_folder)
latest_file = os.path.join(checkpoint_root_folder, 'latest_universal')
with open(latest_file, "w") as f:
f.write(step_folder)

print('*** Done!')


if __name__ == "__main__":
main()
3 changes: 3 additions & 0 deletions deepspeed/runtime/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -2712,6 +2712,9 @@ def load_checkpoint(self,
if self._optimizer_has_ckpt_event_epilogue():
self.optimizer.checkpoint_event_epilogue()

if self.load_universal_checkpoint():
self.optimizer.update_lp_params()

return load_path, client_states

def _load_checkpoint(self,
Expand Down
11 changes: 7 additions & 4 deletions deepspeed/runtime/pipe/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from deepspeed.utils import logger
from deepspeed.utils.timer import ThroughputTimer
from deepspeed.accelerator import get_accelerator
from deepspeed.runtime.bf16_optimizer import BF16_Optimizer

from ..engine import DeepSpeedEngine, MEMORY_OPT_ALLREDUCE_SIZE
from deepspeed.utils.timer import FORWARD_MICRO_TIMER, FORWARD_GLOBAL_TIMER, BACKWARD_MICRO_TIMER, \
Expand Down Expand Up @@ -75,6 +76,8 @@ def __init__(self, has_bool_tensors=False, *super_args, **super_kwargs):
self.has_bool_tensors = has_bool_tensors
self.eval_return_logits = False
self.outputs = None
# BF16 Optimizer is hardcoded for fp32 gradient accumulation
self.using_bf16_optimizer = type(self.optimizer) == BF16_Optimizer

# used to disable the pipeline all-reduce when used with 1-bit Adam/1-bit LAMB
self.pipeline_enable_backward_allreduce = True
Expand Down Expand Up @@ -257,13 +260,13 @@ def _exec_reduce_tied_grads(self):

weight_group_list = self.module.get_tied_weights_and_groups()
for weight, group in weight_group_list:
grad = weight._hp_grad if self.bfloat16_enabled() else weight.grad
grad = weight._hp_grad if self.using_bf16_optimizer else weight.grad
dist.all_reduce(grad, group=group)

def _exec_reduce_grads(self):
self._force_grad_boundary = True
if self.pipeline_enable_backward_allreduce:
if self.bfloat16_enabled():
if self.using_bf16_optimizer:
# PP+BF16 work for ZeRO Stage 1
self._bf16_reduce_grads()
else:
Expand Down Expand Up @@ -745,7 +748,7 @@ def _exec_backward_pass(self, buffer_id):
part_grad = None
#print(f'RANK={self.global_rank} BEFORE-BWD restored grad={self.grad_layer[0].size()} {self.grad_layer[1].size()}')

if self.bfloat16_enabled() and not self.is_last_stage():
if self.using_bf16_optimizer and not self.is_last_stage():
# manually call because we don't call optimizer.backward()
self.optimizer.clear_lp_grads()

Expand All @@ -757,7 +760,7 @@ def _exec_backward_pass(self, buffer_id):
else:
torch.autograd.backward(tensors=(outputs, ), grad_tensors=(grad_tensors, ))

if self.bfloat16_enabled() and not self.is_last_stage():
if self.using_bf16_optimizer and not self.is_last_stage():
# manually call because we don't call optimizer.backward()
self.optimizer.update_hp_grads(clear_lp_grads=False)

Expand Down
Loading

0 comments on commit 8fdd9b3

Please sign in to comment.