Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[cherry-pick] Optimize async save #8878

Merged
merged 1 commit into from
Aug 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 5 additions & 91 deletions paddlenlp/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,8 @@

import collections
import contextlib
import copy
import inspect
import math
import multiprocessing
import os
import random
import re
Expand Down Expand Up @@ -149,6 +147,7 @@
)
from .training_args import TrainingArguments
from .utils import reshard as reshard_util
from .utils.async_save import AsyncSaver
from .utils.helper import ( # nested_truncate,
broadcast_dp_optimizer,
broadcast_moe_optimizer,
Expand Down Expand Up @@ -191,88 +190,6 @@

__all__ = ["Trainer"]

async_save_queue = []
g_cpu_optimizer_state_dict = {}


def _save_func(obj, name_mapping, path, saved_signal_path, protocol):
if isinstance(obj, dict):
for k, v in obj.items():
if k == "master_weights" and isinstance(v, dict):
for kk, vv in v.items():
if isinstance(vv, paddle.Tensor):
vv.name = name_mapping["master_weights"][kk]
else:
if k in name_mapping and isinstance(v, paddle.Tensor):
v.name = name_mapping[k]

paddle.save(obj, path, protocol)
# dump savd_siganl
with open(saved_signal_path, mode="w+") as f:
f.write("1")


def check_exitcode(task):
exitcode = task.exitcode
if exitcode != 0:
print(f"Error: save ckpt process failed with exitcode {exitcode}!!!")


def clear_async_save_task_queue():
"""
wait until all async save task to be done.
"""
while len(async_save_queue) > 0:
task = async_save_queue.pop()
if task and task.is_alive():
task.join(timeout=60)
if task.is_alive():
logger.error("Error: save ckpt process timeout!!!")
async_save_queue.append(task)
else:
check_exitcode(task)
else:
check_exitcode(task)


def async_save_optimizer(optimizer_state_dict, path, saved_signal_path, protocol=4):
global g_cpu_optimizer_state_dict
g_cpu_optimizer_state_dict.clear()
name_mapping = {"master_weights": {}}
for k, v in optimizer_state_dict.items():
if k == "master_weights":
g_cpu_optimizer_state_dict[k] = {}
for kk, vv in v.items():
g_cpu_optimizer_state_dict[k][kk] = vv.pin_memory()
name_mapping[k][kk] = vv.name
elif k == "LR_Scheduler":
g_cpu_optimizer_state_dict[k] = copy.deepcopy(v)
else:
g_cpu_optimizer_state_dict[k] = v.pin_memory()
name_mapping[k] = v.name
paddle.device.synchronize()
clear_async_save_task_queue()

attempt = 0
ctx = multiprocessing.get_context("spawn")

def start_process():
nonlocal attempt
try:
p = ctx.Process(
target=_save_func, args=(g_cpu_optimizer_state_dict, name_mapping, path, saved_signal_path, protocol)
)
p.start()
return p
except Exception as e:
print(f"Attempt {attempt + 1} failed with error: {e}")
attempt += 1
time.sleep(1)
return start_process()

p = start_process()
async_save_queue.append(p)


class Trainer:
"""
Expand Down Expand Up @@ -440,6 +357,8 @@

self._save_ckpt_func = dist.save_state_dict if self.args.enable_auto_parallel else paddle.save
self._load_ckpt_func = dist.load_state_dict if self.args.enable_auto_parallel else paddle.load
if self.args.use_async_save:
self._async_optimizer_saver = AsyncSaver()

Check warning on line 361 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L361

Added line #L361 was not covered by tests

if args.max_steps > 0:
logger.info("max_steps is given, it will override any value given in num_train_epochs")
Expand Down Expand Up @@ -2308,9 +2227,6 @@
# assert unwrap_model(model) is self.model, "internal model should be a reference to self.model"
self.runtime_timer.start("checkpoint saving time")

if self.args.use_async_save:
clear_async_save_task_queue()

# Save model checkpoint
checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"

Expand Down Expand Up @@ -2353,10 +2269,8 @@
save_path = os.path.join(output_dir, optimizer_name)
if self.args.use_async_save:
assert not strtobool(os.getenv("FLAG_LLM_PDC", "False")), "Dont support FLAG_LLM_PDC"
async_save_optimizer(
state_dict,
save_path,
saved_signal_path=saved_signal_path,
self._async_optimizer_saver.run(

Check warning on line 2272 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L2272

Added line #L2272 was not covered by tests
state_dict, save_path, saved_signal_path=saved_signal_path
)
else:
self._save_ckpt_func(state_dict, save_path)
Expand Down
126 changes: 126 additions & 0 deletions paddlenlp/trainer/utils/async_save.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import atexit
import copy
import multiprocessing
import os
import time

import paddle

from paddlenlp.utils.log import logger


def _save_optimizer(obj, name_mapping, path, saved_signal_path, protocol):
start_time = time.time()
for k, v in obj.items():
if k == "master_weights" and isinstance(v, dict):
for kk, vv in v.items():
if isinstance(vv, paddle.Tensor):
vv.name = name_mapping["master_weights"][kk]

Check warning on line 32 in paddlenlp/trainer/utils/async_save.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/utils/async_save.py#L27-L32

Added lines #L27 - L32 were not covered by tests
else:
if k in name_mapping and isinstance(v, paddle.Tensor):
v.name = name_mapping[k]
paddle.save(obj, path, protocol)

Check warning on line 36 in paddlenlp/trainer/utils/async_save.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/utils/async_save.py#L34-L36

Added lines #L34 - L36 were not covered by tests
# dump saved_signal
with open(saved_signal_path, mode="w+") as f:
f.write("1")
f.flush()
os.fsync(f.fileno())
end_time = time.time()
elapsed_time = end_time - start_time
logger.info(f"Async save optimizer took {elapsed_time:.6f} seconds to execute.")

Check warning on line 44 in paddlenlp/trainer/utils/async_save.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/utils/async_save.py#L38-L44

Added lines #L38 - L44 were not covered by tests


class AsyncSaver:
def __init__(self):
self.context = multiprocessing.get_context("spawn")
self.cpu_optimizer_state_dict = {}
self.pool = self.context.Pool(1)
self.result = None
self.name_mapping = None

Check warning on line 53 in paddlenlp/trainer/utils/async_save.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/utils/async_save.py#L49-L53

Added lines #L49 - L53 were not covered by tests

atexit.register(self.shutdown)

Check warning on line 55 in paddlenlp/trainer/utils/async_save.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/utils/async_save.py#L55

Added line #L55 was not covered by tests

def run(self, optimizer_state_dict, path, saved_signal_path, protocol=4):
logger.info(f"Started saving optimizer_state_dict to {os.path.abspath(path)}.")
self._wait_for_previous_result()

Check warning on line 59 in paddlenlp/trainer/utils/async_save.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/utils/async_save.py#L58-L59

Added lines #L58 - L59 were not covered by tests

self._reset_state(path, saved_signal_path, protocol)
self._process_optimizer_state_dict(optimizer_state_dict)

Check warning on line 62 in paddlenlp/trainer/utils/async_save.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/utils/async_save.py#L61-L62

Added lines #L61 - L62 were not covered by tests

self.result = self.pool.apply_async(

Check warning on line 64 in paddlenlp/trainer/utils/async_save.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/utils/async_save.py#L64

Added line #L64 was not covered by tests
_save_optimizer,
args=(self.cpu_optimizer_state_dict, self.name_mapping, self.path, self.saved_signal_path, self.protocol),
)

logger.info("Finished launching saving optimizer_state_dict process")

Check warning on line 69 in paddlenlp/trainer/utils/async_save.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/utils/async_save.py#L69

Added line #L69 was not covered by tests

def _wait_for_previous_result(self):
if self.result is not None:
max_retries = 5
for retries in range(max_retries):
try:
self.result.get()
break
except Exception as e:
if retries == max_retries - 1:
raise RuntimeError(f"Failed after {max_retries} retries during async save.")

Check warning on line 80 in paddlenlp/trainer/utils/async_save.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/utils/async_save.py#L72-L80

Added lines #L72 - L80 were not covered by tests

time.sleep(1 + retries * 2)
logger.warning(f"An error occurred during async save: {e}. Retrying...")
self.result = self.pool.apply_async(

Check warning on line 84 in paddlenlp/trainer/utils/async_save.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/utils/async_save.py#L82-L84

Added lines #L82 - L84 were not covered by tests
_save_optimizer,
args=(
self.cpu_optimizer_state_dict,
self.name_mapping,
self.path,
self.saved_signal_path,
self.protocol,
),
)

if self.result.ready() and not self.result.successful():
raise RuntimeError("The previous async save task failed.")

Check warning on line 96 in paddlenlp/trainer/utils/async_save.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/utils/async_save.py#L95-L96

Added lines #L95 - L96 were not covered by tests
else:
pass

Check warning on line 98 in paddlenlp/trainer/utils/async_save.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/utils/async_save.py#L98

Added line #L98 was not covered by tests

def _reset_state(self, path, saved_signal_path, protocol):
self.cpu_optimizer_state_dict.clear()
self.name_mapping = {"master_weights": {}}
self.path = path
self.saved_signal_path = saved_signal_path
self.protocol = protocol

Check warning on line 105 in paddlenlp/trainer/utils/async_save.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/utils/async_save.py#L101-L105

Added lines #L101 - L105 were not covered by tests

def _process_optimizer_state_dict(self, optimizer_state_dict):
for k, v in optimizer_state_dict.items():
if k == "master_weights":
self.cpu_optimizer_state_dict[k] = {}
for kk, vv in v.items():
self.cpu_optimizer_state_dict[k][kk] = vv.pin_memory()
self.name_mapping[k][kk] = vv.name
elif k == "LR_Scheduler":
self.cpu_optimizer_state_dict[k] = copy.deepcopy(v)

Check warning on line 115 in paddlenlp/trainer/utils/async_save.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/utils/async_save.py#L108-L115

Added lines #L108 - L115 were not covered by tests
else:
self.cpu_optimizer_state_dict[k] = v.pin_memory()
self.name_mapping[k] = v.name
paddle.device.synchronize()

Check warning on line 119 in paddlenlp/trainer/utils/async_save.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/utils/async_save.py#L117-L119

Added lines #L117 - L119 were not covered by tests

def shutdown(self):
self.pool.close()
self.pool.join()

Check warning on line 123 in paddlenlp/trainer/utils/async_save.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/utils/async_save.py#L122-L123

Added lines #L122 - L123 were not covered by tests

def __del__(self):
self.shutdown()

Check warning on line 126 in paddlenlp/trainer/utils/async_save.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/utils/async_save.py#L126

Added line #L126 was not covered by tests
Loading