Skip to content

Commit

Permalink
Universal checkpoint for zero stage 1 (deepspeedai#2284)
Browse files Browse the repository at this point in the history
* Refactor universal checkpointing and tensor fragments

* Formatting

* Support zero stage1; Expand TP dim

* Remove debug prints

* Detect sharded optimizer state

* Format fixes

* Encode reshaping guide

* More symbolic constants

Co-authored-by: Michael Wyatt <michaelwyatt@microsoft.com>
  • Loading branch information
tjruwase and mrwyattii authored Oct 18, 2022
1 parent 906b4a0 commit 799120e
Show file tree
Hide file tree
Showing 11 changed files with 291 additions and 146 deletions.
2 changes: 2 additions & 0 deletions deepspeed/checkpoint/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,5 @@
from .zero_checkpoint import ZeROCheckpoint

from .universal_checkpoint import enable_universal_checkpoint

from .constants import *
31 changes: 26 additions & 5 deletions deepspeed/checkpoint/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
PARTITION_COUNT = 'partition_count'
ZERO_STAGE = 'zero_stage'
CLIP_GRAD = 'clip_grad'
PARAM_SLICE_MAPPINGS = 'param_slice_mappings'
FP32_WEIGHT_KEY = "fp32"

#########################################
Expand All @@ -24,20 +23,42 @@
PARAM = 'param'
PARAM_SHAPES = 'param_shapes'
BUFFER_NAMES = 'buffer_names'
VOCAB_DIVISIBILITY_PADDING_TENSOR = 'vocab_divisibility_padding_tensor'
CAT_DIM = "cat_dim"

#########################################
# Checkpoint naming constants
#########################################
MODEL_FILE_PREFIX = 'mp_rank_'
ZERO_FILE_PREFIX = 'bf16_' + 'zero_pp_rank_'
ZERO_FILE_PREFIX = 'zero_pp_rank_'
OPTIM_FILE_SUFFIX = '_optim_states.pt'
MODEL_FILE_SUFFIX = '_model_states.pt'
LAYER_FILE_PREFIX = 'layer_'
BF16_ZERO_FILE_PREFIX = ZERO_FILE_PREFIX
BF16_ZERO_FILE_PREFIX = 'bf16_' + ZERO_FILE_PREFIX
FP16_ZERO_FILE_PREFIX = 'fp16_' + ZERO_FILE_PREFIX

#########################################
# Checkpoint utility keys
#########################################
DS_VERSION = 'ds_version'

#########################################
# Universal Checkpoint keys
#########################################
UNIVERSAL_CHECKPOINT_INFO = 'universal_checkpoint_info'
UNIVERSAL_CHECKPOINT_VERSION_KEY = 'universal_checkpoint_version'
# Reserve version 0.1 for the hardcoded logic used in BLOOM-176B training
UNIVERSAL_CHECKPOINT_VERSION_VALUE = 0.2

# Vocabulary padding
VOCAB_DIVISIBILITY_PADDING_TENSOR = 'vocab_divisibility_padding_tensor'
PADDED_VOCAB_SIZE = 'padded_vocab_size'
ORIGINAL_VOCAB_SIZE = 'original_vocab_size'

# Parameter splitting/merging
PARAM_SLICE_MAPPINGS = 'param_slice_mappings'
CAT_DIM = "cat_dim"

# Regex list of parameters that require special handling
VOCABULARY_PARAMETER_PATTERNS = 'vocabulary_parameter_patterns'
PIPELINE_REPLICATED_PARAMETER_PATTERNS = 'pipeline_replicated_parameter_patterns'
PARAMETER_TO_AVERAGE_PATTERNS = 'parameter_to_average_patterns'
PARAMETER_WITH_ROW_PARALLELISM_PATTERNS = 'parameter_with_row_parallelism_patterns'
55 changes: 27 additions & 28 deletions deepspeed/checkpoint/deepspeed_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
get_files,
get_files_with_prefix)

from .constants import (ZERO_FILE_PREFIX, MODEL_FILE_PREFIX, LAYER_FILE_PREFIX)
from .constants import (MODEL_FILE_PREFIX, LAYER_FILE_PREFIX)

