Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[unified checkpoint] fix checkpoint names #7795

Merged
merged 9 commits into from
Jan 17, 2024
86 changes: 51 additions & 35 deletions paddlenlp/trainer/plugins/unified_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,44 +115,47 @@
else:
raise ValueError("Unified checkpoint only supports PretrainedModel")

skip_save_model_weight = False

Check warning on line 118 in paddlenlp/trainer/plugins/unified_checkpoint.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/plugins/unified_checkpoint.py#L118

Added line #L118 was not covered by tests
if UnifiedCheckpointOption.SKIP_SAVE_MODEL_WEIGHT.value in args.unified_checkpoint_config:
if is_need_master_weight(optimizer, is_fp16_or_bp16=(args.fp16 or args.bf16)):
logger.info(
f"With {UnifiedCheckpointOption.SKIP_SAVE_MODEL_WEIGHT.value}, skip the model checkpoint save."
"The master weight will be loaded as model weights for next resumption."
)
# not save model weight, load from master weight
return
config_to_save = None
state_dict, config_to_save, shard_file, sharded_index = unified_checkpoint_into_shards(
args, model_to_save, safe_serialization=safe_serialization
)
skip_save_model_weight = True

Check warning on line 126 in paddlenlp/trainer/plugins/unified_checkpoint.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/plugins/unified_checkpoint.py#L126

Added line #L126 was not covered by tests

save_directory = output_dir
os.makedirs(save_directory, exist_ok=True)

is_sync_save = True
if "async_save" in args.unified_checkpoint_config:
is_sync_save = False
file_save_async_or_sync(
state_dict, os.path.join(save_directory, shard_file), safe_serialization, is_sync=is_sync_save
)
# save model weights
if not skip_save_model_weight:
state_dict, shard_file, sharded_index = unified_checkpoint_into_shards(

Check warning on line 133 in paddlenlp/trainer/plugins/unified_checkpoint.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/plugins/unified_checkpoint.py#L132-L133

Added lines #L132 - L133 were not covered by tests
args, model_to_save, safe_serialization=safe_serialization
)
is_sync_save = True
if "async_save" in args.unified_checkpoint_config:
is_sync_save = False
file_save_async_or_sync(

Check warning on line 139 in paddlenlp/trainer/plugins/unified_checkpoint.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/plugins/unified_checkpoint.py#L136-L139

Added lines #L136 - L139 were not covered by tests
state_dict, os.path.join(save_directory, shard_file), safe_serialization, is_sync=is_sync_save
)

if sharded_index is not None:
if not safe_serialization:
path = os.path.join(output_dir, PADDLE_WEIGHTS_INDEX_NAME)

Check warning on line 145 in paddlenlp/trainer/plugins/unified_checkpoint.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/plugins/unified_checkpoint.py#L143-L145

Added lines #L143 - L145 were not covered by tests
else:
path = os.path.join(output_dir, SAFE_WEIGHTS_INDEX_NAME)

Check warning on line 147 in paddlenlp/trainer/plugins/unified_checkpoint.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/plugins/unified_checkpoint.py#L147

Added line #L147 was not covered by tests

with open(path, "w") as f:
json.dump(sharded_index, f, indent=4)

Check warning on line 150 in paddlenlp/trainer/plugins/unified_checkpoint.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/plugins/unified_checkpoint.py#L149-L150

Added lines #L149 - L150 were not covered by tests

# save the config
config_to_save = save_config(model_to_save)

Check warning on line 153 in paddlenlp/trainer/plugins/unified_checkpoint.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/plugins/unified_checkpoint.py#L153

Added line #L153 was not covered by tests
# Attach architecture to the config
config_to_save.architectures = [model_to_save.__class__.__name__]
# Save the config
if args.should_save:
config_to_save.save_pretrained(save_directory)

if sharded_index is not None:
if not safe_serialization:
path = os.path.join(output_dir, PADDLE_WEIGHTS_INDEX_NAME)
else:
path = os.path.join(output_dir, SAFE_WEIGHTS_INDEX_NAME)

with open(path, "w") as f:
json.dump(sharded_index, f, indent=4)


