Skip to content

Commit

Permalink
Revise param_shapes to be a list of ordered dict (#1424)
Browse files Browse the repository at this point in the history
* Revise param_shapes to be a list of ordered dict

* test i can push

* add tests; split z2 and z3 into separate funcs

Co-authored-by: Xiaopeng Li <xiaopel@amazon.com>
Co-authored-by: Stas Bekman <stas@stason.org>
Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
  • Loading branch information
4 people authored Oct 4, 2021
1 parent 466b0e6 commit dd0c8fa
Show file tree
Hide file tree
Showing 3 changed files with 261 additions and 78 deletions.
6 changes: 4 additions & 2 deletions deepspeed/runtime/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -2441,7 +2441,7 @@ def _get_zero_param_shapes(self):
optimizer.fp16_groups seems to be the easiest to use as it's in all zeroX versions.
"""
param_shapes = OrderedDict()
param_group_shapes = []
cnt = 0
numel = 0

Expand All @@ -2453,6 +2453,7 @@ def _get_zero_param_shapes(self):
fp16_groups = self.optimizer.fp16_groups

for fp16_group in fp16_groups:
param_shapes = OrderedDict()
for param in fp16_group:
cnt += 1
numel += param.ds_numel if hasattr(param, "ds_numel") else param.numel()
Expand All @@ -2464,9 +2465,10 @@ def _get_zero_param_shapes(self):

# uncomment to debug zero_to_fp32.py problems
# if self.global_rank == 0: print(f"saving param {name} {shape} (numel={shape.numel()})")
param_group_shapes.append(param_shapes)
# if self.global_rank == 0: print(f"Total saved {numel} numels in {cnt} params")

return param_shapes
return param_group_shapes

def _copy_recovery_script(self, save_path):
base_dir = os.path.dirname(os.path.dirname(__file__))
Expand Down
208 changes: 141 additions & 67 deletions deepspeed/utils/zero_to_fp32.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,6 @@ def get_optim_files(checkpoint_dir):


def parse_model_state(file):

# load to cpu
device = torch.device('cpu')
state_dict = torch.load(file, map_location=device)

if "buffer_names" not in state_dict:
Expand Down Expand Up @@ -85,10 +82,10 @@ def parse_optim_states(files, ds_checkpoint_dir):
zero_stage = state_dicts[0]['optimizer_state_dict']["zero_stage"]
world_size = state_dicts[0]['optimizer_state_dict']["partition_count"]
param_shapes = state_dicts[0]["param_shapes"]
'''For ZeRO-2 each param group can have different partiiton_count as data parallelism for expert
parameters can be different from data parallelism for non-expert parameters. So we can just use the max of
the partition_count to get the dp world_size.
'''
# For ZeRO-2 each param group can have different partition_count as data parallelism for expert
# parameters can be different from data parallelism for non-expert parameters. So we can just
# use the max of the partition_count to get the dp world_size.

if type(world_size) is list:
world_size = max(world_size)

Expand All @@ -106,24 +103,24 @@ def parse_optim_states(files, ds_checkpoint_dir):
else:
raise ValueError(f"unknown zero stage {zero_stage}")

# if there is more than one param group, there will be multiple flattened tensors - one
# flattened tensor per group - for simplicity merge them into a single tensor
#
# XXX: could make the script more memory efficient for when there are multiple groups - it
# will require matching the sub-lists of param_shapes for each param group flattened tensor
fp32_flat_groups = [
torch.cat(state_dicts[i]['optimizer_state_dict'][fp32_groups_key],
0) for i in range(len(state_dicts))
]

return zero_stage, world_size, param_shapes, fp32_flat_groups
if zero_stage == 2:
fp32_flat_groups = [
state_dicts[i]['optimizer_state_dict'][fp32_groups_key]
for i in range(len(state_dicts))
]
elif zero_stage == 3:
# if there is more than one param group, there will be multiple flattened tensors - one
# flattened tensor per group - for simplicity merge them into a single tensor
#
# XXX: could make the script more memory efficient for when there are multiple groups - it
# will require matching the sub-lists of param_shapes for each param group flattened tensor

fp32_flat_groups = [
torch.cat(state_dicts[i]['optimizer_state_dict'][fp32_groups_key],
0) for i in range(len(state_dicts))
]

def zero3_partitioned_param_info(unpartitioned_numel, world_size):
remainder = unpartitioned_numel % world_size
padding_numel = (world_size - remainder) if remainder else 0
partitioned_numel = math.ceil(unpartitioned_numel / world_size)
return partitioned_numel, padding_numel
return zero_stage, world_size, param_shapes, fp32_flat_groups


def _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir):
Expand All @@ -144,29 +141,48 @@ def _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir):
model_file = get_model_state_file(ds_checkpoint_dir, zero_stage)
buffers = parse_model_state(model_file)

if zero_stage == 2:
return _get_fp32_state_dict_from_zero2_checkpoint(world_size,
param_shapes,
fp32_flat_groups,
buffers)
elif zero_stage == 3:
return _get_fp32_state_dict_from_zero3_checkpoint(world_size,
param_shapes,
fp32_flat_groups,
buffers)


def _get_fp32_state_dict_from_zero2_checkpoint(world_size,
param_shapes,
fp32_flat_groups,
buffers):

# Reconstruction protocol:
#
# - for zero2 we just need to concat the partitions back to back and reconsolidate over one huge
# flat buffer - no need to deal with padding since if there is any it will be only in the tail
# of the last partition so there it will be just left out
#
# - for zero3 we need to zip the partitions together at boundary of each param, re-consolidating
# each param, while dealing with padding if any
# XXX: document this

if debug:
for i in range(world_size):
print(f"fp32_flat_groups[i].shape={fp32_flat_groups[i].shape}")

if zero_stage == 2:
# XXX: memory usage doubles here (zero2)
full_single_fp32_vector = torch.cat(fp32_flat_groups, 0)
avail_numel = full_single_fp32_vector.numel()
elif zero_stage == 3:
avail_numel = fp32_flat_groups[0].numel() * world_size
for j in range(len(fp32_flat_groups[0])):
print(f"fp32_flat_groups[{i}][{j}].shape={fp32_flat_groups[i][j].shape}")

# XXX: memory usage doubles here (zero2)
num_param_groups = len(fp32_flat_groups[0])
merged_single_partition_of_fp32_groups = []
for i in range(num_param_groups):
merged_partitions = [sd[i] for sd in fp32_flat_groups]
full_single_fp32_vector = torch.cat(merged_partitions, 0)
merged_single_partition_of_fp32_groups.append(full_single_fp32_vector)
avail_numel = sum([
full_single_fp32_vector.numel()
for full_single_fp32_vector in merged_single_partition_of_fp32_groups
])

if debug:
wanted_params = len(param_shapes)
wanted_numel = sum(shape.numel() for shape in param_shapes.values())
wanted_params = sum([len(shapes) for shapes in param_shapes])
wanted_numel = sum(
[sum(shape.numel() for shape in shapes.values()) for shapes in param_shapes])
# not asserting if there is a mismatch due to possible padding
print(f"Have {avail_numel} numels to process.")
print(f"Need {wanted_numel} numels in {wanted_params} params.")
Expand All @@ -181,16 +197,17 @@ def _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir):
# params
# XXX: for huge models that can't fit into the host's RAM we will have to recode this to support
# out-of-core computing solution
offset = 0
total_numel = 0
total_params = 0
for name, shape in param_shapes.items():
for shapes, full_single_fp32_vector in zip(param_shapes, merged_single_partition_of_fp32_groups):
offset = 0
avail_numel = full_single_fp32_vector.numel()
for name, shape in shapes.items():

unpartitioned_numel = shape.numel()
total_numel += unpartitioned_numel
total_params += 1
unpartitioned_numel = shape.numel()
total_numel += unpartitioned_numel
total_params += 1

if zero_stage == 2:
if debug:
print(
f"{name} full shape: {shape} unpartitioned numel {unpartitioned_numel} "
Expand All @@ -201,26 +218,6 @@ def _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir):
unpartitioned_numel).view(shape)
offset += unpartitioned_numel

elif zero_stage == 3:
partitioned_numel, partitioned_padding_numel = zero3_partitioned_param_info(unpartitioned_numel, world_size)

if debug:
print(
f"{total_params} {name} full shape: {shape} partition0 numel={partitioned_numel} partitioned_padding_numel={partitioned_padding_numel}"
)

# XXX: memory usage doubles here (zero3)
state_dict[name] = torch.cat(
tuple(fp32_flat_groups[i].narrow(0,
offset,
partitioned_numel)
for i in range(world_size)),
0).narrow(0,
0,
unpartitioned_numel).view(shape)
offset += partitioned_numel

if zero_stage == 2:
# Z2 started to align to 2*world_size to improve nccl performance. Therefore both offset and
# avail_numel can differ by anywhere between 0..2*world_size. Due to two unrelated complex
# paddings performed in the code it's almost impossible to predict the exact numbers w/o the
Expand All @@ -239,8 +236,85 @@ def zero2_align(x):
if debug:
print(f"aligned offset={offset}, avail_numel={avail_numel}")

elif zero_stage == 3:
offset *= world_size
# Sanity check
if offset != avail_numel:
raise ValueError(
f"consumed {offset} numels out of {avail_numel} - something is wrong")

print(
f"Reconstructed fp32 state dict with {total_params} params {total_numel} elements"
)

return state_dict


def zero3_partitioned_param_info(unpartitioned_numel, world_size):
remainder = unpartitioned_numel % world_size
padding_numel = (world_size - remainder) if remainder else 0
partitioned_numel = math.ceil(unpartitioned_numel / world_size)
return partitioned_numel, padding_numel


def _get_fp32_state_dict_from_zero3_checkpoint(world_size,
param_shapes,
fp32_flat_groups,
buffers):

# Reconstruction protocol: For zero3 we need to zip the partitions together at boundary of each
# param, re-consolidating each param, while dealing with padding if any

avail_numel = fp32_flat_groups[0].numel() * world_size
# merge list of dicts, preserving order
param_shapes = {k: v for d in param_shapes for k, v in d.items()}

if debug:
for i in range(world_size):
print(f"fp32_flat_groups[{i}].shape={fp32_flat_groups[i].shape}")

wanted_params = len(param_shapes)
wanted_numel = sum(shape.numel() for shape in param_shapes.values())
# not asserting if there is a mismatch due to possible padding
print(f"Have {avail_numel} numels to process.")
print(f"Need {wanted_numel} numels in {wanted_params} params.")

state_dict = OrderedDict()

# buffers
state_dict.update(buffers)
if debug:
print(f"added {len(buffers)} buffers")

# params
# XXX: for huge models that can't fit into the host's RAM we will have to recode this to support
# out-of-core computing solution
offset = 0
total_numel = 0
total_params = 0
for name, shape in param_shapes.items():

unpartitioned_numel = shape.numel()
total_numel += unpartitioned_numel
total_params += 1

partitioned_numel, partitioned_padding_numel = zero3_partitioned_param_info(unpartitioned_numel, world_size)

if debug:
print(
f"{total_params} {name} full shape: {shape} partition0 numel={partitioned_numel} partitioned_padding_numel={partitioned_padding_numel}"
)

# XXX: memory usage doubles here
state_dict[name] = torch.cat(
tuple(fp32_flat_groups[i].narrow(0,
offset,
partitioned_numel)
for i in range(world_size)),
0).narrow(0,
0,
unpartitioned_numel).view(shape)
offset += partitioned_numel

offset *= world_size

# Sanity check
if offset != avail_numel:
Expand Down
Loading

0 comments on commit dd0c8fa

Please sign in to comment.