Skip to content

Commit 1ed4262

Browse files
committed
added more comments to enwik8 transformer example
1 parent 684bf86 commit 1ed4262

File tree

1 file changed

+41
-22
lines changed

1 file changed

+41
-22
lines changed

examples/transformer/enwik8.py

Lines changed: 41 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -26,26 +26,29 @@ def layernorm(x, scope, epsilon=1e-5, relu=False):
2626
return bs.layer_norm(x, gain, bias, axis=-1, epsilon=epsilon, relu=relu)
2727

2828

29-
def conv1d(x, scope, nf, relu=False, fast_gelu=False):
29+
def conv1d(x, scope, nf, std=0.02, relu=False, fast_gelu=False):
3030
with tf.variable_scope(scope):
3131
nx = x.shape[-1].value
3232
ndims = x.shape.ndims
3333

34-
w = tf.get_variable("w", [nx, nf], initializer=tf.random_normal_initializer(stddev=0.02))
34+
# Note: param initializers are not particularly well tuned in this code
35+
w = tf.get_variable("w", [nx, nf], initializer=tf.random_normal_initializer(stddev=std))
3536
b = tf.get_variable("b", [ nf], initializer=tf.constant_initializer(0.0))
3637

3738
if hps.float16:
38-
# by setting dx_dtype to float16 we prevent useless casting in the backwards pass
39-
# our all-reduce and fused optimizers can accept fp16 natively.
39+
# By setting dx_dtype to float16 we prevent useless casting back to fp32 in the backwards pass.
40+
# Our all-reduce and fused optimizers can accept fp16 natively.
4041
w = bs.float_cast(w, dtype=tf.float16, dx_dtype=tf.float16)
4142

4243
# merge context and batch dims for more efficient matmul
4344
if ndims > 2:
4445
y_shape = tf.concat([tf.shape(x)[: ndims - 1], [nf]], axis=0)
4546
x = tf.reshape(x, [-1, nx])
4647

48+
y = tf.matmul(x, w)
49+
4750
# avoid atomics in bias grad, but be careful as tf handles temp memory badly in the presense of async ops like all-reduce
48-
y = bs.bias_relu(tf.matmul(x, w), b, relu=relu, fast_gelu=fast_gelu, atomics=False)
51+
y = bs.bias_relu(y, b, relu=relu, fast_gelu=fast_gelu, atomics=False)
4952

5053
if ndims > 2:
5154
y = tf.reshape(y, y_shape)
@@ -71,10 +74,12 @@ def causal_subblock_mask(blk_shape, head_idx, query_idx, key_idx, blk_idx):
7174
# Coarse sparse structure
7275
# Only layout[q,k] == 1 blocks are computed and materialized in memory
7376
# Block sizes of 8, 16, 32 and 64 are supported on volta fp16 tensorcores (64 being most appropriate for dense attention)
74-
# Only blocoksize 32 currently supported in fp32 on on other gpus.
77+
# Only blocksize 32 currently supported in fp32 on other gpus (sm >= 3.5).
7578
def get_blocksparse_transformer(n_timesteps, n_heads):
7679
blocksize = 64 if hps.float16 else 32
7780
n_time_blocks = n_timesteps // blocksize
81+
# The block layout can also include a head dimension if you don't want the same layout shared by all heads.
82+
# Each head just has to have the same number of active blocks (but you can always mask them away).
7883
layout = np.ones([n_time_blocks, n_time_blocks], dtype=np.bool)
7984
# No query blocks may attend to key blocks in the future.
8085
# Much more elaborate structures can be defined here aside from the usual lower triangular.
@@ -101,33 +106,38 @@ def transformer_block(x, scope, train=False):
101106
k = conv1d(h, 'proj_k', n_state)
102107
v = conv1d(h, 'proj_v', n_state)
103108

