Skip to content

Commit ed6a80d

Browse files
committed
Learning correct sequences, fixed save bug, added hyperparameters pass
1 parent ba05854 commit ed6a80d

File tree

1 file changed

+27
-8
lines changed

1 file changed

+27
-8
lines changed

train.py

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,11 @@
1616
help="Initial learning rate.")
1717
parser.add_argument('--vocabulary_size', type=int, default=20000,
1818
help="Keep only the n most common words of the training data.")
19-
parser.add_argument('--batch_size', type=int, default=16,
19+
parser.add_argument('--batch_size', type=int, default=128,
2020
help="Stochastic gradient descent minibatch size.")
2121
parser.add_argument('--output_size', type=int, default=512,
2222
help="Number of hidden units for the encoder and decoder GRUs.")
23-
parser.add_argument('--max_length', type=int, default=40,
23+
parser.add_argument('--max_sequence_length', type=int, default=40,
2424
help="Truncate input and output sentences to maximum length n.")
2525
parser.add_argument('--sample_prob', type=float, default=0.,
2626
help="Decoder probability to sample from its predictions duing training.")
@@ -58,15 +58,19 @@ def parse_and_pad(seq):
5858
serialized=seq, sequence_features=sequence_features)
5959
# Pad the sequence
6060
t = sequence_parsed["tokens"]
61-
return tf.pad(t, [[0, FLAGS.max_length - tf.shape(t)[0]]])
61+
if FLAGS.eos_token:
62+
t = tf.pad(t, [[0, 1]], constant_values=3)
63+
return tf.pad(t, [[0, FLAGS.max_sequence_length - tf.shape(t)[0]]])
6264

6365

6466
def train_iterator(filenames):
6567
"""Build the input pipeline for training.."""
6668

6769
def _single_iterator(skip):
6870
dataset = tf.data.TFRecordDataset(filenames)
69-
dataset = dataset.map(parse_and_pad) # TODO: add option for parallel calls
71+
if skip:
72+
dataset = dataset.skip(skip)
73+
dataset = dataset.map(parse_and_pad, num_parallel_calls=2)
7074
return dataset.apply(
7175
tf.contrib.data.batch_and_drop_remainder(FLAGS.batch_size))
7276

@@ -98,8 +102,17 @@ def _single_iterator(skip):
98102
filenames = [os.path.join(FLAGS.input, f) for f in os.listdir(FLAGS.input)]
99103
iterator = train_iterator(filenames)
100104

101-
# TODO: add hyperparameters from argparse
102-
m = SkipThoughts(w2v_model, train=iterator)
105+
m = SkipThoughts(w2v_model, train=iterator,
106+
vocabulary_size=FLAGS.vocabulary_size,
107+
batch_size=FLAGS.batch_size,
108+
output_size=FLAGS.output_size,
109+
max_sequence_length=FLAGS.max_sequence_length,
110+
learning_rate=FLAGS.initial_lr,
111+
sample_prob=FLAGS.sample_prob,
112+
max_grad_norm=FLAGS.max_grad_norm,
113+
concat=FLAGS.concat,
114+
train_special_embeddings=FLAGS.train_special_embeddings,
115+
train_word_embeddings=FLAGS.train_word_embeddings)
103116

104117
duration = time.time() - start
105118
print("Done ({:0.4f}s).".format(duration))
@@ -121,12 +134,18 @@ def _single_iterator(skip):
121134
# Avoid crashes due to directory not existing.
122135
if not os.path.exists(output_dir):
123136
os.makedirs(output_dir)
124-
137+
#i = 1000 ##
125138
while True:
126139
start = time.time()
127140
loss_, _ = sess.run([m.loss, m.train_op])
128141
duration = time.time() - start
129142
current_step = sess.run(m.global_step)
143+
#i = min(i, duration) ##
144+
#if current_step > 100: ##
145+
# print(i) ##
146+
# exit() ##
147+
#else: ##
148+
# continue ##
130149
print(
131150
"Step", current_step,
132151
"(loss={:0.4f}, time={:0.4f}s)".format(loss_, duration))
@@ -136,4 +155,4 @@ def _single_iterator(skip):
136155
saver.save(
137156
sess,
138157
os.path.join('output', FLAGS.model_name, 'checkpoint.ckpt'),
139-
global_step=current_step)
158+
global_step=m.global_step)

0 commit comments

Comments
 (0)