Skip to content

Commit

Permalink
Fix load RNG compatibility. (#8451)
Browse files Browse the repository at this point in the history
  • Loading branch information
ZHUI authored May 16, 2024
1 parent debb2ad commit fc860a3
Showing 1 changed file with 6 additions and 9 deletions.
15 changes: 6 additions & 9 deletions paddlenlp/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1591,16 +1591,13 @@ def _load_rng_state(self, checkpoint):
if os.path.isfile(rng_file):
rng_file_list = paddle.load(rng_file, return_numpy=True)
paddle.distributed.broadcast_object_list(rng_file_list, src=0)
# if rng_file_list still empty, then use old style rng_state
# if rng_file_list still empty, not log rng state.
if rng_file_list[0] is None:
rng_file = os.path.join(checkpoint, f"rng_state_{process_index}.pth")
if not os.path.isfile(rng_file):
logger.info(
f"Didn't find an RNG file for process {process_index}, if you are resuming a training that "
"wasn't launched in a distributed fashion, reproducibility is not guaranteed."
)
return
checkpoint_rng_state = paddle.load(rng_file, return_numpy=True)
logger.info(
f"Didn't find an RNG file for process {process_index}, if you are resuming a training that "
"wasn't launched in a distributed fashion, reproducibility is not guaranteed."
)
return
else:
checkpoint_rng_state = rng_file_list[process_index]
else:
Expand Down

0 comments on commit fc860a3

Please sign in to comment.