104-
bst = hps.bst_cache.get(scope)
109+
# only need to create one bst per config
110+
# we could pass this in as an external param but I like to keep the code more local
111+
bst_params = (hps.n_timesteps, hps.n_head)
112+
bst = bst_cache.get(bst_params)
105113
if bst is None:
106-
bst = get_blocksparse_transformer(hps.n_timesteps, hps.n_head)
107-
hps.bst_cache[scope] = bst
114+
bst = bst_cache[bst_params] = get_blocksparse_transformer(*bst_params)
108115

116+
# run the core bst ops, transposes for dealing with heads are fused in here.
109117
w = bst.query_key_op(q, k)
110118
w = bst.masked_softmax(w, scale=1.0/np.sqrt(n_state / hps.n_head))
111119
a = bst.weight_value_op(w, v)
112120

113-
a = conv1d(a, 'proj_a', n_state)
121+
a = conv1d(a, 'proj_a', n_state, std=0.02/hps.n_layer)
114122

115123
if train and hps.resid_pdrop > 0.0:
116124
# preserve the dropout mask through recompute
117125
key = scope + "_dropout_a"
118-
a, hps.dropout_cache[key] = bs.dropout(a, keep_prob=1.0 - hps.resid_pdrop, mask=hps.dropout_cache.get(key))
126+
a, dropout_cache[key] = bs.dropout(a, keep_prob=1.0 - hps.resid_pdrop, mask=dropout_cache.get(key))
119127

128+
# many basic tf ops are about half as fast as they should be in fp16
120129
x = bs.add(x, a)
121130

122131
m = layernorm(x, "norm_m")
123132

133+
# fast_gelu: x * sigmoid(1.702 * x)
124134
m = conv1d(m, 'proj_m1', n_state * hps.mlp_ratio, fast_gelu=True)
125135
m = conv1d(m, 'proj_m2', n_state)
126136

127137
if train and hps.resid_pdrop > 0.0:
128138
# preserve the dropout mask through recompute
129139
key = scope + "_dropout_m"
130-
m, hps.dropout_cache[key] = bs.dropout(m, keep_prob=1.0 - hps.resid_pdrop, mask=hps.dropout_cache.get(key))
140+
m, dropout_cache[key] = bs.dropout(m, keep_prob=1.0 - hps.resid_pdrop, mask=dropout_cache.get(key))
131141

132142
return bs.add(x, m)
133143

@@ -139,7 +149,7 @@ def model(xs, ys, loss_scale=None, train=False):
139149
with tf.device("/cpu:0"):
140150
if train:
141151
grad_scale = tf.reciprocal(loss_scale) if hps.float16 else 1.0
142-
global_step = tf.Variable(1.0, trainable=False)
152+
global_step = tf.get_variable("global_step", [], initializer=tf.ones_initializer(), trainable=False)
143153
learning_rate = tf.minimum(global_step * (1.0/hps.warmup_iters), 1.0) * hps.lr
144154
mpi_scale = tf.constant(1.0 / mpi_size)
145155

@@ -160,9 +170,11 @@ def model(xs, ys, loss_scale=None, train=False):
160170
x_embed = bs.float_cast(x_embed, dtype=tf.float16, dx_dtype=tf.float16)
161171
p_embed = bs.float_cast(p_embed, dtype=tf.float16, dx_dtype=tf.float16)
162172

173+
# bs.embedding_lookup can be much faster than tf version for low entropy indexes or small vocabs
163174
x = bs.embedding_lookup(x_embed, xs)
164175

165176
if train and hps.embed_pdrop > 0.0:
177+
# this part of the code is not recomputed so no need to remember the generated mask returned by bs.dropout
166178
x, _ = bs.dropout(x, keep_prob=1.0 - hps.embed_pdrop)
167179
p_embed, _ = bs.dropout(p_embed, keep_prob=1.0 - hps.embed_pdrop)
168180

@@ -171,6 +183,8 @@ def model(xs, ys, loss_scale=None, train=False):
171183

172184
for l in range(hps.n_layer):
173185
layer_name = 'layer_%d' % l
186+
# enable the recompute decorator in training
187+
# see blocksparse/grads.py if you want understand how this works
174188
h = transformer_block(h, layer_name, train=train, recompute=train and hps.recompute)
175189
grad_groups.insert(0, layer_name)
176190

