Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

save_checkpoint, load_checkpoint and aggregate_checkpoints #6136

Merged
merged 12 commits into from
Dec 18, 2020
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions orttraining/orttraining/python/training/_checkpoint_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import h5py
from collections.abc import Mapping
import pickle

def _dfs_save(group, save_obj):
"""Recursively go over each level in the save_obj dictionary and save values to a hdf5 group"""
Expand Down Expand Up @@ -79,3 +80,17 @@ def load(path, key=None):
_dfs_load(f, load_obj)

return load_obj

def to_serialized_hex(user_dict):
"""Serialize the user_dict and convert the serialized bytes to a hex string and return"""

return pickle.dumps(user_dict).hex()

def from_serialized_hex(serialized_hex):
"""Convert serialized_hex to bytes and deserialize it and return"""

try:
serialized_hex = serialized_hex.decode()
except AttributeError:
pass
baijumeswani marked this conversation as resolved.
Show resolved Hide resolved
return pickle.loads(bytes.fromhex(serialized_hex))
45 changes: 44 additions & 1 deletion orttraining/orttraining/python/training/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,4 +202,47 @@ def state_dict_trainer_options_key():
def state_dict_full_precision_key():
"""Returns the full precision key name in the state dictionary"""

return 'fp32'
return 'full_precision'

def state_dict_original_dimension_key():
"""Returns the original dimension key name in the state dictionary"""

return 'original_dim'

def state_dict_sharded_optimizer_keys():
"""Returns the optimizer key names that can be sharded in the state dictionary"""

return {
'Moment_1',
'Moment_2'
}

def state_dict_user_dict_key():
"""Returns the user dict key name in the state dictionary"""

return 'user_dict'

def state_dict_trainer_options_mixed_precision_key():
"""Returns the trainer options mixed precision key name in the state dictionary"""

return 'mixed_precision'

def state_dict_trainer_options_zero_stage_key():
"""Returns the trainer options zero_stage key name in the state dictionary"""

return 'zero_stage'

def state_dict_trainer_options_world_rank_key():
"""Returns the trainer options world_rank key name in the state dictionary"""

return 'world_rank'

def state_dict_trainer_options_world_size_key():
"""Returns the trainer options world_size key name in the state dictionary"""

return 'world_size'

def state_dict_trainer_options_optimizer_name_key():
"""Returns the trainer options optimizer_name key name in the state dictionary"""

return 'optimizer_name'
230 changes: 229 additions & 1 deletion orttraining/orttraining/python/training/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import os
import torch
import warnings
from . import _checkpoint_storage, _utils


