Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Unified Checkpoint] Cherry pick empty cache. #7868

Merged
merged 2 commits into from
Jan 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions paddlenlp/trainer/plugins/unified_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,8 @@ def save_unified_checkpoint(args, model, optimizer, output_dir, safe_serializati
if args.should_save:
config_to_save.save_pretrained(save_directory)

paddle.device.cuda.empty_cache()


def load_unified_checkpoint(args, model, optimizer, resume_from_checkpoint: str, safe_serialization=False) -> None:
"""Load potential model checkpoint
Expand Down Expand Up @@ -281,6 +283,7 @@ def unified_checkpoint_into_shards(
Returns:
tuple: state_dict, config, shard_file: file name, sharded_index: map for weight to file name.
"""
paddle.device.cuda.empty_cache()
assert hasattr(model_to_save, "config")

state_dict = model_to_save.state_dict()
Expand Down Expand Up @@ -311,6 +314,8 @@ def unified_checkpoint_into_shards(
total_size_list,
)

paddle.device.cuda.empty_cache()

return state_dict, shard_file, sharded_index


Expand All @@ -333,6 +338,8 @@ def save_unified_optimizer(args, model, optimizer, output_dir, safe_serializatio
optim_state_dict, shard_optim_file, sharded_optim_index = results[0]
master_weight_state_dict, shard_master_weight_file, sharded_master_weight_index = results[1]

paddle.device.cuda.empty_cache()

save_directory = output_dir
os.makedirs(save_directory, exist_ok=True)

Expand Down Expand Up @@ -514,6 +521,7 @@ def unified_optimizer_into_shards(
optimizer (Optimizer): optimizer to save.
safe_serialization (bool, optional): safe serialization using safetensors. Defaults to False.
"""
paddle.device.cuda.empty_cache()
optim_state_dict = nested_copy(optimizer.state_dict())
master_weights = None
if "master_weights" in optim_state_dict.keys():
Expand Down Expand Up @@ -559,12 +567,15 @@ def unified_optimizer_into_shards(
tp_actions,
filter_optim_keys,
)
paddle.device.cuda.empty_cache()

if master_weights is not None:
master_weights = merge_tensor_parallel_for_optimizer(
master_weights,
tp_actions,
filter_master_keys,
)
paddle.device.cuda.empty_cache()

# build index json file
index_optimizer_file, index_master_weight_file = {}, {}
Expand Down Expand Up @@ -601,6 +612,7 @@ def unified_optimizer_into_shards(
else:
sharded_optim_index["master_weights"] = False

paddle.device.cuda.empty_cache()
if master_weights is None:
return [(optim_state_dict, shard_optimizer_file, sharded_optim_index)]
else:
Expand Down
2 changes: 1 addition & 1 deletion paddlenlp/trainer/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -1309,7 +1309,7 @@ def is_segment_parallel_supported():
self.unified_checkpoint_config = [
"skip_save_model_weight",
"master_weight_compatible",
"async_save",
# "async_save",
]
else:
self.unified_checkpoint_config = self.unified_checkpoint_config.split(" ")
Expand Down