Skip to content

Commit

Permalink
support bf16 loss in static (#7874)
Browse files Browse the repository at this point in the history
  • Loading branch information
heavyrain-lzy authored Jan 24, 2024
1 parent eafa066 commit ca79444
Showing 1 changed file with 5 additions and 1 deletion.
6 changes: 5 additions & 1 deletion llm/llama/auto_parallel/run_pretrain_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import paddle
import paddle.distributed as dist
import paddle.distributed.auto_parallel as auto
from paddle.base.data_feeder import convert_uint16_to_float
from paddle.profiler.utils import job_schedule_profiler_range

from paddlenlp.ops import Topology
Expand Down Expand Up @@ -668,7 +669,10 @@ def loss_func(loss, outputs):
outs = engine.run(micro_batch, mode="train")

if "loss" in outs:
tr_loss_step = np.sum(outs["loss"])
if outs["loss"].dtype == np.uint16:
tr_loss_step = np.sum(convert_uint16_to_float(outs["loss"]))
else:
tr_loss_step = np.sum(outs["loss"])
else:
tr_loss_step = float(0)

Expand Down

0 comments on commit ca79444

Please sign in to comment.