################################################################################
Expand Down Expand Up @@ -108,6 +109,233 @@ def experimental_load_checkpoint(ort_trainer, checkpoint_dir, checkpoint_prefix=
else:
return _load_single_checkpoint(ort_trainer, checkpoint_dir, checkpoint_prefix, is_partitioned, strict)

def _order_paths(paths):
"""Reorders the given paths in ascending order of rank and return the ordered list"""

trainer_options_path_tuples = []
world_rank = _utils.state_dict_trainer_options_world_rank_key()

for path in paths:
trainer_options_path_tuples.append((_checkpoint_storage.load(path,
key=_utils.state_dict_trainer_options_key()), path))

ordered_paths = [path for _, path in sorted(trainer_options_path_tuples,
key=lambda trainer_options_path_pair: trainer_options_path_pair[0][world_rank])]

return ordered_paths

def _add_or_update_sharded_key_for_zero(state_key, state_value, state_sub_dict,
model_state_key, original_dim, sharded_states_original_dims):
"""Add or update the record for the sharded state_key in the state_sub_dict"""

# record the original dimension for this state
sharded_states_original_dims[model_state_key] = original_dim

if state_key in state_sub_dict:
# state_dict already contains a record for this state
# since this state is sharded, concatenate the state value to
# the record in the state_dict
state_sub_dict[state_key] = \
np.concatenate((state_sub_dict[state_key], state_value))
baijumeswani marked this conversation as resolved.
Show resolved Hide resolved
else:
# create a new entry for this state in the state_dict
state_sub_dict[state_key] = state_value

def _add_or_validate_unsharded_key_for_zero(state_key, state_value, state_sub_dict, mismatch_error_string):
"""Add or validate the record for the unsharded state_key in the state_sub_dict"""

if state_key in state_sub_dict:
# state_dict already contains a record for this unsharded state.
# assert that all values are the same for this previously loaded state
assert (state_sub_dict[state_key] == state_value).all(), mismatch_error_string
else:
# create a new entry for this state in the state_sub_dict
state_sub_dict[state_key] = state_value

def _aggregate_model_states(rank_state_dict, sharded_states_original_dims, state_dict):
"""Aggregates all model states from the rank_state_dict into state_dict"""

model = _utils.state_dict_model_key()
full_precision = _utils.state_dict_full_precision_key()
partition_info = _utils.state_dict_partition_info_key()
original_dim = _utils.state_dict_original_dimension_key()

# if there are no model states in the rank_state_dict, no model aggregation is needed
if model not in rank_state_dict:
return

if model not in state_dict:
state_dict[model] = {}

if full_precision not in state_dict[model]:
state_dict[model][full_precision] = {}

# iterate over all model state keys
for model_state_key, model_state_value in rank_state_dict[model][full_precision].items():
if model_state_key in rank_state_dict[partition_info]:
# this model state is sharded since a record exists in the partition_info subdict
_add_or_update_sharded_key_for_zero(model_state_key, model_state_value,
state_dict[model][full_precision], model_state_key,
rank_state_dict[partition_info][model_state_key][original_dim], sharded_states_original_dims)
else:
# this model state is not sharded since a record for it does not exist in the partition_info subdict
_add_or_validate_unsharded_key_for_zero(model_state_key, model_state_value,
state_dict[model][full_precision], "Value mismatch for model state {}".format(model_state_key))

def _aggregate_optimizer_states(rank_state_dict, sharded_states_original_dims, state_dict):
"""Aggregates all optimizer states from the rank_state_dict into state_dict"""

optimizer = _utils.state_dict_optimizer_key()
partition_info = _utils.state_dict_partition_info_key()
original_dim = _utils.state_dict_original_dimension_key()
sharded_optimizer_keys = _utils.state_dict_sharded_optimizer_keys()

# if there are no optimizer states in the rank_state_dict, no optimizer aggregation is needed
if optimizer not in rank_state_dict:
return

if optimizer not in state_dict:
state_dict[optimizer] = {}

# iterate over all optimizer state keys
for model_state_key, optimizer_dict in rank_state_dict[optimizer].items():
for optimizer_key, optimizer_value in optimizer_dict.items():
if model_state_key not in state_dict[optimizer]:
state_dict[optimizer][model_state_key] = {}

if optimizer_key in sharded_optimizer_keys and model_state_key in rank_state_dict[partition_info]:
# this optimizer state is sharded since a record exists in the partition_info subdict
_add_or_update_sharded_key_for_zero(optimizer_key, optimizer_value,
state_dict[optimizer][model_state_key], model_state_key,
rank_state_dict[partition_info][model_state_key][original_dim], sharded_states_original_dims)
else:
# this optimizer state is not sharded since a record for it does not exist in the partition_info subdict
# or this optimizer key is not one of the sharded optimizer keys
_add_or_validate_unsharded_key_for_zero(optimizer_key, optimizer_value,
state_dict[optimizer][model_state_key],
"Value mismatch for model state {} and optimizer state {}".format(model_state_key, optimizer_key))

def _reshape_states(sharded_states_original_dims, state_dict):
"""Reshape model and optimizer states in the state_dict according to dimensions in sharded_states_original_dims"""

model = _utils.state_dict_model_key()
full_precision = _utils.state_dict_full_precision_key()
optimizer = _utils.state_dict_optimizer_key()
sharded_optimizer_keys = _utils.state_dict_sharded_optimizer_keys()

for sharded_state_key, original_dim in sharded_states_original_dims.items():
# reshape model states to original_dim
if model in state_dict:
state_dict[model][full_precision][sharded_state_key] = \
state_dict[model][full_precision][sharded_state_key].reshape(original_dim)

# reshape optimizer states to original_dim
if optimizer in state_dict:
for optimizer_key, optimizer_value in state_dict[optimizer][sharded_state_key].items():
if optimizer_key in sharded_optimizer_keys:
state_dict[optimizer][sharded_state_key][optimizer_key] = optimizer_value.reshape(original_dim)

def _aggregate_trainer_options(rank_state_dict, state_dict):
"""Extracts trainer options from rank_state_dict and loads them accordingly on state_dict"""

state_dict[_utils.state_dict_trainer_options_key()] = {}

mixed_precision = _utils.state_dict_trainer_options_mixed_precision_key()
zero_stage = _utils.state_dict_trainer_options_zero_stage_key()
world_rank = _utils.state_dict_trainer_options_world_rank_key()
world_size = _utils.state_dict_trainer_options_world_size_key()
optimizer_name = _utils.state_dict_trainer_options_optimizer_name_key()

state_dict[_utils.state_dict_trainer_options_key()][mixed_precision] = \
rank_state_dict[_utils.state_dict_trainer_options_key()][mixed_precision]
state_dict[_utils.state_dict_trainer_options_key()][zero_stage] = 0
state_dict[_utils.state_dict_trainer_options_key()][world_rank] = 0
state_dict[_utils.state_dict_trainer_options_key()][world_size] = 1
state_dict[_utils.state_dict_trainer_options_key()][optimizer_name] = \
rank_state_dict[_utils.state_dict_trainer_options_key()][optimizer_name]

def aggregate_checkpoints(paths, pytorch_format=True):
"""Aggregate checkpoint files and return a single state dictionary

Aggregates checkpoint files specified by paths and laods the checkpoint file one at a time merging
them into a single state dictionary.
The checkpoint files represented by paths must be saved through ORTTrainer.save_checkpoint() function.
The schema of the state_dict returned will be in the same as the one returned by ORTTrainer.state_dict()

Args:
paths: list of more than one file represented as strings where the checkpoint is saved
pytorch_format: boolean flag to select either ONNX Runtime or PyTorch state schema of the returned state_dict
Returns:
state_dict that can be loaded into an ORTTrainer or into a PyTorch model
"""

# order the paths in ascending order of ranks
ordered_paths = _order_paths(paths)

state_dict = {}
sharded_states_original_dims = {}
world_rank = _utils.state_dict_trainer_options_world_rank_key()
mixed_precision = _utils.state_dict_trainer_options_mixed_precision_key()
zero_stage = _utils.state_dict_trainer_options_zero_stage_key()
world_size = _utils.state_dict_trainer_options_world_size_key()
optimizer_name = _utils.state_dict_trainer_options_optimizer_name_key()

loaded_mixed_precision = None
loaded_world_size = None
loaded_zero_stage = None
loaded_optimizer_name = None

for rank, path in enumerate(ordered_paths):
rank_state_dict = _checkpoint_storage.load(path)

assert _utils.state_dict_partition_info_key() in rank_state_dict, "Missing information: partition_info"
assert _utils.state_dict_trainer_options_key() in rank_state_dict, "Missing information: trainer_options"
assert rank == rank_state_dict[_utils.state_dict_trainer_options_key()][world_rank], \
"Unexpected rank in file at path {}. Expected {}, got {}".\
format(path, rank, rank_state_dict[_utils.state_dict_trainer_options_key()][world_rank])
if loaded_mixed_precision is None:
loaded_mixed_precision = rank_state_dict[_utils.state_dict_trainer_options_key()][mixed_precision]
else:
assert loaded_mixed_precision == rank_state_dict[_utils.state_dict_trainer_options_key()][mixed_precision], \
"Mixed precision state mismatch among checkpoint files. File: {}".format(path)
baijumeswani marked this conversation as resolved.
Show resolved Hide resolved
if loaded_world_size is None:
loaded_world_size = rank_state_dict[_utils.state_dict_trainer_options_key()][world_size]
else:
assert loaded_world_size == rank_state_dict[_utils.state_dict_trainer_options_key()][world_size], \
"World size state mismatch among checkpoint files. File: {}".format(path)
if loaded_zero_stage is None:
loaded_zero_stage = rank_state_dict[_utils.state_dict_trainer_options_key()][zero_stage]
else:
assert loaded_zero_stage == rank_state_dict[_utils.state_dict_trainer_options_key()][zero_stage], \
"Zero stage mismatch among checkpoint files. File: {}".format(path)
if loaded_optimizer_name is None:
loaded_optimizer_name = rank_state_dict[_utils.state_dict_trainer_options_key()][optimizer_name]
else:
assert loaded_optimizer_name == rank_state_dict[_utils.state_dict_trainer_options_key()][optimizer_name], \
"Optimizer name mismatch among checkpoint files. File: {}".format(path)

baijumeswani marked this conversation as resolved.
Show resolved Hide resolved
# aggregate all model states
_aggregate_model_states(rank_state_dict, sharded_states_original_dims, state_dict)

if not pytorch_format:
# aggregate all optimizer states if pytorch_format is False
_aggregate_optimizer_states(rank_state_dict, sharded_states_original_dims, state_dict)

# entry for trainer_options in the state_dict to perform other sanity checks
if _utils.state_dict_trainer_options_key() not in state_dict:
_aggregate_trainer_options(rank_state_dict, state_dict)

# entry for user_dict in the state_dict if not already present
if _utils.state_dict_user_dict_key() not in state_dict and \
_utils.state_dict_user_dict_key() in rank_state_dict:
state_dict[_utils.state_dict_user_dict_key()] = rank_state_dict[_utils.state_dict_user_dict_key()]

# reshape all the sharded tensors based on the original dimensions stored in sharded_states_original_dims
_reshape_states(sharded_states_original_dims, state_dict)

# return a flat structure for PyTorch model in case pytorch_format is True
# else return the hierarchical structure for ORTTrainer
return state_dict[_utils.state_dict_model_key()][_utils.state_dict_full_precision_key()] if pytorch_format else state_dict

################################################################################
# Helper functions
Expand Down Expand Up @@ -201,7 +429,7 @@ def _split_name(self, name):
name_split = name.split('_view_')
view_num = None
if(len(name_split) > 1):
view_num = int(name_split[1])
view_num = int(name_split[1])
optimizer_key = ''
mp_suffix = ''
if name_split[0].startswith('Moment_1'):
Expand Down
Loading