def load_unified_checkpoint(args, model, optimizer, resume_from_checkpoint: str, safe_serialization=False) -> None:
"""Load potential model checkpoint
Expand Down Expand Up @@ -252,6 +255,18 @@
raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}")


def save_config(model_to_save):
dtype = get_parameter_dtype(model_to_save)
model_to_save.config.dtype = str(dtype).split(".")[1]
config_to_save = copy.deepcopy(model_to_save.config)

Check warning on line 261 in paddlenlp/trainer/plugins/unified_checkpoint.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/plugins/unified_checkpoint.py#L259-L261

Added lines #L259 - L261 were not covered by tests

if config_to_save.tensor_parallel_degree > 1:

Check warning on line 263 in paddlenlp/trainer/plugins/unified_checkpoint.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/plugins/unified_checkpoint.py#L263

Added line #L263 was not covered by tests
# do we need to change?
config_to_save.tensor_parallel_degree = 1

Check warning on line 265 in paddlenlp/trainer/plugins/unified_checkpoint.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/plugins/unified_checkpoint.py#L265

Added line #L265 was not covered by tests

return config_to_save

Check warning on line 267 in paddlenlp/trainer/plugins/unified_checkpoint.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/plugins/unified_checkpoint.py#L267

Added line #L267 was not covered by tests


def unified_checkpoint_into_shards(
args,
model_to_save,
Expand All @@ -272,8 +287,6 @@

all_filter_keys = filter_params(model_to_save, state_dict)

dtype = get_parameter_dtype(model_to_save)
model_to_save.config.dtype = str(dtype).split(".")[1]
config_to_save = copy.deepcopy(model_to_save.config)

if config_to_save.tensor_parallel_degree > 1:
Expand All @@ -282,10 +295,6 @@
)
state_dict = merge_tensor_parallel_with_shard(state_dict, tp_actions, all_filter_keys)

if config_to_save.tensor_parallel_degree > 1:
# do we need to change?
config_to_save.tensor_parallel_degree = 1

# build index json file
index_weight_file = {}
total_size = 0
Expand All @@ -302,7 +311,7 @@
total_size_list,
)

return state_dict, config_to_save, shard_file, sharded_index
return state_dict, shard_file, sharded_index

Check warning on line 314 in paddlenlp/trainer/plugins/unified_checkpoint.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/plugins/unified_checkpoint.py#L314

Added line #L314 was not covered by tests


def save_unified_optimizer(args, model, optimizer, output_dir, safe_serialization=False):
Expand Down Expand Up @@ -343,16 +352,17 @@
)

if sharded_optim_index is not None:
if not safe_serialization:
path = os.path.join(output_dir, PADDLE_OPTIMIZER_INDEX_NAME)
master_path = os.path.join(output_dir, PADDLE_MASTER_WEIGHTS_INDEX_NAME)
else:
path = os.path.join(output_dir, SAFE_OPTIMIZER_INDEX_NAME)
master_path = os.path.join(output_dir, SAFE_MASTER_WEIGHTS_INDEX_NAME)

optimizer_index_name = SAFE_OPTIMIZER_INDEX_NAME if safe_serialization else PADDLE_OPTIMIZER_INDEX_NAME
path = os.path.join(output_dir, optimizer_index_name)

Check warning on line 356 in paddlenlp/trainer/plugins/unified_checkpoint.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/plugins/unified_checkpoint.py#L355-L356

Added lines #L355 - L356 were not covered by tests
with open(path, "w") as f:
json.dump(sharded_optim_index, f, indent=4)

master_weights_name = (

Check warning on line 360 in paddlenlp/trainer/plugins/unified_checkpoint.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/plugins/unified_checkpoint.py#L360

Added line #L360 was not covered by tests
SAFE_MASTER_WEIGHTS_INDEX_NAME if safe_serialization else PADDLE_MASTER_WEIGHTS_INDEX_NAME
)
if UnifiedCheckpointOption.SKIP_SAVE_MODEL_WEIGHT.value in args.unified_checkpoint_config:
master_weights_name = SAFE_WEIGHTS_INDEX_NAME if safe_serialization else PADDLE_WEIGHTS_INDEX_NAME
master_path = os.path.join(output_dir, master_weights_name)

Check warning on line 365 in paddlenlp/trainer/plugins/unified_checkpoint.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/plugins/unified_checkpoint.py#L363-L365

Added lines #L363 - L365 were not covered by tests
if master_weight_state_dict is not None:
with open(master_path, "w") as f:
json.dump(sharded_master_weight_index, f, indent=4)
Expand Down Expand Up @@ -561,6 +571,8 @@
total_optim_size, total_master_weight_size = 0, 0
optimizer_name = SAFE_OPTIMIZER_NAME if safe_serialization else PADDLE_OPTIMIZER_NAME
master_weights_name = SAFE_MASTER_WEIGHTS_NAME if safe_serialization else PADDLE_MASTER_WEIGHTS_NAME
if UnifiedCheckpointOption.SKIP_SAVE_MODEL_WEIGHT.value in args.unified_checkpoint_config:
master_weights_name = SAFE_WEIGHTS_NAME if safe_serialization else PADDLE_WEIGHTS_NAME

Check warning on line 575 in paddlenlp/trainer/plugins/unified_checkpoint.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/plugins/unified_checkpoint.py#L574-L575

Added lines #L574 - L575 were not covered by tests
shard_optimizer_file = get_sharded_file_name(args, optimizer_name, is_optimizer=True)
shard_master_weight_file = get_sharded_file_name(args, master_weights_name, is_optimizer=True)

Expand Down Expand Up @@ -1648,6 +1660,10 @@
index_filename_master_weights = (
PADDLE_MASTER_WEIGHTS_INDEX_NAME if not safe_serialization else SAFE_MASTER_WEIGHTS_INDEX_NAME
)
if UnifiedCheckpointOption.SKIP_SAVE_MODEL_WEIGHT.value in args.unified_checkpoint_config:
index_filename_master_weights = (

Check warning on line 1664 in paddlenlp/trainer/plugins/unified_checkpoint.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/plugins/unified_checkpoint.py#L1663-L1664

Added lines #L1663 - L1664 were not covered by tests
PADDLE_WEIGHTS_INDEX_NAME if not safe_serialization else SAFE_WEIGHTS_INDEX_NAME
)
else:
has_master_weight = False
index_filename_master_weights = None
Expand Down
9 changes: 9 additions & 0 deletions paddlenlp/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2027,8 +2027,17 @@ def save_model(self, output_dir: Optional[str] = None, merge_tensor_parallel: Op
self.model_wrapped.get_all_parameters(convert2cpu=True)

if self.args.should_save_model_state:
unified_checkpoint_config_backup = self.args.unified_checkpoint_config
# backup and remove unified_checkpoint_config for not trine stage
if not self.is_in_train:
self.args.unified_checkpoint_config = []

self._save(output_dir=output_dir, merge_tensor_parallel=merge_tensor_parallel)

# recover unified_checkpoint_config for not trine stage
if not self.is_in_train:
self.args.unified_checkpoint_config = unified_checkpoint_config_backup

def _save_checkpoint(self, model, metrics=None):
# assert unwrap_model(model) is self.model, "internal model should be a reference to self.model"

Expand Down
2 changes: 2 additions & 0 deletions paddlenlp/trainer/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -1311,6 +1311,8 @@
"master_weight_compatible",
"async_save",
]
else:
self.unified_checkpoint_config = self.unified_checkpoint_config.split(" ")

Check warning on line 1315 in paddlenlp/trainer/training_args.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/training_args.py#L1315

Added line #L1315 was not covered by tests

if self.report_to is None:
logger.info(
Expand Down
74 changes: 74 additions & 0 deletions tests/trainer/test_unified_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -650,6 +650,26 @@ def setUp(self):
self.rtol = 1e-7

self.run_pretrain_file = "llm/llama/run_pretrain.py"
self.filelists = [
"config.json",
"master_weights-00001-of-00002.safetensors",
"master_weights-00002-of-00002.safetensors",
"master_weights.safetensors.index.json",
"model-00001-of-00002.safetensors",
"model-00002-of-00002.safetensors",
"model.safetensors.index.json",
"optimizer-00001-of-00002.safetensors",
"optimizer-00002-of-00002.safetensors",
"optimizer.safetensors.index.json",
"rng_state_2.pth",
"scaler.pdparams",
"scheduler.pdparams",
"sentencepiece.bpe.model",
"special_tokens_map.json",
"tokenizer_config.json",
"trainer_state.json",
"training_args.bin",
]

def runfrist(self, train_args):
train_args["unified_checkpoint"] = 1
Expand All @@ -674,6 +694,43 @@ def testTP2(self):
assert len(res) == 2
np.testing.assert_allclose(res[0], res[1], self.rtol)

@require_paddle_at_least_2_gpu
def testFileLists(self):
remove_logs()
remove_ckpt(pretrain_arguments["output_dir"])

save_steps = pretrain_arguments["save_steps"]
base_ckpt_path = os.path.join(pretrain_arguments["output_dir"], "checkpoint-%d" % save_steps)

train_args = self.configs["TP2"]
self.runfrist(train_args)
assert sorted(self.filelists) == sorted(os.listdir(base_ckpt_path))
self.rerun(train_args)

if self.need_allclose:
res = check_acc()
assert len(res) == 2
np.testing.assert_allclose(res[0], res[1], self.rtol)

# Test skip_save_model_weight
remove_logs()
remove_ckpt(pretrain_arguments["output_dir"])
train_args["unified_checkpoint_config"] = "skip_save_model_weight"
self.runfrist(train_args)
unsave_filelists = [
"master_weights-00001-of-00002.safetensors",
"master_weights-00002-of-00002.safetensors",
"master_weights.safetensors.index.json",
]
cur_filelists = [file for file in self.filelists if file not in unsave_filelists]
assert sorted(cur_filelists) == sorted(os.listdir(base_ckpt_path))
self.rerun(train_args)

if self.need_allclose:
res = check_acc()
assert len(res) == 2
np.testing.assert_allclose(res[0], res[1], self.rtol)


class TestUnifiedCheckpointOnN1C8AsyncSaveToDisk(TestUnifiedCheckpointBase):
def setUp(self):
Expand Down Expand Up @@ -985,3 +1042,20 @@ def rerun(self, train_args):
self.run_n1c8(self.run_pretrain_file, **config)
res = check_acc()
np.testing.assert_allclose(res[0], res[-1], rtol=self.rtol)


class TestUnifiedCheckpointOnN1C8EnableAll(TestUnifiedCheckpointBase):
def setUp(self):
super().setUp()
for config_key in self.configs:
self.configs[config_key]["unified_checkpoint"] = 1
self.configs[config_key]["unified_checkpoint_config"] = "enable_all_options"

self.need_allclose = True
self.rtol = 1e-7

def runfrist(self, train_args):
self.run_n1c8(self.run_pretrain_file, **train_args)

def rerun(self, train_args):
self.run_n1c8(self.run_pretrain_file, **train_args)