diff --git a/paddlenlp/trainer/trainer.py b/paddlenlp/trainer/trainer.py index af5c9b9f1d12..fecc832b4601 100644 --- a/paddlenlp/trainer/trainer.py +++ b/paddlenlp/trainer/trainer.py @@ -18,10 +18,8 @@ import collections import contextlib -import copy import inspect import math -import multiprocessing import os import random import re @@ -149,6 +147,7 @@ ) from .training_args import TrainingArguments from .utils import reshard as reshard_util +from .utils.async_save import AsyncSaver from .utils.helper import ( # nested_truncate, broadcast_dp_optimizer, broadcast_moe_optimizer, @@ -191,89 +190,6 @@ __all__ = ["Trainer"] -async_save_queue = [] -g_cpu_optimizer_state_dict = {} - - -def _save_func(obj, name_mapping, path, saved_signal_path, protocol): - if isinstance(obj, dict): - for k, v in obj.items(): - if k == "master_weights" and isinstance(v, dict): - for kk, vv in v.items(): - if isinstance(vv, paddle.Tensor): - vv.name = name_mapping["master_weights"][kk] - else: - if k in name_mapping and isinstance(v, paddle.Tensor): - v.name = name_mapping[k] - - paddle.save(obj, path, protocol) - # dump savd_siganl - with open(saved_signal_path, mode="w+") as f: - f.write("1") - - -def check_exitcode(task): - exitcode = task.exitcode - if exitcode != 0: - print(f"Error: save ckpt process failed with exitcode {exitcode}!!!") - - -def clear_async_save_task_queue(): - """ - wait until all async save task to be done. - """ - while len(async_save_queue) > 0: - task = async_save_queue.pop() - if task and task.is_alive(): - task.join(timeout=60) - if task.is_alive(): - logger.error("Error: save ckpt process timeout!!!") - async_save_queue.append(task) - else: - check_exitcode(task) - else: - check_exitcode(task) - - -def async_save_optimizer(optimizer_state_dict, path, saved_signal_path, protocol=4): - global g_cpu_optimizer_state_dict - g_cpu_optimizer_state_dict.clear() - name_mapping = {"master_weights": {}} - for k, v in optimizer_state_dict.items(): - if k == "master_weights": - g_cpu_optimizer_state_dict[k] = {} - for kk, vv in v.items(): - g_cpu_optimizer_state_dict[k][kk] = vv.pin_memory() - name_mapping[k][kk] = vv.name - elif k == "LR_Scheduler": - g_cpu_optimizer_state_dict[k] = copy.deepcopy(v) - else: - g_cpu_optimizer_state_dict[k] = v.pin_memory() - name_mapping[k] = v.name - paddle.device.synchronize() - clear_async_save_task_queue() - - attempt = 0 - ctx = multiprocessing.get_context("spawn") - - def start_process(): - nonlocal attempt - try: - p = ctx.Process( - target=_save_func, args=(g_cpu_optimizer_state_dict, name_mapping, path, saved_signal_path, protocol) - ) - p.start() - return p - except Exception as e: - print(f"Attempt {attempt + 1} failed with error: {e}") - attempt += 1 - time.sleep(1) - return start_process() - - p = start_process() - async_save_queue.append(p) - - class Trainer: """ Trainer is a simple but feature-complete training and eval loop for PaddlePaddle, optimized for PaddleNLP. @@ -440,6 +356,8 @@ def __init__( self._save_ckpt_func = dist.save_state_dict if self.args.enable_auto_parallel else paddle.save self._load_ckpt_func = dist.load_state_dict if self.args.enable_auto_parallel else paddle.load + if self.args.use_async_save: + self._async_optimizer_saver = AsyncSaver() if args.max_steps > 0: logger.info("max_steps is given, it will override any value given in num_train_epochs") @@ -2308,9 +2226,6 @@ def _save_checkpoint(self, model, metrics=None): # assert unwrap_model(model) is self.model, "internal model should be a reference to self.model" self.runtime_timer.start("checkpoint saving time") - if self.args.use_async_save: - clear_async_save_task_queue() - # Save model checkpoint checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}" @@ -2353,10 +2268,8 @@ def _save_checkpoint(self, model, metrics=None): save_path = os.path.join(output_dir, optimizer_name) if self.args.use_async_save: assert not strtobool(os.getenv("FLAG_LLM_PDC", "False")), "Dont support FLAG_LLM_PDC" - async_save_optimizer( - state_dict, - save_path, - saved_signal_path=saved_signal_path, + self._async_optimizer_saver.run( + state_dict, save_path, saved_signal_path=saved_signal_path ) else: self._save_ckpt_func(state_dict, save_path) diff --git a/paddlenlp/trainer/utils/async_save.py b/paddlenlp/trainer/utils/async_save.py new file mode 100644 index 000000000000..c652fd1e3109 --- /dev/null +++ b/paddlenlp/trainer/utils/async_save.py @@ -0,0 +1,126 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import atexit +import copy +import multiprocessing +import os +import time + +import paddle + +from paddlenlp.utils.log import logger + + +def _save_optimizer(obj, name_mapping, path, saved_signal_path, protocol): + start_time = time.time() + for k, v in obj.items(): + if k == "master_weights" and isinstance(v, dict): + for kk, vv in v.items(): + if isinstance(vv, paddle.Tensor): + vv.name = name_mapping["master_weights"][kk] + else: + if k in name_mapping and isinstance(v, paddle.Tensor): + v.name = name_mapping[k] + paddle.save(obj, path, protocol) + # dump saved_signal + with open(saved_signal_path, mode="w+") as f: + f.write("1") + f.flush() + os.fsync(f.fileno()) + end_time = time.time() + elapsed_time = end_time - start_time + logger.info(f"Async save optimizer took {elapsed_time:.6f} seconds to execute.") + + +class AsyncSaver: + def __init__(self): + self.context = multiprocessing.get_context("spawn") + self.cpu_optimizer_state_dict = {} + self.pool = self.context.Pool(1) + self.result = None + self.name_mapping = None + + atexit.register(self.shutdown) + + def run(self, optimizer_state_dict, path, saved_signal_path, protocol=4): + logger.info(f"Started saving optimizer_state_dict to {os.path.abspath(path)}.") + self._wait_for_previous_result() + + self._reset_state(path, saved_signal_path, protocol) + self._process_optimizer_state_dict(optimizer_state_dict) + + self.result = self.pool.apply_async( + _save_optimizer, + args=(self.cpu_optimizer_state_dict, self.name_mapping, self.path, self.saved_signal_path, self.protocol), + ) + + logger.info("Finished launching saving optimizer_state_dict process") + + def _wait_for_previous_result(self): + if self.result is not None: + max_retries = 5 + for retries in range(max_retries): + try: + self.result.get() + break + except Exception as e: + if retries == max_retries - 1: + raise RuntimeError(f"Failed after {max_retries} retries during async save.") + + time.sleep(1 + retries * 2) + logger.warning(f"An error occurred during async save: {e}. Retrying...") + self.result = self.pool.apply_async( + _save_optimizer, + args=( + self.cpu_optimizer_state_dict, + self.name_mapping, + self.path, + self.saved_signal_path, + self.protocol, + ), + ) + + if self.result.ready() and not self.result.successful(): + raise RuntimeError("The previous async save task failed.") + else: + pass + + def _reset_state(self, path, saved_signal_path, protocol): + self.cpu_optimizer_state_dict.clear() + self.name_mapping = {"master_weights": {}} + self.path = path + self.saved_signal_path = saved_signal_path + self.protocol = protocol + + def _process_optimizer_state_dict(self, optimizer_state_dict): + for k, v in optimizer_state_dict.items(): + if k == "master_weights": + self.cpu_optimizer_state_dict[k] = {} + for kk, vv in v.items(): + self.cpu_optimizer_state_dict[k][kk] = vv.pin_memory() + self.name_mapping[k][kk] = vv.name + elif k == "LR_Scheduler": + self.cpu_optimizer_state_dict[k] = copy.deepcopy(v) + else: + self.cpu_optimizer_state_dict[k] = v.pin_memory() + self.name_mapping[k] = v.name + paddle.device.synchronize() + + def shutdown(self): + self.pool.close() + self.pool.join() + + def __del__(self): + self.shutdown()