19
19
20
20
import numpy as np
21
21
22
- flags .DEFINE_integer ("num_core_per_host" , default = 8 ,
23
- help = "Number of cores per host" )
24
22
flags .DEFINE_bool ('horovod' , True , 'Use Horovod ' )
25
23
# Experiment (data/checkpoint/directory) config
26
24
flags .DEFINE_string ("raport_file" , default = "summary.json" ,
41
39
help = "Checkpoint path for do_test evaluation."
42
40
"If set, model_dir will be ignored."
43
41
"If unset, will use the latest ckpt in model_dir." )
44
- flags .DEFINE_bool ("fp16 " , default = False ,
45
- help = "Whether to enable AMP ops." )
42
+ flags .DEFINE_bool ("amp " , default = False ,
43
+ help = "Whether to enable AMP ops. When false, uses TF32 on A100 and FP32 on V100 GPUS. " )
46
44
flags .DEFINE_bool ("jit_optimizer" , default = True ,
47
45
help = "Whether to enable XLA on optimizer" )
48
46
@@ -211,10 +209,10 @@ def single_core_graph(n_token, cutoffs, is_training, inp, tgt, mems):
211
209
return model_ret
212
210
213
211
214
- def train (n_token , cutoffs , rank , local_rank , size ):
212
+ def train (n_token , cutoffs , rank , local_rank , num_core_per_host ):
215
213
216
214
meters = {}
217
- warmup = 2 + 12 / size
215
+ warmup = 3
218
216
meters ['train_throughput' ] = AverageMeter (warmup = warmup )
219
217
train_batch_size = FLAGS .train_batch_size // FLAGS .batch_chunk
220
218
##### Get input function and model function
@@ -223,7 +221,7 @@ def train(n_token, cutoffs, rank, local_rank, size):
223
221
split = "train" ,
224
222
per_host_bsz = train_batch_size ,
225
223
tgt_len = FLAGS .tgt_len ,
226
- num_core_per_host = FLAGS . num_core_per_host ,
224
+ num_core_per_host = num_core_per_host ,
227
225
num_hosts = 1 )
228
226
229
227
tf .logging .info ("num of batches {}" .format (train_record_info ["num_batch" ]))
@@ -235,7 +233,7 @@ def train(n_token, cutoffs, rank, local_rank, size):
235
233
236
234
inputs , labels = train_set .make_one_shot_iterator ().get_next ()
237
235
238
- per_core_bsz = train_batch_size // FLAGS . num_core_per_host
236
+ per_core_bsz = train_batch_size // num_core_per_host
239
237
240
238
with tf .variable_scope (tf .get_variable_scope ()):
241
239
mems = [tf .Variable (tf .zeros ([FLAGS .mem_len , per_core_bsz , FLAGS .d_model ], tf .float32 ), trainable = False )
@@ -327,7 +325,7 @@ def train(n_token, cutoffs, rank, local_rank, size):
327
325
328
326
if curr_step > 0 and curr_step % FLAGS .log_interval == 0 :
329
327
curr_loss = total_loss / (curr_step - prev_step )
330
- throughput = target_tokens * size / (time .time ()- start_time )
328
+ throughput = target_tokens * num_core_per_host / (time .time ()- start_time )
331
329
meters ['train_throughput' ].update (throughput )
332
330
if rank == 0 :
333
331
tf .logging .info ("step {} | lr {:8.9f} "
@@ -367,7 +365,7 @@ def evaluate(n_token, cutoffs):
367
365
split = FLAGS .eval_split ,
368
366
per_host_bsz = FLAGS .eval_batch_size ,
369
367
tgt_len = FLAGS .tgt_len ,
370
- num_core_per_host = FLAGS . num_core_per_host ,
368
+ num_core_per_host = 1 , #multicore inference is not supported
371
369
num_hosts = 1 )
372
370
373
371
meters = {}
@@ -417,7 +415,8 @@ def evaluate(n_token, cutoffs):
417
415
else :
418
416
eval_ckpt_path = FLAGS .eval_ckpt_path
419
417
tf .logging .info ("Evaluate {}" .format (eval_ckpt_path ))
420
- saver .restore (sess , eval_ckpt_path )
418
+ if FLAGS .eval_ckpt_path != "random" :
419
+ saver .restore (sess , eval_ckpt_path )
421
420
422
421
fetches = [loss , new_mems , target_tokens ]
423
422
@@ -457,7 +456,7 @@ def evaluate(n_token, cutoffs):
457
456
start_time = time .time ()
458
457
avg_loss = total_loss / total_cnt
459
458
latency_data = np .array (meters ['eval_latency' ].vals )
460
- tf .logging .info ("Evaluating with: bs {}, math {} " .format (FLAGS .eval_batch_size , "fp16 " if FLAGS .fp16 else "fp32" ))
459
+ tf .logging .info ("Evaluating with: bs {}, math {} " .format (FLAGS .eval_batch_size , "amp " if FLAGS .amp else "fp32" ))
461
460
tf .logging .info ("| loss {:.2f} | pplx {:>7.2f}, bpc {:>7.4f}, tok/s {:>6.1f}, ms/batch {:>4.2f}" .format (
462
461
avg_loss , math .exp (avg_loss ), avg_loss / math .log (2 ), meters ['eval_throughput' ].avg , meters ['eval_latency' ].avg ))
463
462
summary = {
@@ -476,17 +475,17 @@ def evaluate(n_token, cutoffs):
476
475
477
476
478
477
def main (unused_argv ):
479
- rank , local_rank , size = 0 , 0 , 1
478
+ rank , local_rank , num_core_per_host = 0 , 0 , 1
480
479
if FLAGS .horovod :
481
480
hvd .init ()
482
481
rank = hvd .rank ()
483
482
local_rank = hvd .local_rank ()
484
- size = hvd .size ()
483
+ num_core_per_host = hvd .size () #singlenode support
485
484
del unused_argv # Unused
486
485
487
486
tf .logging .set_verbosity (tf .logging .INFO )
488
487
489
- if FLAGS .fp16 :
488
+ if FLAGS .amp :
490
489
os .environ ["TF_ENABLE_AUTO_MIXED_PRECISION" ] = "1"
491
490
else :
492
491
os .environ ["TF_ENABLE_AUTO_MIXED_PRECISION" ] = "0"
@@ -500,7 +499,7 @@ def main(unused_argv):
500
499
setup_dllogger (enabled = True , filename = FLAGS .raport_file , rank = rank )
501
500
502
501
if FLAGS .do_train :
503
- train (n_token , cutoffs , rank , local_rank , size )
502
+ train (n_token , cutoffs , rank , local_rank , num_core_per_host )
504
503
if FLAGS .do_eval :
505
504
evaluate (n_token , cutoffs )
506
505
0 commit comments