Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix moe save load #9045

Merged
merged 4 commits into from
Aug 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 22 additions & 6 deletions paddlenlp/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,7 +354,17 @@
)
self.add_callback(PrinterCallback if self.args.disable_tqdm else DEFAULT_PROGRESS_CALLBACK)

self._save_ckpt_func = dist.save_state_dict if self.args.enable_auto_parallel else paddle.save
def _save_ckpt_func(state_dict, path, signal_path=None):
if self.args.enable_auto_parallel:
dist.save_state_dict(state_dict, path)

Check warning on line 359 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L359

Added line #L359 was not covered by tests
else:
paddle.save(state_dict, path)

if signal_path is not None:
with open(signal_path, mode="w+") as f:
f.write("1")

self._save_ckpt_func = _save_ckpt_func
self._load_ckpt_func = dist.load_state_dict if self.args.enable_auto_parallel else paddle.load
if self.args.use_async_save:
self._async_optimizer_saver = AsyncSaver()
Expand Down Expand Up @@ -2297,7 +2307,9 @@
self._save_ckpt_func(
self._filter_moe_no_sync_optimizer_params(),
os.path.join(output_dir, optimizer_name),
saved_signal_path,
)

else:
state_dict = self.optimizer.state_dict()
save_path = os.path.join(output_dir, optimizer_name)
Expand All @@ -2307,9 +2319,7 @@
state_dict, save_path, saved_signal_path=saved_signal_path
)
else:
self._save_ckpt_func(state_dict, save_path)
with open(saved_signal_path, mode="w+") as f:
f.write("1")
self._save_ckpt_func(state_dict, save_path, saved_signal_path)

Check warning on line 2322 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L2322

Added line #L2322 was not covered by tests
else:
if self.args.unified_checkpoint and "async_save" in self.args.unified_checkpoint_config:
global_rank = paddle.distributed.get_rank() if paddle.distributed.get_world_size() > 1 else -1
Expand All @@ -2328,10 +2338,16 @@
else:
if self.args.data_parallel_rank > 0 and self.args.use_expert_parallel:
self._save_ckpt_func(
self._filter_moe_no_sync_optimizer_params(), os.path.join(output_dir, OPTIMIZER_NAME)
self._filter_moe_no_sync_optimizer_params(),
os.path.join(output_dir, optimizer_name),
saved_signal_path,
)
else:
self._save_ckpt_func(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME))
self._save_ckpt_func(
self.optimizer.state_dict(),
os.path.join(output_dir, optimizer_name),
saved_signal_path,
)

# FIXME: maybe only save one copy
paddle.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
Expand Down
3 changes: 3 additions & 0 deletions paddlenlp/trainer/utils/sharding_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,9 @@
for (k, v) in state_dict.items():
if v.name in filtered_parameters:
filtered_state_dict[k] = v
else:
if sharding_rank == 0:
filtered_state_dict[k] = v

Check warning on line 88 in paddlenlp/trainer/utils/sharding_io.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/utils/sharding_io.py#L87-L88

Added lines #L87 - L88 were not covered by tests
return filtered_state_dict


Expand Down
Loading