@@ -207,7 +221,7 @@ def model(xs, ys, loss_scale=None, train=False):
207221
grads = [bs.scale_tensor(g, mpi_scale) for g in grads]
208222

209223
# allreduce in an mpi context
210-
# bias and gain grads will in in fp32, but have them fp16 cast prior to allreduce
224+
# bias, gain and x_embed grads will in in fp32, but have them fp16 cast prior to allreduce
211225
cast_all = tf.float16 if H.float16 else None
212226
loss = bs.allreduce(loss)
213227
grads = bs.group_allreduce(grads, params, search_strings=grad_groups, cast_all=cast_all)
@@ -292,14 +306,15 @@ def print_rank0(*args):
292306
parser.add_argument('--warmup_iters', type=int, default=1000)
293307
parser.add_argument('--enwik8_path', type=str, default='/home/scott/datasets/enwik8') # obviously change to your local path
294308
parser.add_argument('--log_interval', type=int, default=200)
295-
parser.add_argument('--profile', type=int, default=3) # exit early for nvprof profiling
309+
parser.add_argument('--profile', type=int, default=0) # exit early for nvprof profiling
296310
parser.add_argument('--float16', type=int, default=0) # only sm >= 7.0 (tensorcores)
297311
parser.add_argument('--recompute', type=int, default=0) # allow use of large contexts and/or lots of layers/params
298312

313+
# use some global vars for convenience
299314
hps = parser.parse_args()
300315

301-
hps.dropout_cache = dict()
302-
hps.bst_cache = dict()
316+
bst_cache = dict()
317+
dropout_cache = dict()
303318

304319
comm = MPI.COMM_WORLD
305320
mpi_size = comm.Get_size()
@@ -314,7 +329,7 @@ def print_rank0(*args):
314329
X = tf.placeholder(tf.uint8, shape=[hps.n_batch, hps.n_timesteps])
315330
Y = tf.placeholder(tf.uint8, shape=[hps.n_batch, hps.n_timesteps])
316331

317-
# loss_scale and grad_scale are host side scalars
332+
# loss_scale is a host side scalar
318333
with tf.device("/cpu:0"):
319334
loss_scale = tf.placeholder(tf.float32, shape=[])
320335

@@ -326,12 +341,13 @@ def print_rank0(*args):
326341
cur_loss_scale = hps.loss_scale
327342
loss_count = 0
328343

344+
# build the models for training and testing/validation
329345
train_loss, train_op, gn, ns = model(X, Y, loss_scale, train=True)
330346
valid_loss = model(X, Y)
331347

332-
# Free up some python memory
333-
hps.bst_cache = None
334-
hps.dropout_cache = None
348+
# Free up some python memory now that models are built
349+
bst_cache = None
350+
dropout_cache = None
335351
bs.clear_bst_constants()
336352

337353
config = tf.ConfigProto()
@@ -355,9 +371,11 @@ def print_rank0(*args):
355371

356372
loss, global_norm, norm_scale, _ = sess.run([train_loss, gn, ns, train_op], feed_dict={X: x, Y: y, loss_scale: cur_loss_scale})
357373

358-
if hps.float16:
374+
# auto loss scaling for fp16.
375+
if hps.float16 and np.isfinite(loss):
359376
# slowly increase loss scale but quickly drop it when inf or nan is detected in the gradients
360377
# norm_scale will be zero when this happens
378+
# You may also want to limit the change in loss_scale from any single minibatch and throw them away when this limit is exceeded.
361379
if norm_scale == 0.0:
362380
cur_loss_scale *= 0.5
363381
loss_count = 0
@@ -371,6 +389,7 @@ def print_rank0(*args):
371389
else:
372390
loss_count += 1
373391
else:
392+
# if forward pass is not finite skip any further auto loss scaling.
374393
retry = False
375394

376395
if iteration % hps.log_interval == 0:

0 commit comments

Comments
 (0)