16
16
help = "Initial learning rate." )
17
17
parser .add_argument ('--vocabulary_size' , type = int , default = 20000 ,
18
18
help = "Keep only the n most common words of the training data." )
19
- parser .add_argument ('--batch_size' , type = int , default = 16 ,
19
+ parser .add_argument ('--batch_size' , type = int , default = 128 ,
20
20
help = "Stochastic gradient descent minibatch size." )
21
21
parser .add_argument ('--output_size' , type = int , default = 512 ,
22
22
help = "Number of hidden units for the encoder and decoder GRUs." )
23
- parser .add_argument ('--max_length ' , type = int , default = 40 ,
23
+ parser .add_argument ('--max_sequence_length ' , type = int , default = 40 ,
24
24
help = "Truncate input and output sentences to maximum length n." )
25
25
parser .add_argument ('--sample_prob' , type = float , default = 0. ,
26
26
help = "Decoder probability to sample from its predictions duing training." )
@@ -58,15 +58,19 @@ def parse_and_pad(seq):
58
58
serialized = seq , sequence_features = sequence_features )
59
59
# Pad the sequence
60
60
t = sequence_parsed ["tokens" ]
61
- return tf .pad (t , [[0 , FLAGS .max_length - tf .shape (t )[0 ]]])
61
+ if FLAGS .eos_token :
62
+ t = tf .pad (t , [[0 , 1 ]], constant_values = 3 )
63
+ return tf .pad (t , [[0 , FLAGS .max_sequence_length - tf .shape (t )[0 ]]])
62
64
63
65
64
66
def train_iterator (filenames ):
65
67
"""Build the input pipeline for training.."""
66
68
67
69
def _single_iterator (skip ):
68
70
dataset = tf .data .TFRecordDataset (filenames )
69
- dataset = dataset .map (parse_and_pad ) # TODO: add option for parallel calls
71
+ if skip :
72
+ dataset = dataset .skip (skip )
73
+ dataset = dataset .map (parse_and_pad , num_parallel_calls = 2 )
70
74
return dataset .apply (
71
75
tf .contrib .data .batch_and_drop_remainder (FLAGS .batch_size ))
72
76
@@ -98,8 +102,17 @@ def _single_iterator(skip):
98
102
filenames = [os .path .join (FLAGS .input , f ) for f in os .listdir (FLAGS .input )]
99
103
iterator = train_iterator (filenames )
100
104
101
- # TODO: add hyperparameters from argparse
102
- m = SkipThoughts (w2v_model , train = iterator )
105
+ m = SkipThoughts (w2v_model , train = iterator ,
106
+ vocabulary_size = FLAGS .vocabulary_size ,
107
+ batch_size = FLAGS .batch_size ,
108
+ output_size = FLAGS .output_size ,
109
+ max_sequence_length = FLAGS .max_sequence_length ,
110
+ learning_rate = FLAGS .initial_lr ,
111
+ sample_prob = FLAGS .sample_prob ,
112
+ max_grad_norm = FLAGS .max_grad_norm ,
113
+ concat = FLAGS .concat ,
114
+ train_special_embeddings = FLAGS .train_special_embeddings ,
115
+ train_word_embeddings = FLAGS .train_word_embeddings )
103
116
104
117
duration = time .time () - start
105
118
print ("Done ({:0.4f}s)." .format (duration ))
@@ -121,12 +134,18 @@ def _single_iterator(skip):
121
134
# Avoid crashes due to directory not existing.
122
135
if not os .path .exists (output_dir ):
123
136
os .makedirs (output_dir )
124
-
137
+ #i = 1000 ##
125
138
while True :
126
139
start = time .time ()
127
140
loss_ , _ = sess .run ([m .loss , m .train_op ])
128
141
duration = time .time () - start
129
142
current_step = sess .run (m .global_step )
143
+ #i = min(i, duration) ##
144
+ #if current_step > 100: ##
145
+ # print(i) ##
146
+ # exit() ##
147
+ #else: ##
148
+ # continue ##
130
149
print (
131
150
"Step" , current_step ,
132
151
"(loss={:0.4f}, time={:0.4f}s)" .format (loss_ , duration ))
@@ -136,4 +155,4 @@ def _single_iterator(skip):
136
155
saver .save (
137
156
sess ,
138
157
os .path .join ('output' , FLAGS .model_name , 'checkpoint.ckpt' ),
139
- global_step = current_step )
158
+ global_step = m . global_step )
0 commit comments