Skip to content

Commit 0a3faee

Browse files
authored
gpt benchmark fix v2 (#1244)
1 parent 48e58f0 commit 0a3faee

File tree

2 files changed

+2
-3
lines changed

2 files changed

+2
-3
lines changed

examples/language_model/gpt/run_pretrain.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -252,14 +252,13 @@ def do_train(args):
252252
lr_scheduler.step()
253253
optimizer.clear_grad()
254254

255-
paddle.device.cuda.synchronize()
255+
loss_numpy = loss.numpy()
256256
train_run_cost += time.time() - train_start
257257

258258
# Profile for model benchmark
259259
profiler.add_profiler_step(args.profiler_options)
260260

261261
if global_step % args.logging_freq == 0:
262-
loss_numpy = loss.numpy()
263262
speed = args.logging_freq / (
264263
train_reader_cost + train_run_cost)
265264
avg_reader_cost = train_reader_cost / args.logging_freq

tests/benchmark/run_benchmark.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ function _train(){
5555

5656
if [ $fp_item = "fp16" ]; then
5757
use_fp16_cmd="--use_amp true"
58-
if [ $dygraph_name = "dygraph" && $gpt_repo = "gpt3" ]; then
58+
if [ $dygraph_name = "dygraph" ] && [ $gpt_repo = "gpt3" ]; then
5959
use_fp16_cmd="--use_pure_fp16 true"
6060
fi
6161
fi

0 commit comments

Comments
 (0)