Skip to content

Commit

Permalink
Training works, just needs tweaks with pad tokens.
Browse files Browse the repository at this point in the history
  • Loading branch information
ben-baran committed Aug 18, 2018
1 parent 8752e9b commit 754284a
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 43 deletions.
4 changes: 2 additions & 2 deletions test_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
#for i in range(1):
#print('inputs:', vocab.to_tokens([int(x) for x in in_tokens[:, i].asnumpy()]))
#print('outputs:', vocab.to_tokens([int(x) for x in out_tokens.reshape((-1, 32))[:, i].asnumpy()]))
#print('pre context:', vocab.to_tokens([int(x) for x in pre_contexts[0, :, i].asnumpy()]))
#print('post context:', vocab.to_tokens([int(x) for x in post_contexts[0, :, i].asnumpy()]))
#print('pre context:', vocab.to_tokens([int(x) for x in pre_contexts[i, :, 0].asnumpy()]))
#print('post context:', vocab.to_tokens([int(x) for x in post_contexts[i, :, 0].asnumpy()]))

toc = time.time()
print("Total number of context groups:", n_context_groups)
Expand Down
81 changes: 42 additions & 39 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
argument_parser.add_argument("-b", "--batch-size", default = 32, type = int,
help = "Number of frame groups to process at a time")

