Skip to content

Commit

Permalink
[Unified Checkpoint] Fix unified checkpoint by empty cache. (#7855)
Browse files Browse the repository at this point in the history
  • Loading branch information
ZHUI committed Jan 17, 2024
1 parent 38ef1f6 commit 1743466
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 1 deletion.
12 changes: 12 additions & 0 deletions paddlenlp/trainer/plugins/unified_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,8 @@ def save_unified_checkpoint(args, model, optimizer, output_dir, safe_serializati
if args.should_save:
config_to_save.save_pretrained(save_directory)

paddle.device.cuda.empty_cache()


def load_unified_checkpoint(args, model, optimizer, resume_from_checkpoint: str, safe_serialization=False) -> None:
"""Load potential model checkpoint
Expand Down Expand Up @@ -281,6 +283,7 @@ def unified_checkpoint_into_shards(
Returns:
tuple: state_dict, config, shard_file: file name, sharded_index: map for weight to file name.
"""
paddle.device.cuda.empty_cache()
assert hasattr(model_to_save, "config")

state_dict = model_to_save.state_dict()
Expand Down Expand Up @@ -311,6 +314,8 @@ def unified_checkpoint_into_shards(
total_size_list,
)

paddle.device.cuda.empty_cache()

return state_dict, shard_file, sharded_index


Expand All @@ -333,6 +338,8 @@ def save_unified_optimizer(args, model, optimizer, output_dir, safe_serializatio
optim_state_dict, shard_optim_file, sharded_optim_index = results[0]
master_weight_state_dict, shard_master_weight_file, sharded_master_weight_index = results[1]

paddle.device.cuda.empty_cache()

save_directory = output_dir
os.makedirs(save_directory, exist_ok=True)

Expand Down Expand Up @@ -514,6 +521,7 @@ def unified_optimizer_into_shards(
optimizer (Optimizer): optimizer to save.
safe_serialization (bool, optional): safe serialization using safetensors. Defaults to False.
"""
paddle.device.cuda.empty_cache()
optim_state_dict = nested_copy(optimizer.state_dict())
master_weights = None
if "master_weights" in optim_state_dict.keys():
Expand Down Expand Up @@ -559,12 +567,15 @@ def unified_optimizer_into_shards(
tp_actions,
filter_optim_keys,
)
paddle.device.cuda.empty_cache()

if master_weights is not None:
master_weights = merge_tensor_parallel_for_optimizer(
master_weights,
tp_actions,
filter_master_keys,
)
paddle.device.cuda.empty_cache()

# build index json file
index_optimizer_file, index_master_weight_file = {}, {}
Expand Down Expand Up @@ -601,6 +612,7 @@ def unified_optimizer_into_shards(
else:
sharded_optim_index["master_weights"] = False

paddle.device.cuda.empty_cache()
if master_weights is None:
return [(optim_state_dict, shard_optimizer_file, sharded_optim_index)]
else:
Expand Down
2 changes: 1 addition & 1 deletion paddlenlp/trainer/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -1308,7 +1308,7 @@ def is_segment_parallel_supported():
self.unified_checkpoint_config = [
"skip_save_model_weight",
"master_weight_compatible",
"async_save",
# "async_save",
]
else:
self.unified_checkpoint_config = self.unified_checkpoint_config.split(" ")
Expand Down

0 comments on commit 1743466

Please sign in to comment.