from .reshape_meg_2d import reshape_meg_2d_parallel, meg_2d_parallel_map
from .zero_checkpoint import ZeROCheckpoint
Expand Down Expand Up @@ -39,37 +39,36 @@ def __init__(self, dir, tp_degree=None, pp_degree=None, dp_degree=None):
self.dir = dir
self._validate_folder(dir)

self.zero_checkpoint = ZeROCheckpoint(dir)

self.file_list = get_files(dir)
self.zero_files = get_files_with_prefix(self.file_list, ZERO_FILE_PREFIX)
self.layer_files = get_files_with_prefix(self.file_list, LAYER_FILE_PREFIX)
self.mp_rank_files = get_files_with_prefix(self.file_list, MODEL_FILE_PREFIX)

self.layer_keys = self._get_layer_keys()
self.layer_count = len(self.layer_keys)
self.original_tp_degree = len(
get_files_with_prefix(self.layer_files,
f'{LAYER_FILE_PREFIX}01'))
self.original_pp_degree = len(self.mp_rank_files) // self.original_tp_degree
self.original_dp_degree = max(
1,
len(self.zero_files) // (self.original_pp_degree * self.original_tp_degree))

self.tp_degree = self.original_tp_degree if tp_degree is None else tp_degree
self.pp_degree = self.original_pp_degree if pp_degree is None else pp_degree
self.dp_degree = self.original_dp_degree if dp_degree is None else dp_degree

self.original_world_size = self.original_tp_degree * self.original_pp_degree * self.original_dp_degree

self.tp_degree = self.zero_checkpoint.get_src_tp_degree(
) if tp_degree is None else tp_degree
self.pp_degree = self.zero_checkpoint.get_src_pp_degree(
) if pp_degree is None else pp_degree
self.dp_degree = self.zero_checkpoint.get_src_dp_degree(
) if dp_degree is None else dp_degree

self.original_world_size = self.zero_checkpoint.get_src_tp_degree(
) * self.zero_checkpoint.get_src_pp_degree(
) * self.zero_checkpoint.get_src_dp_degree()
self.world_size = self.tp_degree * self.pp_degree * self.dp_degree

self.old_2d_map = meg_2d_parallel_map(self.original_pp_degree,
self.original_tp_degree)
self.old_2d_map = meg_2d_parallel_map(self.zero_checkpoint.get_src_pp_degree(),
self.zero_checkpoint.get_src_tp_degree())
self.old_2d_map.simple_init()
self.new_2d_map = reshape_meg_2d_parallel(old_pp_degree=self.original_pp_degree,
old_tp_degree=self.original_tp_degree,
new_pp_degree=self.pp_degree,
new_tp_degree=self.tp_degree)
self.new_2d_map = reshape_meg_2d_parallel(
old_pp_degree=self.zero_checkpoint.get_src_pp_degree(),
old_tp_degree=self.zero_checkpoint.get_src_tp_degree(),
new_pp_degree=self.pp_degree,
new_tp_degree=self.tp_degree)

self.zero_checkpoint = ZeROCheckpoint(dir)
if self.is_change_pp_degree() or self.is_change_tp_degree(
) or self.is_change_dp_degree():
self.zero_checkpoint.reshape(
Expand All @@ -88,13 +87,13 @@ def __init__(self, dir, tp_degree=None, pp_degree=None, dp_degree=None):
self._build_global_state()

def is_change_tp_degree(self):
return self.tp_degree != self.original_tp_degree
return self.tp_degree != self.zero_checkpoint.get_src_tp_degree()

def is_change_pp_degree(self):
return self.pp_degree != self.original_pp_degree
return self.pp_degree != self.zero_checkpoint.get_src_pp_degree()

def is_change_dp_degree(self):
return self.dp_degree != self.original_dp_degree
return self.dp_degree != self.zero_checkpoint.get_src_dp_degree()

def show_2d_mapping(self):
print(f'reshaped 2d map ---- begin')
Expand Down Expand Up @@ -171,8 +170,8 @@ def _get_checkpoint_value(self, key):
def get_args(self):
return self._get_checkpoint_value(ARGS_KEY)

def get_checkpoint_info(self):
return self._get_checkpoint_value(CHECKPOINT_INFO_KEY)
def get_checkpoint_info(self, info_key=CHECKPOINT_INFO_KEY):
return self._get_checkpoint_value(info_key)

def get_2d_parallel_state(self, tp_index: int, pp_index: int) -> dict:
assert tp_index < self.tp_degree
Expand Down Expand Up @@ -272,8 +271,8 @@ def _build_transformer_file_map(self):

def _sanity_check(self):
assert len(self.mp_rank_files) % self.tp_degree == 0
assert len(self.zero_files) % (self.pp_degree * self.tp_degree) == 0
assert len(self.layer_keys) > 2
assert self.zero_checkpoint.num_files % (self.pp_degree * self.tp_degree) == 0
# XXX: fix me - isn't always the case
# only true with --pp-partition-method 'type:transformer|embedding' \
# assert (len(self.layer_keys) - 2) % self.pp_degree == 0
Expand Down
25 changes: 19 additions & 6 deletions deepspeed/checkpoint/reshape_3d_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from .reshape_utils import (get_files, get_files_with_prefix, partition_data)
from .reshape_utils import (get_files,
get_files_with_prefix,
partition_data,
get_zero_files)

from .constants import (ZERO_FILE_PREFIX, MODEL_FILE_PREFIX, LAYER_FILE_PREFIX)
from .constants import (MODEL_FILE_PREFIX, LAYER_FILE_PREFIX)

from .reshape_meg_2d import (reshape_meg_2d_parallel, meg_2d_parallel_map)

Expand Down Expand Up @@ -34,6 +37,9 @@ def reshape(self, target_3d_desc, verbose=False):
def get_desc(self):
return f'{PP_DIM},{TP_DIM},{DP_DIM} = ({self.pp_degree}, {self.tp_degree}, {self.dp_degree})'

def world_size(self):
return self.pp_degree * self.tp_degree * self.dp_degree

def is_valid(self, pp_index, tp_index, dp_index):
err_msg = []
valid = True
Expand Down Expand Up @@ -70,10 +76,17 @@ def can_reshape(self, target_3d_desc):

def get_model_3d_descriptor(dir):
file_list = get_files(dir)
tp_degree = len(get_files_with_prefix(file_list, f'{LAYER_FILE_PREFIX}01'))
pp_degree = len(get_files_with_prefix(file_list, MODEL_FILE_PREFIX)) // tp_degree
num_zero_files = len(get_files_with_prefix(file_list, ZERO_FILE_PREFIX))
dp_degree = max(1, num_zero_files // (pp_degree * tp_degree))
zero_file_list = get_zero_files(dir)
num_pp0_files = len(get_files_with_prefix(file_list, f'{LAYER_FILE_PREFIX}01'))
if num_pp0_files > 0:
tp_degree = num_pp0_files
pp_degree = len(get_files_with_prefix(file_list, MODEL_FILE_PREFIX)) // tp_degree
dp_degree = max(1, len(zero_file_list) // (pp_degree * tp_degree))
else:
tp_degree = len(get_files_with_prefix(file_list, MODEL_FILE_PREFIX))
dp_degree = max(1, len(zero_file_list) // tp_degree)
pp_degree = 0

return model_3d_desc(pp_degree, tp_degree, dp_degree)


Expand Down
11 changes: 11 additions & 0 deletions deepspeed/checkpoint/reshape_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
import torch
from collections import OrderedDict
from .constants import (ZERO_FILE_PREFIX, FP16_ZERO_FILE_PREFIX, BF16_ZERO_FILE_PREFIX)


def basic_folder_validation(dir):
Expand Down Expand Up @@ -32,6 +33,16 @@ def get_files(dir):
return file_list


def get_zero_files(dir):
file_list = get_files(dir)
for prefix in [ZERO_FILE_PREFIX, FP16_ZERO_FILE_PREFIX, BF16_ZERO_FILE_PREFIX]:
zero_files = get_files_with_prefix(file_list, prefix)
if len(zero_files) > 0:
return zero_files

return []


def partition_data(data_list, num_partitions):
num_elems = len(data_list)
assert num_elems % num_partitions == 0
Expand Down
8 changes: 3 additions & 5 deletions deepspeed/checkpoint/universal_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import os
import torch
import types

from .constants import (FP32_WEIGHT_KEY,
PARAM,
VOCAB_DIVISIBILITY_PADDING_TENSOR,
Expand Down Expand Up @@ -54,18 +53,17 @@ def load_hp_checkpoint_state(self, folder, tp_rank, tp_world_size):
padded_target_vocab_size = self.shape[0] * tp_world_size
if padded_target_vocab_size > full_hp_param.shape[0]:
# Need to expand
padding_tensor = vocab_divisibility_padding_tensor.expand(
padded_target_vocab_size - full_hp_param.shape[0])
padding_size = padded_target_vocab_size - full_hp_param.shape[0]
# Implement the following concat in efficient way using pad
#full_hp_param = torch.cat((full_hp_param, padding_tensor), 0)
full_hp_param = torch.nn.functional.pad(full_hp_param,
(0,
0,
0,
padding_tensor.shape[0]),
padding_size),
"constant",
0)
full_hp_param[:-padding_tensor.shape[0], :] = padding_tensor
full_hp_param[:-padding_size, :] = vocab_divisibility_padding_tensor
else:
# Need to shrink or keep the same
full_hp_param = full_hp_param[:padded_target_vocab_size, :]
Expand Down
30 changes: 15 additions & 15 deletions deepspeed/checkpoint/zero_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,9 @@
from .constants import (BASE_OPTIMIZER_STATE,
GROUP_PADDINGS,
OPTIMIZER_STATE_DICT,
PARTITION_COUNT,
ZERO_FILE_PREFIX,
BF16_ZERO_FILE_PREFIX)
PARTITION_COUNT)

from .reshape_utils import (basic_folder_validation,
get_files,
get_files_with_prefix,
merge_state)
from .reshape_utils import (basic_folder_validation, get_zero_files, merge_state)

from .reshape_3d_utils import (model_3d_desc, get_model_3d_descriptor)

Expand All @@ -21,7 +16,7 @@ class ZeROCheckpoint(object):
def __init__(self, dir):
basic_folder_validation(dir)
self.dir = dir
self.file_list = self._get_zero_files(dir)
self.file_list = get_zero_files(dir)
self.num_files = len(self.file_list)
assert self.num_files > 0, f'No ZeRO files found in {dir}'

Expand All @@ -31,6 +26,18 @@ def __init__(self, dir):
dp_degree=self.src_3d.dp_degree)
self._3d_file_map = self.src_3d.reshape(self.target_3d)

def get_src_world_size(self):
return self.src_3d.world_size()

def get_src_tp_degree(self):
return self.src_3d.tp_degree

def get_src_pp_degree(self):
return self.src_3d.pp_degree

def get_src_dp_degree(self):
return self.src_3d.dp_degree

def get_file_indices_for_rank(self, pp_index, tp_index, dp_index):
assert dp_index < len(self._3d_file_map), f'DP index {dp_index} >= DP degree {len(self._3d_file_map)}'
dp_2d_map = self._3d_file_map[dp_index]
Expand Down Expand Up @@ -137,10 +144,3 @@ def _update_partition_count(self, sd):
num_groups = len(partition_counts)
sd[OPTIMIZER_STATE_DICT][PARTITION_COUNT] = [self.target_3d.dp_degree
] * num_groups

def _get_zero_files(self, dir):
file_list = get_files(dir)
zero_files = get_files_with_prefix(file_list, ZERO_FILE_PREFIX)
if len(zero_files) > 0:
return zero_files
return get_files_with_prefix(file_list, BF16_ZERO_FILE_PREFIX)
Loading

0 comments on commit 799120e

Please sign in to comment.