argument_parser.add_argument("-nh", "--num-hidden-per-side", default = 256, type = int,
argument_parser.add_argument("-nh", "--hidden-per-side", default = 256, type = int,
help = "Number of hidden weights per GRU layer in the forward/backward networks")

argument_parser.add_argument("-nl", "--num-layers", default = 2, type = int,
Expand All @@ -30,6 +30,9 @@
argument_parser.add_argument("--embedding-size", default = 256, type = int,
help = "Size of embeddings")

argument_parser.add_argument("--dropout", default = 0.5, type = float,
help = "Amount to apply dropout.")

argument_parser.add_argument("--clip", default = 0.2, type = float,
help = "Amount to clip weights by.")

Expand All @@ -42,7 +45,10 @@
argument_parser.add_argument("--time-limit", default = int(1e12), type = int,
help = "Maximum number of seconds this can train.")

argument_parser.add_argument('--tie_weights', action = 'store_true',
help = 'If you tie the weights of the decoder and encoder, you have to have an extra funneling step.')

argument_parser.set_defaults(tie_weights = False)
options = argument_parser.parse_args()

with open('data_options.json') as data_options_f:
Expand All @@ -69,64 +75,61 @@ def __init__(self, vocab, hidden_per_side, num_layers, embedding_size,
self.hidden_full = 2 * hidden_per_side
self.drop = mx.gluon.nn.Dropout(dropout)
self.embedder = mx.gluon.nn.Embedding(vocab.total_tokens, embedding_size, weight_initializer = mx.init.Uniform(0.1))
self.forward_rnn = mx.gluon.rnn.GRU(hidden_per_side, num_layers, dropout = dropout, input_size = embedding_size)
self.backward_rnn = mx.gluon.rnn.GRU(hidden_per_side, num_layers, dropout = dropout, input_size = embedding_size)
self.output_rnn = mx.gluon.rnn.GRU(self.hidden_full, num_layers, dropout = dropout, input_size = embedding_size)
self.forward_rnn = mx.gluon.rnn.GRU(hidden_per_side, num_layers, dropout = dropout, input_size = embedding_size, prefix = 'forward_rnn_')
self.backward_rnn = mx.gluon.rnn.GRU(hidden_per_side, num_layers, dropout = dropout, input_size = embedding_size, prefix = 'backward_rnn_')
self.output_rnn = mx.gluon.rnn.GRU(self.hidden_full, num_layers, dropout = dropout, input_size = embedding_size, prefix = 'output_rnn_')
if tie_weights:
contract_hiddens = mx.gluon.nn.Dense(hidden_per_side, in_units = self.hidden_full)
decoder = mx.gluon.nn.Dense(vocab.total_tokens, in_units = hidden_per_side, params = self.encoder.params)
self.decoder = lambda x: decoder(contract_hiddens(x))
self._contract_hiddens = mx.gluon.nn.Dense(hidden_per_side, in_units = self.hidden_full)
self._decoder = mx.gluon.nn.Dense(vocab.total_tokens, in_units = hidden_per_side, params = self.encoder.params)
self.decoder = lambda x: self._decoder(self._contract_hiddens(x))
else:
decoder = mx.gluon.nn.Dense(vocab.total_tokens, in_units = self.hidden_full)
self.decoder = mx.gluon.nn.Dense(vocab.total_tokens, in_units = self.hidden_full)

def forward(self, forward_contexts, backward_contexts, predict_in, is_training = True):
dropout_mode = 'training' if is_training else 'always'
def forward(self, forward_contexts, backward_contexts, predict_in):
batch_size = predict_in.shape[1]
n_contexts = forward_contexts.shape[0]

# TODO(Ben) also add in option for variance?
avg_f_hidden, avg_b_hidden = None, None
f_hiddens, b_hiddens = [], []
for ci in range(n_contexts):
f_embed = self.drop(self.embedder(forward_contexts[ci]), mode = dropout_mode)
b_embed = self.drop(self.embedder(backward_contexts[ci]), mode = dropout_mode)
f_embed = self.drop(self.embedder(forward_contexts[ci]))
b_embed = self.drop(self.embedder(backward_contexts[ci]))

f_hidden = self.forward_rnn.begin_state(func = mx.nd.zeros, batch_size = batch_size, ctx = ctx)
_, f_hidden = self.forward_rnn(f_embed, f_hidden)
b_hidden = self.nackward_rnn.begin_state(func = mx.nd.zeros, batch_size = batch_size, ctx = ctx)
_, b_hidden = self.forward_rnn(f_embed, b_hidden)
b_hidden = self.backward_rnn.begin_state(func = mx.nd.zeros, batch_size = batch_size, ctx = ctx)
_, b_hidden = self.backward_rnn(f_embed, b_hidden)

if avg_f_hidden is None:
avg_f_hidden = f_hidden
avg_b_hidden = b_hidden
else:
for hid_i in range(len(f_hidden)):
avg_f_hidden[hid_i] += f_hidden[hid_i]
avg_b_hidden[hid_i] += b_hidden[hid_i]
combined_hidden = [mx.nd.concatenate((f_h, b_h)) / n_contexts for f_h, b_h in zip(avg_f_hidden, avg_b_hidden)]
f_hiddens.append(f_hidden[0]) # always seems to be of length 1. Is there an edge case?
b_hiddens.append(b_hidden[0])
f_hidden_sum = mx.nd.add_n(*f_hiddens)
b_hidden_sum = mx.nd.add_n(*b_hiddens)
# combined_hidden = [mx.nd.concatenate((f_h, b_h)) / n_contexts for f_h, b_h in zip(avg_f_hidden, avg_b_hidden)]
combined_hidden = mx.nd.concat(f_hidden_sum, b_hidden_sum, dim = 2) / n_contexts

predict_embed = self.drop(self.embedder(predict_in), mode = dropout_mode)
predict_embed = self.drop(self.embedder(predict_in))
output, _ = self.output_rnn(predict_embed, combined_hidden)
output = self.drop(output, mode = dropout_mode)
output_decoded = self.decoder(output.reshape((-1, 2 * self.num_hidden)))
output = self.drop(output)
output_decoded = self.decoder(output.reshape((-1, self.hidden_full)))
return output_decoded

model = BidirectionalModel(vocab,
hidden_per_side = options.hidden_per_side,
num_layers = options.num_layers,
embedding_sizee = options.embedding_size,
embedding_size = options.embedding_size,
dropout = options.dropout,
tie_weights = options.tie_weights)
model.hybridize()
model.collect_params().initialize(mx.init.Xavier(), ctx = ctx)
trainer = mx.gluon.Trainer(model.collect_params(), 'adam', {'learning_rate': options.learning_rate})
loss = mx.gluon.loss.SoftmaxCrossEntropyLoss()

validation_data = ContextLoader(vocab, batch_size = options.batch_size, folder_name = 'data/val_small')
validation_data = ContextLoader(vocab, batch_size = options.batch_size, folder_name = 'data/val_small/', ctx = ctx)
def validation_loss(num_minibatches = 10, verbose = False):
global test_batches
i = 0
avg_loss = 0.0
for pre_contexts, post_contexts, in_tokens, target_tokens in validation_data:
for pre_contexts, post_contexts, in_tokens, target_tokens in validation_data.iterator():
i += 1
output_tokens = model(pre_contexts, post_contexts, in_tokens)

Expand Down Expand Up @@ -154,31 +157,31 @@ def validation_loss(num_minibatches = 10, verbose = False):
validation_losses = []
start_global = time.time()
for epoch in range(options.epochs):
i = 0
iteration = 0
cur_lr = options.learning_rate * options.lr_decay ** epoch
print("\nCurrent learning rate: %f" % cur_lr)
trainer.set_learning_rate(cur_lr)
train_data = ContextLoader(vocab, batch_size = 32, folder_name = 'data/train_small', ctx = ctx)
for pre_contexts, post_contexts, in_tokens, target_tokens in train_data:
train_data = ContextLoader(vocab, batch_size = 32, folder_name = 'data/train_small/', ctx = ctx)
for pre_contexts, post_contexts, in_tokens, target_tokens in train_data.iterator():
with mx.autograd.record():
output_tokens = model(pre_contexts, post_contexts, in_tokens)
L = loss(output_tokens, target_tokens) # TODO(Ben): zero out the irrelevant padding tokens
L.backward()
grads = [i.grad(ctx) for i in model.collect_params().values()]
mx.gluon.utils.clip_global_norm(grads, options.clip * data_options['context_width'] * options.batch_size) # TODO(Ben): adjust clipping for type of network
trainer.step(batch_size)
trainer.step(options.batch_size)

if i % 100 == 0:
if iteration % 100 == 0:
validation_losses.append(validation_loss())
if i % 500 == 0:
print("i = %d. Saving." % i)
model.save_params(save_dir + 'epoch-%d-i-%.6d.params' % (epoch, len(validation_losses)))
if iteration % 500 == 0:
print("%dth iteration. Saving." % iteration)
model.save_params(save_dir + 'epoch-%d-%.6d.params' % (epoch, len(validation_losses)))
validation_loss(num_minibatches = 1, verbose = True)
np.save(save_dir + 'validation_losses', np.array(validation_losses))
if time.time() - start_global > options.time_limit:
print("Time limit reached. Ending epoch.")
break
i += 1
print("Epoch completed. %d iterations" % i)
iteration += 1
print("Epoch completed. %d iterations" % iteration)
if time.time() - start_global > options.time_limit:
break
4 changes: 2 additions & 2 deletions utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ def _try_load(self, n_contexts, fin):
output_tokens = [u + [self.pad_token_id for i in range(longest_var - len(u))] for u in output_tokens]
input_tokens = mx.nd.array(input_tokens, ctx = self.ctx).T
output_tokens = mx.nd.array(output_tokens, ctx = self.ctx).T.reshape((-1,))
pre_contexts = mx.nd.array(pre_contexts, ctx = self.ctx).transpose((0, 2, 1))
post_contexts = mx.nd.array(post_contexts, ctx = self.ctx).transpose((0, 2, 1))
pre_contexts = mx.nd.array(pre_contexts, ctx = self.ctx).transpose((1, 2, 0))
post_contexts = mx.nd.array(post_contexts, ctx = self.ctx).transpose((1, 2, 0))

return pre_contexts.T, post_contexts.T, input_tokens, output_tokens

0 comments on commit 754284a

Please sign in to comment.