Skip to content

Commit 6e12b5a

Browse files
committed
[Transformer-XL/TF] Updating for Ampere
1 parent 2860d6f commit 6e12b5a

File tree

7 files changed

+236
-117
lines changed

7 files changed

+236
-117
lines changed

TensorFlow/LanguageModeling/Transformer-XL/Dockerfile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
ARG FROM_IMAGE_NAME=nvcr.io/nvidia/tensorflow:19.12-tf1-py3
1+
ARG FROM_IMAGE_NAME=nvcr.io/nvidia/tensorflow:20.06-tf1-py3
22
FROM ${FROM_IMAGE_NAME}
33

44
WORKDIR /workspace/transformer-xl/tf

TensorFlow/LanguageModeling/Transformer-XL/README.md

Lines changed: 205 additions & 87 deletions
Large diffs are not rendered by default.

TensorFlow/LanguageModeling/Transformer-XL/getdata.sh

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -34,22 +34,26 @@ echo "---"
3434
mkdir -p data
3535
cd data
3636

37-
if [[ ! -d 'wikitext-2' ]]; then
38-
echo "- Downloading WikiText-2 (WT2)"
39-
wget --quiet --continue https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-v1.zip
40-
unzip -q wikitext-2-v1.zip
41-
cd wikitext-2
37+
echo "- Downloading WikiText-103 (WT2)"
38+
if [[ ! -d 'wikitext-103' ]]; then
39+
wget --continue https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-103-v1.zip
40+
unzip -q wikitext-103-v1.zip
41+
cd wikitext-103
4242
mv wiki.train.tokens train.txt
4343
mv wiki.valid.tokens valid.txt
4444
mv wiki.test.tokens test.txt
4545
cd ..
4646
fi
4747

48-
echo "- Downloading WikiText-103 (WT2)"
49-
if [[ ! -d 'wikitext-103' ]]; then
50-
wget --continue https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-103-v1.zip
51-
unzip -q wikitext-103-v1.zip
52-
cd wikitext-103
48+
if [[ $1 != 'full' ]]; then
49+
exit 0
50+
fi
51+
52+
if [[ ! -d 'wikitext-2' ]]; then
53+
echo "- Downloading WikiText-2 (WT2)"
54+
wget --quiet --continue https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-v1.zip
55+
unzip -q wikitext-2-v1.zip
56+
cd wikitext-2
5357
mv wiki.train.tokens train.txt
5458
mv wiki.valid.tokens valid.txt
5559
mv wiki.test.tokens test.txt
-311 KB
Loading

TensorFlow/LanguageModeling/Transformer-XL/tf/main.py

Lines changed: 15 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,6 @@
1919

2020
import numpy as np
2121

22-
flags.DEFINE_integer("num_core_per_host", default=8,
23-
help="Number of cores per host")
2422
flags.DEFINE_bool('horovod', True, 'Use Horovod ')
2523
# Experiment (data/checkpoint/directory) config
2624
flags.DEFINE_string("raport_file", default="summary.json",
@@ -41,8 +39,8 @@
4139
help="Checkpoint path for do_test evaluation."
4240
"If set, model_dir will be ignored."
4341
"If unset, will use the latest ckpt in model_dir.")
44-
flags.DEFINE_bool("fp16", default=False,
45-
help="Whether to enable AMP ops.")
42+
flags.DEFINE_bool("amp", default=False,
43+
help="Whether to enable AMP ops. When false, uses TF32 on A100 and FP32 on V100 GPUS.")
4644
flags.DEFINE_bool("jit_optimizer", default=True,
4745
help="Whether to enable XLA on optimizer")
4846

@@ -211,10 +209,10 @@ def single_core_graph(n_token, cutoffs, is_training, inp, tgt, mems):
211209
return model_ret
212210

213211

214-
def train(n_token, cutoffs, rank, local_rank, size):
212+
def train(n_token, cutoffs, rank, local_rank, num_core_per_host):
215213

216214
meters = {}
217-
warmup = 2 + 12/size
215+
warmup = 3
218216
meters['train_throughput'] = AverageMeter(warmup=warmup)
219217
train_batch_size = FLAGS.train_batch_size // FLAGS.batch_chunk
220218
##### Get input function and model function
@@ -223,7 +221,7 @@ def train(n_token, cutoffs, rank, local_rank, size):
223221
split="train",
224222
per_host_bsz=train_batch_size,
225223
tgt_len=FLAGS.tgt_len,
226-
num_core_per_host=FLAGS.num_core_per_host,
224+
num_core_per_host=num_core_per_host,
227225
num_hosts=1)
228226

229227
tf.logging.info("num of batches {}".format(train_record_info["num_batch"]))
@@ -235,7 +233,7 @@ def train(n_token, cutoffs, rank, local_rank, size):
235233

236234
inputs, labels = train_set.make_one_shot_iterator().get_next()
237235

238-
per_core_bsz = train_batch_size // FLAGS.num_core_per_host
236+
per_core_bsz = train_batch_size // num_core_per_host
239237

