Skip to content

Commit

Permalink
fix format
Browse files Browse the repository at this point in the history
  • Loading branch information
ForFishes committed Aug 9, 2024
1 parent dbf395f commit 32ca80e
Show file tree
Hide file tree
Showing 2 changed files with 131 additions and 91 deletions.
96 changes: 5 additions & 91 deletions paddlenlp/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,8 @@

import collections
import contextlib
import copy
import inspect
import math
import multiprocessing
import os
import random
import re
Expand Down Expand Up @@ -149,6 +147,7 @@
)
from .training_args import TrainingArguments
from .utils import reshard as reshard_util
from .utils.async_save import AsyncSaver
from .utils.helper import ( # nested_truncate,
broadcast_dp_optimizer,
broadcast_moe_optimizer,
Expand Down Expand Up @@ -191,88 +190,6 @@

__all__ = ["Trainer"]

async_save_queue = []
g_cpu_optimizer_state_dict = {}


def _save_func(obj, name_mapping, path, saved_signal_path, protocol):
if isinstance(obj, dict):
for k, v in obj.items():
if k == "master_weights" and isinstance(v, dict):
for kk, vv in v.items():
if isinstance(vv, paddle.Tensor):
vv.name = name_mapping["master_weights"][kk]
else:
if k in name_mapping and isinstance(v, paddle.Tensor):
v.name = name_mapping[k]

paddle.save(obj, path, protocol)
# dump savd_siganl
with open(saved_signal_path, mode="w+") as f:
f.write("1")


def check_exitcode(task):
exitcode = task.exitcode
if exitcode != 0:
print(f"Error: save ckpt process failed with exitcode {exitcode}!!!")


def clear_async_save_task_queue():
"""
wait until all async save task to be done.
"""
while len(async_save_queue) > 0:
task = async_save_queue.pop()
if task and task.is_alive():
task.join(timeout=60)
if task.is_alive():
logger.error("Error: save ckpt process timeout!!!")
async_save_queue.append(task)
else:
check_exitcode(task)
else:
check_exitcode(task)


def async_save_optimizer(optimizer_state_dict, path, saved_signal_path, protocol=4):
global g_cpu_optimizer_state_dict
g_cpu_optimizer_state_dict.clear()
name_mapping = {"master_weights": {}}
for k, v in optimizer_state_dict.items():
if k == "master_weights":
g_cpu_optimizer_state_dict[k] = {}
for kk, vv in v.items():
g_cpu_optimizer_state_dict[k][kk] = vv.pin_memory()
name_mapping[k][kk] = vv.name
elif k == "LR_Scheduler":
g_cpu_optimizer_state_dict[k] = copy.deepcopy(v)
else:
g_cpu_optimizer_state_dict[k] = v.pin_memory()
name_mapping[k] = v.name
paddle.device.synchronize()
clear_async_save_task_queue()

attempt = 0
ctx = multiprocessing.get_context("spawn")

def start_process():
nonlocal attempt
try:
p = ctx.Process(
target=_save_func, args=(g_cpu_optimizer_state_dict, name_mapping, path, saved_signal_path, protocol)
)
p.start()
return p
except Exception as e:
print(f"Attempt {attempt + 1} failed with error: {e}")
attempt += 1
time.sleep(1)
return start_process()

p = start_process()
async_save_queue.append(p)


class Trainer:
"""
Expand Down Expand Up @@ -440,6 +357,8 @@ def __init__(

self._save_ckpt_func = dist.save_state_dict if self.args.enable_auto_parallel else paddle.save
self._load_ckpt_func = dist.load_state_dict if self.args.enable_auto_parallel else paddle.load
if self.args.use_async_save:
self._async_optimizer_saver = AsyncSaver()

Check warning on line 361 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L361

Added line #L361 was not covered by tests

if args.max_steps > 0:
logger.info("max_steps is given, it will override any value given in num_train_epochs")
Expand Down Expand Up @@ -2308,9 +2227,6 @@ def _save_checkpoint(self, model, metrics=None):
# assert unwrap_model(model) is self.model, "internal model should be a reference to self.model"
self.runtime_timer.start("checkpoint saving time")

if self.args.use_async_save:
clear_async_save_task_queue()

# Save model checkpoint
checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"

Expand Down Expand Up @@ -2353,10 +2269,8 @@ def _save_checkpoint(self, model, metrics=None):
save_path = os.path.join(output_dir, optimizer_name)
if self.args.use_async_save:
assert not strtobool(os.getenv("FLAG_LLM_PDC", "False")), "Dont support FLAG_LLM_PDC"
async_save_optimizer(
state_dict,
save_path,
saved_signal_path=saved_signal_path,
self._async_optimizer_saver.run(

Check warning on line 2272 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L2272

Added line #L2272 was not covered by tests
state_dict, save_path, saved_signal_path=saved_signal_path
)
else:
self._save_ckpt_func(state_dict, save_path)
Expand Down
126 changes: 126 additions & 0 deletions paddlenlp/trainer/utils/async_save.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import atexit
import copy
import multiprocessing
import os
import time

import paddle

from paddlenlp.utils.log import logger


def _save_optimizer(obj, name_mapping, path, saved_signal_path, protocol):
start_time = time.time()
for k, v in obj.items():
if k == "master_weights" and isinstance(v, dict):
for kk, vv in v.items():
if isinstance(vv, paddle.Tensor):
vv.name = name_mapping["master_weights"][kk]

Check warning on line 32 in paddlenlp/trainer/utils/async_save.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/utils/async_save.py#L27-L32

Added lines #L27 - L32 were not covered by tests
else:
if k in name_mapping and isinstance(v, paddle.Tensor):
v.name = name_mapping[k]
paddle.save(obj, path, protocol)

Check warning on line 36 in paddlenlp/trainer/utils/async_save.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/utils/async_save.py#L34-L36

Added lines #L34 - L36 were not covered by tests
# dump saved_signal
with open(saved_signal_path, mode="w+") as f:
f.write("1")
f.flush()
os.fsync(f.fileno())
end_time = time.time()
elapsed_time = end_time - start_time
logger.info(f"Async save optimizer took {elapsed_time:.6f} seconds to execute.")

Check warning on line 44 in paddlenlp/trainer/utils/async_save.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/utils/async_save.py#L38-L44

Added lines #L38 - L44 were not covered by tests


class AsyncSaver:
def __init__(self):
self.context = multiprocessing.get_context("spawn")
self.cpu_optimizer_state_dict = {}
self.pool = self.context.Pool(1)
self.result = None
self.name_mapping = None

Check warning on line 53 in paddlenlp/trainer/utils/async_save.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/utils/async_save.py#L49-L53

Added lines #L49 - L53 were not covered by tests

atexit.register(self.shutdown)

Check warning on line 55 in paddlenlp/trainer/utils/async_save.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/utils/async_save.py#L55

Added line #L55 was not covered by tests

def run(self, optimizer_state_dict, path, saved_signal_path, protocol=4):
logger.info(f"Started saving optimizer_state_dict to {os.path.abspath(path)}.")
self._wait_for_previous_result()

Check warning on line 59 in paddlenlp/trainer/utils/async_save.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/utils/async_save.py#L58-L59

Added lines #L58 - L59 were not covered by tests

self._reset_state(path, saved_signal_path, protocol)
self._process_optimizer_state_dict(optimizer_state_dict)

Check warning on line 62 in paddlenlp/trainer/utils/async_save.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/utils/async_save.py#L61-L62

Added lines #L61 - L62 were not covered by tests

self.result = self.pool.apply_async(

Check warning on line 64 in paddlenlp/trainer/utils/async_save.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/utils/async_save.py#L64

Added line #L64 was not covered by tests
_save_optimizer,
args=(self.cpu_optimizer_state_dict, self.name_mapping, self.path, self.saved_signal_path, self.protocol),
)

logger.info("Finished launching saving optimizer_state_dict process")

Check warning on line 69 in paddlenlp/trainer/utils/async_save.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/utils/async_save.py#L69

Added line #L69 was not covered by tests

def _wait_for_previous_result(self):
if self.result is not None:
max_retries = 5
for retries in range(max_retries):
try:
self.result.get()
break
except Exception as e:
if retries == max_retries - 1:
raise RuntimeError(f"Failed after {max_retries} retries during async save.")

Check warning on line 80 in paddlenlp/trainer/utils/async_save.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/utils/async_save.py#L72-L80

Added lines #L72 - L80 were not covered by tests

time.sleep(1 + retries * 2)
logger.warning(f"An error occurred during async save: {e}. Retrying...")
self.result = self.pool.apply_async(

Check warning on line 84 in paddlenlp/trainer/utils/async_save.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/utils/async_save.py#L82-L84

Added lines #L82 - L84 were not covered by tests
_save_optimizer,
args=(
self.cpu_optimizer_state_dict,
self.name_mapping,
self.path,
self.saved_signal_path,
self.protocol,
),
)

if self.result.ready() and not self.result.successful():
raise RuntimeError("The previous async save task failed.")

Check warning on line 96 in paddlenlp/trainer/utils/async_save.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/utils/async_save.py#L95-L96

Added lines #L95 - L96 were not covered by tests
else:
pass

Check warning on line 98 in paddlenlp/trainer/utils/async_save.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/utils/async_save.py#L98

Added line #L98 was not covered by tests

def _reset_state(self, path, saved_signal_path, protocol):
self.cpu_optimizer_state_dict.clear()
self.name_mapping = {"master_weights": {}}
self.path = path
self.saved_signal_path = saved_signal_path
self.protocol = protocol

Check warning on line 105 in paddlenlp/trainer/utils/async_save.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/utils/async_save.py#L101-L105

Added lines #L101 - L105 were not covered by tests

def _process_optimizer_state_dict(self, optimizer_state_dict):
for k, v in optimizer_state_dict.items():
if k == "master_weights":
self.cpu_optimizer_state_dict[k] = {}
for kk, vv in v.items():
self.cpu_optimizer_state_dict[k][kk] = vv.pin_memory()
self.name_mapping[k][kk] = vv.name
elif k == "LR_Scheduler":
self.cpu_optimizer_state_dict[k] = copy.deepcopy(v)

Check warning on line 115 in paddlenlp/trainer/utils/async_save.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/utils/async_save.py#L108-L115

Added lines #L108 - L115 were not covered by tests
else:
self.cpu_optimizer_state_dict[k] = v.pin_memory()
self.name_mapping[k] = v.name
paddle.device.synchronize()

Check warning on line 119 in paddlenlp/trainer/utils/async_save.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/utils/async_save.py#L117-L119

Added lines #L117 - L119 were not covered by tests

def shutdown(self):
self.pool.close()
self.pool.join()

Check warning on line 123 in paddlenlp/trainer/utils/async_save.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/utils/async_save.py#L122-L123

Added lines #L122 - L123 were not covered by tests

def __del__(self):
self.shutdown()

Check warning on line 126 in paddlenlp/trainer/utils/async_save.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/utils/async_save.py#L126

Added line #L126 was not covered by tests

0 comments on commit 32ca80e

Please sign in to comment.