@@ -26,26 +26,29 @@ def layernorm(x, scope, epsilon=1e-5, relu=False):
26
26
return bs .layer_norm (x , gain , bias , axis = - 1 , epsilon = epsilon , relu = relu )
27
27
28
28
29
- def conv1d (x , scope , nf , relu = False , fast_gelu = False ):
29
+ def conv1d (x , scope , nf , std = 0.02 , relu = False , fast_gelu = False ):
30
30
with tf .variable_scope (scope ):
31
31
nx = x .shape [- 1 ].value
32
32
ndims = x .shape .ndims
33
33
34
- w = tf .get_variable ("w" , [nx , nf ], initializer = tf .random_normal_initializer (stddev = 0.02 ))
34
+ # Note: param initializers are not particularly well tuned in this code
35
+ w = tf .get_variable ("w" , [nx , nf ], initializer = tf .random_normal_initializer (stddev = std ))
35
36
b = tf .get_variable ("b" , [ nf ], initializer = tf .constant_initializer (0.0 ))
36
37
37
38
if hps .float16 :
38
- # by setting dx_dtype to float16 we prevent useless casting in the backwards pass
39
- # our all-reduce and fused optimizers can accept fp16 natively.
39
+ # By setting dx_dtype to float16 we prevent useless casting back to fp32 in the backwards pass.
40
+ # Our all-reduce and fused optimizers can accept fp16 natively.
40
41
w = bs .float_cast (w , dtype = tf .float16 , dx_dtype = tf .float16 )
41
42
42
43
# merge context and batch dims for more efficient matmul
43
44
if ndims > 2 :
44
45
y_shape = tf .concat ([tf .shape (x )[: ndims - 1 ], [nf ]], axis = 0 )
45
46
x = tf .reshape (x , [- 1 , nx ])
46
47
48
+ y = tf .matmul (x , w )
49
+
47
50
# avoid atomics in bias grad, but be careful as tf handles temp memory badly in the presense of async ops like all-reduce
48
- y = bs .bias_relu (tf . matmul ( x , w ) , b , relu = relu , fast_gelu = fast_gelu , atomics = False )
51
+ y = bs .bias_relu (y , b , relu = relu , fast_gelu = fast_gelu , atomics = False )
49
52
50
53
if ndims > 2 :
51
54
y = tf .reshape (y , y_shape )
@@ -71,10 +74,12 @@ def causal_subblock_mask(blk_shape, head_idx, query_idx, key_idx, blk_idx):
71
74
# Coarse sparse structure
72
75
# Only layout[q,k] == 1 blocks are computed and materialized in memory
73
76
# Block sizes of 8, 16, 32 and 64 are supported on volta fp16 tensorcores (64 being most appropriate for dense attention)
74
- # Only blocoksize 32 currently supported in fp32 on on other gpus.
77
+ # Only blocksize 32 currently supported in fp32 on other gpus (sm >= 3.5) .
75
78
def get_blocksparse_transformer (n_timesteps , n_heads ):
76
79
blocksize = 64 if hps .float16 else 32
77
80
n_time_blocks = n_timesteps // blocksize
81
+ # The block layout can also include a head dimension if you don't want the same layout shared by all heads.
82
+ # Each head just has to have the same number of active blocks (but you can always mask them away).
78
83
layout = np .ones ([n_time_blocks , n_time_blocks ], dtype = np .bool )
79
84
# No query blocks may attend to key blocks in the future.
80
85
# Much more elaborate structures can be defined here aside from the usual lower triangular.
@@ -101,33 +106,38 @@ def transformer_block(x, scope, train=False):
101
106
k = conv1d (h , 'proj_k' , n_state )
102
107
v = conv1d (h , 'proj_v' , n_state )
103
108
104
- bst = hps .bst_cache .get (scope )
109
+ # only need to create one bst per config
110
+ # we could pass this in as an external param but I like to keep the code more local
111
+ bst_params = (hps .n_timesteps , hps .n_head )
112
+ bst = bst_cache .get (bst_params )
105
113
if bst is None :
106
- bst = get_blocksparse_transformer (hps .n_timesteps , hps .n_head )
107
- hps .bst_cache [scope ] = bst
114
+ bst = bst_cache [bst_params ] = get_blocksparse_transformer (* bst_params )
108
115
116
+ # run the core bst ops, transposes for dealing with heads are fused in here.
109
117
w = bst .query_key_op (q , k )
110
118
w = bst .masked_softmax (w , scale = 1.0 / np .sqrt (n_state / hps .n_head ))
111
119
a = bst .weight_value_op (w , v )
112
120
113
- a = conv1d (a , 'proj_a' , n_state )
121
+ a = conv1d (a , 'proj_a' , n_state , std = 0.02 / hps . n_layer )
114
122
115
123
if train and hps .resid_pdrop > 0.0 :
116
124
# preserve the dropout mask through recompute
117
125
key = scope + "_dropout_a"
118
- a , hps . dropout_cache [key ] = bs .dropout (a , keep_prob = 1.0 - hps .resid_pdrop , mask = hps . dropout_cache .get (key ))
126
+ a , dropout_cache [key ] = bs .dropout (a , keep_prob = 1.0 - hps .resid_pdrop , mask = dropout_cache .get (key ))
119
127
128
+ # many basic tf ops are about half as fast as they should be in fp16
120
129
x = bs .add (x , a )
121
130
122
131
m = layernorm (x , "norm_m" )
123
132
133
+ # fast_gelu: x * sigmoid(1.702 * x)
124
134
m = conv1d (m , 'proj_m1' , n_state * hps .mlp_ratio , fast_gelu = True )
125
135
m = conv1d (m , 'proj_m2' , n_state )
126
136
127
137
if train and hps .resid_pdrop > 0.0 :
128
138
# preserve the dropout mask through recompute
129
139
key = scope + "_dropout_m"
130
- m , hps . dropout_cache [key ] = bs .dropout (m , keep_prob = 1.0 - hps .resid_pdrop , mask = hps . dropout_cache .get (key ))
140
+ m , dropout_cache [key ] = bs .dropout (m , keep_prob = 1.0 - hps .resid_pdrop , mask = dropout_cache .get (key ))
131
141
132
142
return bs .add (x , m )
133
143
@@ -139,7 +149,7 @@ def model(xs, ys, loss_scale=None, train=False):
139
149
with tf .device ("/cpu:0" ):
140
150
if train :
141
151
grad_scale = tf .reciprocal (loss_scale ) if hps .float16 else 1.0
142
- global_step = tf .Variable ( 1.0 , trainable = False )
152
+ global_step = tf .get_variable ( "global_step" , [], initializer = tf . ones_initializer () , trainable = False )
143
153
learning_rate = tf .minimum (global_step * (1.0 / hps .warmup_iters ), 1.0 ) * hps .lr
144
154
mpi_scale = tf .constant (1.0 / mpi_size )
145
155
@@ -160,9 +170,11 @@ def model(xs, ys, loss_scale=None, train=False):
160
170
x_embed = bs .float_cast (x_embed , dtype = tf .float16 , dx_dtype = tf .float16 )
161
171
p_embed = bs .float_cast (p_embed , dtype = tf .float16 , dx_dtype = tf .float16 )
162
172
173
+ # bs.embedding_lookup can be much faster than tf version for low entropy indexes or small vocabs
163
174
x = bs .embedding_lookup (x_embed , xs )
164
175
165
176
if train and hps .embed_pdrop > 0.0 :
177
+ # this part of the code is not recomputed so no need to remember the generated mask returned by bs.dropout
166
178
x , _ = bs .dropout (x , keep_prob = 1.0 - hps .embed_pdrop )
167
179
p_embed , _ = bs .dropout (p_embed , keep_prob = 1.0 - hps .embed_pdrop )
168
180
@@ -171,6 +183,8 @@ def model(xs, ys, loss_scale=None, train=False):
171
183
172
184
for l in range (hps .n_layer ):
173
185
layer_name = 'layer_%d' % l
186
+ # enable the recompute decorator in training
187
+ # see blocksparse/grads.py if you want understand how this works
174
188
h = transformer_block (h , layer_name , train = train , recompute = train and hps .recompute )
175
189
grad_groups .insert (0 , layer_name )
176
190
@@ -207,7 +221,7 @@ def model(xs, ys, loss_scale=None, train=False):
207
221
grads = [bs .scale_tensor (g , mpi_scale ) for g in grads ]
208
222
209
223
# allreduce in an mpi context
210
- # bias and gain grads will in in fp32, but have them fp16 cast prior to allreduce
224
+ # bias, gain and x_embed grads will in in fp32, but have them fp16 cast prior to allreduce
211
225
cast_all = tf .float16 if H .float16 else None
212
226
loss = bs .allreduce (loss )
213
227
grads = bs .group_allreduce (grads , params , search_strings = grad_groups , cast_all = cast_all )
@@ -292,14 +306,15 @@ def print_rank0(*args):
292
306
parser .add_argument ('--warmup_iters' , type = int , default = 1000 )
293
307
parser .add_argument ('--enwik8_path' , type = str , default = '/home/scott/datasets/enwik8' ) # obviously change to your local path
294
308
parser .add_argument ('--log_interval' , type = int , default = 200 )
295
- parser .add_argument ('--profile' , type = int , default = 3 ) # exit early for nvprof profiling
309
+ parser .add_argument ('--profile' , type = int , default = 0 ) # exit early for nvprof profiling
296
310
parser .add_argument ('--float16' , type = int , default = 0 ) # only sm >= 7.0 (tensorcores)
297
311
parser .add_argument ('--recompute' , type = int , default = 0 ) # allow use of large contexts and/or lots of layers/params
298
312
313
+ # use some global vars for convenience
299
314
hps = parser .parse_args ()
300
315
301
- hps . dropout_cache = dict ()
302
- hps . bst_cache = dict ()
316
+ bst_cache = dict ()
317
+ dropout_cache = dict ()
303
318
304
319
comm = MPI .COMM_WORLD
305
320
mpi_size = comm .Get_size ()
@@ -314,7 +329,7 @@ def print_rank0(*args):
314
329
X = tf .placeholder (tf .uint8 , shape = [hps .n_batch , hps .n_timesteps ])
315
330
Y = tf .placeholder (tf .uint8 , shape = [hps .n_batch , hps .n_timesteps ])
316
331
317
- # loss_scale and grad_scale are host side scalars
332
+ # loss_scale is a host side scalar
318
333
with tf .device ("/cpu:0" ):
319
334
loss_scale = tf .placeholder (tf .float32 , shape = [])
320
335
@@ -326,12 +341,13 @@ def print_rank0(*args):
326
341
cur_loss_scale = hps .loss_scale
327
342
loss_count = 0
328
343
344
+ # build the models for training and testing/validation
329
345
train_loss , train_op , gn , ns = model (X , Y , loss_scale , train = True )
330
346
valid_loss = model (X , Y )
331
347
332
- # Free up some python memory
333
- hps . bst_cache = None
334
- hps . dropout_cache = None
348
+ # Free up some python memory now that models are built
349
+ bst_cache = None
350
+ dropout_cache = None
335
351
bs .clear_bst_constants ()
336
352
337
353
config = tf .ConfigProto ()
@@ -355,9 +371,11 @@ def print_rank0(*args):
355
371
356
372
loss , global_norm , norm_scale , _ = sess .run ([train_loss , gn , ns , train_op ], feed_dict = {X : x , Y : y , loss_scale : cur_loss_scale })
357
373
358
- if hps .float16 :
374
+ # auto loss scaling for fp16.
375
+ if hps .float16 and np .isfinite (loss ):
359
376
# slowly increase loss scale but quickly drop it when inf or nan is detected in the gradients
360
377
# norm_scale will be zero when this happens
378
+ # You may also want to limit the change in loss_scale from any single minibatch and throw them away when this limit is exceeded.
361
379
if norm_scale == 0.0 :
362
380
cur_loss_scale *= 0.5
363
381
loss_count = 0
@@ -371,6 +389,7 @@ def print_rank0(*args):
371
389
else :
372
390
loss_count += 1
373
391
else :
392
+ # if forward pass is not finite skip any further auto loss scaling.
374
393
retry = False
375
394
376
395
if iteration % hps .log_interval == 0 :
0 commit comments