Skip to content

Commit

Permalink
[unified checkpoint] fix checkpoint names (#7795)
Browse files Browse the repository at this point in the history
* fix(unified checkpoint): add config save

* fix(unified checkpoint): name change

change master weigths name to model weigths name when SKIP_SAVE_MODEL_WEIGHT

* fix(unified checkpoint): model weights load

when skipping model weighs save and saving master weights as model weights, unified checkpoint needs choose the model weights to load into master weights.

* test(unified checkpoint): add test cases

To test file list and file name when skip_save_model_weight=1

* test(unified checkpoint): add enable_all_options

* fix(unified checkpoint): fix last ckpt save

* fix(unified checkpoint): unified config format
  • Loading branch information
DrownFish19 authored Jan 17, 2024
1 parent 5a01362 commit 48c1313
Show file tree
Hide file tree
Showing 2 changed files with 125 additions and 35 deletions.
86 changes: 51 additions & 35 deletions paddlenlp/trainer/plugins/unified_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,44 +115,47 @@ def save_unified_checkpoint(args, model, optimizer, output_dir, safe_serializati
else:
raise ValueError("Unified checkpoint only supports PretrainedModel")

skip_save_model_weight = False
if UnifiedCheckpointOption.SKIP_SAVE_MODEL_WEIGHT.value in args.unified_checkpoint_config:
if is_need_master_weight(optimizer, is_fp16_or_bp16=(args.fp16 or args.bf16)):
logger.info(
f"With {UnifiedCheckpointOption.SKIP_SAVE_MODEL_WEIGHT.value}, skip the model checkpoint save."
"The master weight will be loaded as model weights for next resumption."
)
# not save model weight, load from master weight
return
config_to_save = None
state_dict, config_to_save, shard_file, sharded_index = unified_checkpoint_into_shards(
args, model_to_save, safe_serialization=safe_serialization
)
skip_save_model_weight = True

save_directory = output_dir
os.makedirs(save_directory, exist_ok=True)

is_sync_save = True
if "async_save" in args.unified_checkpoint_config:
is_sync_save = False
file_save_async_or_sync(
state_dict, os.path.join(save_directory, shard_file), safe_serialization, is_sync=is_sync_save
)
# save model weights
if not skip_save_model_weight:
state_dict, shard_file, sharded_index = unified_checkpoint_into_shards(
args, model_to_save, safe_serialization=safe_serialization
)
is_sync_save = True
if "async_save" in args.unified_checkpoint_config:
is_sync_save = False
file_save_async_or_sync(
state_dict, os.path.join(save_directory, shard_file), safe_serialization, is_sync=is_sync_save
)

if sharded_index is not None:
if not safe_serialization:
path = os.path.join(output_dir, PADDLE_WEIGHTS_INDEX_NAME)
else:
path = os.path.join(output_dir, SAFE_WEIGHTS_INDEX_NAME)

with open(path, "w") as f:
json.dump(sharded_index, f, indent=4)

# save the config
config_to_save = save_config(model_to_save)
# Attach architecture to the config
config_to_save.architectures = [model_to_save.__class__.__name__]
# Save the config
if args.should_save:
config_to_save.save_pretrained(save_directory)

if sharded_index is not None:
if not safe_serialization:
path = os.path.join(output_dir, PADDLE_WEIGHTS_INDEX_NAME)
else:
path = os.path.join(output_dir, SAFE_WEIGHTS_INDEX_NAME)

with open(path, "w") as f:
json.dump(sharded_index, f, indent=4)


def load_unified_checkpoint(args, model, optimizer, resume_from_checkpoint: str, safe_serialization=False) -> None:
"""Load potential model checkpoint
Expand Down Expand Up @@ -252,6 +255,18 @@ def _remove_unused_keys(
raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}")


def save_config(model_to_save):
dtype = get_parameter_dtype(model_to_save)
model_to_save.config.dtype = str(dtype).split(".")[1]
config_to_save = copy.deepcopy(model_to_save.config)

if config_to_save.tensor_parallel_degree > 1:
# do we need to change?
config_to_save.tensor_parallel_degree = 1

return config_to_save


def unified_checkpoint_into_shards(
args,
model_to_save,
Expand All @@ -272,8 +287,6 @@ def unified_checkpoint_into_shards(

all_filter_keys = filter_params(model_to_save, state_dict)

dtype = get_parameter_dtype(model_to_save)
model_to_save.config.dtype = str(dtype).split(".")[1]
config_to_save = copy.deepcopy(model_to_save.config)

if config_to_save.tensor_parallel_degree > 1:
Expand All @@ -282,10 +295,6 @@ def unified_checkpoint_into_shards(
)
state_dict = merge_tensor_parallel_with_shard(state_dict, tp_actions, all_filter_keys)

if config_to_save.tensor_parallel_degree > 1:
# do we need to change?
config_to_save.tensor_parallel_degree = 1

# build index json file
index_weight_file = {}
total_size = 0
Expand All @@ -302,7 +311,7 @@ def unified_checkpoint_into_shards(
total_size_list,
)

return state_dict, config_to_save, shard_file, sharded_index
return state_dict, shard_file, sharded_index


def save_unified_optimizer(args, model, optimizer, output_dir, safe_serialization=False):
Expand Down Expand Up @@ -343,16 +352,17 @@ def save_unified_optimizer(args, model, optimizer, output_dir, safe_serializatio
)

if sharded_optim_index is not None:
if not safe_serialization:
path = os.path.join(output_dir, PADDLE_OPTIMIZER_INDEX_NAME)
master_path = os.path.join(output_dir, PADDLE_MASTER_WEIGHTS_INDEX_NAME)
else:
path = os.path.join(output_dir, SAFE_OPTIMIZER_INDEX_NAME)
master_path = os.path.join(output_dir, SAFE_MASTER_WEIGHTS_INDEX_NAME)

optimizer_index_name = SAFE_OPTIMIZER_INDEX_NAME if safe_serialization else PADDLE_OPTIMIZER_INDEX_NAME
path = os.path.join(output_dir, optimizer_index_name)
with open(path, "w") as f:
json.dump(sharded_optim_index, f, indent=4)

master_weights_name = (
SAFE_MASTER_WEIGHTS_INDEX_NAME if safe_serialization else PADDLE_MASTER_WEIGHTS_INDEX_NAME
)
if UnifiedCheckpointOption.SKIP_SAVE_MODEL_WEIGHT.value in args.unified_checkpoint_config:
master_weights_name = SAFE_WEIGHTS_INDEX_NAME if safe_serialization else PADDLE_WEIGHTS_INDEX_NAME
master_path = os.path.join(output_dir, master_weights_name)
if master_weight_state_dict is not None:
with open(master_path, "w") as f:
json.dump(sharded_master_weight_index, f, indent=4)
Expand Down Expand Up @@ -561,6 +571,8 @@ def unified_optimizer_into_shards(
total_optim_size, total_master_weight_size = 0, 0
optimizer_name = SAFE_OPTIMIZER_NAME if safe_serialization else PADDLE_OPTIMIZER_NAME
master_weights_name = SAFE_MASTER_WEIGHTS_NAME if safe_serialization else PADDLE_MASTER_WEIGHTS_NAME
if UnifiedCheckpointOption.SKIP_SAVE_MODEL_WEIGHT.value in args.unified_checkpoint_config:
master_weights_name = SAFE_WEIGHTS_NAME if safe_serialization else PADDLE_WEIGHTS_NAME
shard_optimizer_file = get_sharded_file_name(args, optimizer_name, is_optimizer=True)
shard_master_weight_file = get_sharded_file_name(args, master_weights_name, is_optimizer=True)

Expand Down Expand Up @@ -1648,6 +1660,10 @@ def update_master_weight_status(args, optimizer, has_master_weight, safe_seriali
index_filename_master_weights = (
PADDLE_MASTER_WEIGHTS_INDEX_NAME if not safe_serialization else SAFE_MASTER_WEIGHTS_INDEX_NAME
)
if UnifiedCheckpointOption.SKIP_SAVE_MODEL_WEIGHT.value in args.unified_checkpoint_config:
index_filename_master_weights = (
PADDLE_WEIGHTS_INDEX_NAME if not safe_serialization else SAFE_WEIGHTS_INDEX_NAME
)
else:
has_master_weight = False
index_filename_master_weights = None
Expand Down
74 changes: 74 additions & 0 deletions tests/trainer/test_unified_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -650,6 +650,26 @@ def setUp(self):
self.rtol = 1e-7

self.run_pretrain_file = "llm/llama/run_pretrain.py"
self.filelists = [
"config.json",
"master_weights-00001-of-00002.safetensors",
"master_weights-00002-of-00002.safetensors",
"master_weights.safetensors.index.json",
"model-00001-of-00002.safetensors",
"model-00002-of-00002.safetensors",
"model.safetensors.index.json",
"optimizer-00001-of-00002.safetensors",
"optimizer-00002-of-00002.safetensors",
"optimizer.safetensors.index.json",
"rng_state_2.pth",
"scaler.pdparams",
"scheduler.pdparams",
"sentencepiece.bpe.model",
"special_tokens_map.json",
"tokenizer_config.json",
"trainer_state.json",
"training_args.bin",
]

def runfrist(self, train_args):
train_args["unified_checkpoint"] = 1
Expand All @@ -674,6 +694,43 @@ def testTP2(self):
assert len(res) == 2
np.testing.assert_allclose(res[0], res[1], self.rtol)

@require_paddle_at_least_2_gpu
def testFileLists(self):
remove_logs()
remove_ckpt(pretrain_arguments["output_dir"])

save_steps = pretrain_arguments["save_steps"]
base_ckpt_path = os.path.join(pretrain_arguments["output_dir"], "checkpoint-%d" % save_steps)

train_args = self.configs["TP2"]
self.runfrist(train_args)
assert sorted(self.filelists) == sorted(os.listdir(base_ckpt_path))
self.rerun(train_args)

if self.need_allclose:
res = check_acc()
assert len(res) == 2
np.testing.assert_allclose(res[0], res[1], self.rtol)

# Test skip_save_model_weight
remove_logs()
remove_ckpt(pretrain_arguments["output_dir"])
train_args["unified_checkpoint_config"] = "skip_save_model_weight"
self.runfrist(train_args)
unsave_filelists = [
"master_weights-00001-of-00002.safetensors",
"master_weights-00002-of-00002.safetensors",
"master_weights.safetensors.index.json",
]
cur_filelists = [file for file in self.filelists if file not in unsave_filelists]
assert sorted(cur_filelists) == sorted(os.listdir(base_ckpt_path))
self.rerun(train_args)

if self.need_allclose:
res = check_acc()
assert len(res) == 2
np.testing.assert_allclose(res[0], res[1], self.rtol)


class TestUnifiedCheckpointOnN1C8AsyncSaveToDisk(TestUnifiedCheckpointBase):
def setUp(self):
Expand Down Expand Up @@ -985,3 +1042,20 @@ def rerun(self, train_args):
self.run_n1c8(self.run_pretrain_file, **config)
res = check_acc()
np.testing.assert_allclose(res[0], res[-1], rtol=self.rtol)


class TestUnifiedCheckpointOnN1C8EnableAll(TestUnifiedCheckpointBase):
def setUp(self):
super().setUp()
for config_key in self.configs:
self.configs[config_key]["unified_checkpoint"] = 1
self.configs[config_key]["unified_checkpoint_config"] = "enable_all_options"

self.need_allclose = True
self.rtol = 1e-7

def runfrist(self, train_args):
self.run_n1c8(self.run_pretrain_file, **train_args)

def rerun(self, train_args):
self.run_n1c8(self.run_pretrain_file, **train_args)

0 comments on commit 48c1313

Please sign in to comment.