240238
with tf.variable_scope(tf.get_variable_scope()):
241239
mems = [tf.Variable(tf.zeros([FLAGS.mem_len, per_core_bsz, FLAGS.d_model], tf.float32), trainable=False)
@@ -327,7 +325,7 @@ def train(n_token, cutoffs, rank, local_rank, size):
327325

328326
if curr_step > 0 and curr_step % FLAGS.log_interval == 0:
329327
curr_loss = total_loss / (curr_step - prev_step)
330-
throughput = target_tokens * size / (time.time()-start_time)
328+
throughput = target_tokens * num_core_per_host / (time.time()-start_time)
331329
meters['train_throughput'].update(throughput)
332330
if rank == 0:
333331
tf.logging.info("step {} | lr {:8.9f} "
@@ -367,7 +365,7 @@ def evaluate(n_token, cutoffs):
367365
split=FLAGS.eval_split,
368366
per_host_bsz=FLAGS.eval_batch_size,
369367
tgt_len=FLAGS.tgt_len,
370-
num_core_per_host=FLAGS.num_core_per_host,
368+
num_core_per_host=1, #multicore inference is not supported
371369
num_hosts=1)
372370

373371
meters = {}
@@ -417,7 +415,8 @@ def evaluate(n_token, cutoffs):
417415
else:
418416
eval_ckpt_path = FLAGS.eval_ckpt_path
419417
tf.logging.info("Evaluate {}".format(eval_ckpt_path))
420-
saver.restore(sess, eval_ckpt_path)
418+
if FLAGS.eval_ckpt_path != "random":
419+
saver.restore(sess, eval_ckpt_path)
421420

422421
fetches = [loss, new_mems, target_tokens]
423422

@@ -457,7 +456,7 @@ def evaluate(n_token, cutoffs):
457456
start_time = time.time()
458457
avg_loss = total_loss / total_cnt
459458
latency_data = np.array(meters['eval_latency'].vals)
460-
tf.logging.info("Evaluating with: bs {}, math {} ".format(FLAGS.eval_batch_size, "fp16" if FLAGS.fp16 else "fp32"))
459+
tf.logging.info("Evaluating with: bs {}, math {} ".format(FLAGS.eval_batch_size, "amp" if FLAGS.amp else "fp32"))
461460
tf.logging.info("| loss {:.2f} | pplx {:>7.2f}, bpc {:>7.4f}, tok/s {:>6.1f}, ms/batch {:>4.2f}".format(
462461
avg_loss, math.exp(avg_loss), avg_loss / math.log(2), meters['eval_throughput'].avg, meters['eval_latency'].avg))
463462
summary = {
@@ -476,17 +475,17 @@ def evaluate(n_token, cutoffs):
476475

477476

478477
def main(unused_argv):
479-
rank, local_rank, size = 0, 0, 1
478+
rank, local_rank, num_core_per_host = 0, 0, 1
480479
if FLAGS.horovod:
481480
hvd.init()
482481
rank = hvd.rank()
483482
local_rank = hvd.local_rank()
484-
size = hvd.size()
483+
num_core_per_host = hvd.size() #singlenode support
485484
del unused_argv # Unused
486485

487486
tf.logging.set_verbosity(tf.logging.INFO)
488487

489-
if FLAGS.fp16:
488+
if FLAGS.amp:
490489
os.environ["TF_ENABLE_AUTO_MIXED_PRECISION"] = "1"
491490
else:
492491
os.environ["TF_ENABLE_AUTO_MIXED_PRECISION"] = "0"
@@ -500,7 +499,7 @@ def main(unused_argv):
500499
setup_dllogger(enabled=True, filename=FLAGS.raport_file, rank=rank)
501500

502501
if FLAGS.do_train:
503-
train(n_token, cutoffs, rank, local_rank, size)
502+
train(n_token, cutoffs, rank, local_rank, num_core_per_host)
504503
if FLAGS.do_eval:
505504
evaluate(n_token, cutoffs)
506505

TensorFlow/LanguageModeling/Transformer-XL/tf/run_wt103_base.sh

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,6 @@ elif [[ $1 == 'train' ]]; then
6464
--warmup_steps=1000 \
6565
--tgt_len=${TGT_LEN} \
6666
--mem_len=${MEM_LEN} \
67-
--num_core_per_host=${NUM_CORE} \
6867
${@:3}
6968
elif [[ $1 == 'eval' ]]; then
7069
echo 'Run evaluation...'
@@ -87,7 +86,6 @@ elif [[ $1 == 'eval' ]]; then
8786
--mem_len=${TEST_MEM_LEN} \
8887
--clamp_len=${TEST_CLAMP_LEN} \
8988
--same_length=True \
90-
--num_core_per_host=${TEST_NUM_CORE} \
9189
--do_train=False \
9290
--do_eval=True \
9391
--horovod=False \

TensorFlow/LanguageModeling/Transformer-XL/tf/scripts/inference_benchmark.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
BATCH_SIZES=(1 2 4 8 16 32)
1818
# "empty" MATH corresponds to fp32
19-
MATHS=("" "--fp16")
19+
MATHS=("" "--amp")
2020

2121

2222
for (( j = 0; j < ${#BATCH_SIZES[@]}; j++ )); do

0 commit comments

Comments
 (0)