5
5
6
6
import tensorflow as tf
7
7
from tqdm import tqdm
8
- from data_load import get_batch
8
+ from data_load import get_batch , get_dev
9
9
from params import Params
10
10
from layers import *
11
- from GRU import gated_attention_GRUCell
11
+ from GRU import gated_attention_GRUCell , GRUCell
12
12
from evaluate import *
13
13
import numpy as np
14
14
import cPickle as pickle
@@ -26,19 +26,19 @@ def __init__(self,is_training = True):
26
26
self .graph = tf .Graph ()
27
27
with self .graph .as_default ():
28
28
self .global_step = tf .Variable (0 , name = 'global_step' , trainable = False )
29
- data , self .num_batch = get_batch (is_training = is_training )
29
+ self . data , self .num_batch = get_batch (is_training = is_training )
30
30
(self .passage_w ,
31
31
self .question_w ,
32
32
self .passage_c ,
33
33
self .question_c ,
34
- self .passage_w_len ,
35
- self .question_w_len ,
34
+ self .passage_w_len_ ,
35
+ self .question_w_len_ ,
36
36
self .passage_c_len ,
37
37
self .question_c_len ,
38
- self .indices ) = data
38
+ self .indices ) = self . data
39
39
40
- self .passage_w_len = tf .squeeze (self .passage_w_len )
41
- self .question_w_len = tf .squeeze (self .question_w_len )
40
+ self .passage_w_len = tf .squeeze (self .passage_w_len_ )
41
+ self .question_w_len = tf .squeeze (self .question_w_len_ )
42
42
43
43
self .encode_ids ()
44
44
self .params = get_attn_params (Params .attn_size , initializer = tf .contrib .layers .xavier_initializer )
@@ -73,28 +73,36 @@ def encode_ids(self):
73
73
word_embeddings = self .word_embeddings ,
74
74
char_embeddings = self .char_embeddings ,
75
75
scope = "question_embeddings" )
76
+ #cell = [GRUCell(Params.attn_size, is_training = self.is_training) for _ in range(2)]
76
77
self .passage_char_encoded = bidirectional_GRU (self .passage_char_encoded ,
77
78
self .passage_c_len ,
79
+ # cell = cell,
78
80
scope = "passage_char_encoding" ,
79
81
output = 1 ,
80
82
is_training = self .is_training )
83
+ #cell = [GRUCell(Params.attn_size, is_training = self.is_training) for _ in range(2)]
81
84
self .question_char_encoded = bidirectional_GRU (self .question_char_encoded ,
82
85
self .question_c_len ,
86
+ # cell = cell,
83
87
scope = "question_char_encoding" ,
84
88
output = 1 ,
85
89
is_training = self .is_training )
86
90
self .passage_encoding = tf .concat ((self .passage_word_encoded , self .passage_char_encoded ),axis = 2 )
87
91
self .question_encoding = tf .concat ((self .question_word_encoded , self .question_char_encoded ),axis = 2 )
88
92
89
93
# Passage and question encoding
94
+ #cell = [MultiRNNCell([GRUCell(Params.attn_size, is_training = self.is_training) for _ in range(3)]) for _ in range(2)]
90
95
self .passage_encoding = bidirectional_GRU (self .passage_encoding ,
91
96
self .passage_w_len ,
97
+ # cell = cell,
92
98
layers = Params .num_layers ,
93
99
scope = "passage_encoding" ,
94
100
output = 0 ,
95
101
is_training = self .is_training )
102
+ #cell = [MultiRNNCell([GRUCell(Params.attn_size, is_training = self.is_training) for _ in range(3)]) for _ in range(2)]
96
103
self .question_encoding = bidirectional_GRU (self .question_encoding ,
97
104
self .question_w_len ,
105
+ # cell = cell,
98
106
layers = Params .num_layers ,
99
107
scope = "question_encoding" ,
100
108
output = 0 ,
@@ -165,10 +173,13 @@ def summary(self):
165
173
self .F1_placeholder = tf .placeholder (tf .float32 , shape = (), name = "F1_placeholder" )
166
174
self .EM = tf .Variable (tf .constant (0.0 , shape = (), dtype = tf .float32 ),trainable = False , name = "EM" )
167
175
self .EM_placeholder = tf .placeholder (tf .float32 , shape = (), name = "EM_placeholder" )
168
- self .metric_assign = tf .group (tf .assign (self .F1 , self .F1_placeholder ),tf .assign (self .EM , self .EM_placeholder ))
169
- tf .summary .scalar ('mean_loss' , self .mean_loss )
170
- tf .summary .scalar ("training_F1_Score" ,self .F1 )
171
- tf .summary .scalar ("training_Exact_Match" ,self .EM )
176
+ self .dev_loss = tf .Variable (tf .constant (5.0 , shape = (), dtype = tf .float32 ),trainable = False , name = "dev_loss" )
177
+ self .dev_loss_placeholder = tf .placeholder (tf .float32 , shape = (), name = "dev_loss" )
178
+ self .metric_assign = tf .group (tf .assign (self .F1 , self .F1_placeholder ),tf .assign (self .EM , self .EM_placeholder ),tf .assign (self .dev_loss , self .dev_loss_placeholder ))
179
+ tf .summary .scalar ('loss_training' , self .mean_loss )
180
+ tf .summary .scalar ('loss_dev' , self .dev_loss )
181
+ tf .summary .scalar ("F1_Score" ,self .F1 )
182
+ tf .summary .scalar ("Exact_Match" ,self .EM )
172
183
tf .summary .scalar ('learning_rate' , Params .opt_arg [Params .optimizer ]['learning_rate' ])
173
184
self .merged = tf .summary .merge_all ()
174
185
@@ -198,6 +209,7 @@ def main():
198
209
model = Model (is_training = True ); print ("Built model" )
199
210
dict_ = pickle .load (open (Params .data_dir + "dictionary.pkl" ,"r" ))
200
211
init = False
212
+ devdata , dev_ind = get_dev ()
201
213
if not os .path .isfile (os .path .join (Params .logdir ,"checkpoint" )):
202
214
init = True
203
215
glove = np .memmap (Params .data_dir + "glove.np" , dtype = np .float32 , mode = "r" )
@@ -218,19 +230,19 @@ def main():
218
230
for step in tqdm (range (model .num_batch ), total = model .num_batch , ncols = 70 , leave = False , unit = 'b' ):
219
231
sess .run (model .train_op )
220
232
if step % Params .save_steps == 0 :
221
- gs = sess . run ( model . global_step )
222
- sv . saver . save ( sess , Params . logdir + '/model_epoch_%d_step_%d' % ( gs // model . num_batch , gs % model .num_batch ))
223
- index , ground_truth , passage = sess .run ([model .points_logits , model .indices , model .passage_w ] )
224
- index = np .argmax (index , axis = 2 )
233
+ sample_ind = np . random . choice ( dev_ind , Params . batch_size )
234
+ feed_dict = { data : devdata [ i ][ sample_ind ] for i , data in enumerate ( model .data )}
235
+ logits , dev_loss , gs = sess .run ([model .points_logits , model .mean_loss , model .global_step ], feed_dict = feed_dict )
236
+ index = np .argmax (logits , axis = 2 )
225
237
F1 , EM = 0.0 , 0.0
226
238
for batch in range (Params .batch_size ):
227
- f1 , em = f1_and_EM (index [batch ], ground_truth [ batch ], passage [batch ], dict_ )
239
+ f1 , em = f1_and_EM (index [batch ], devdata [ 8 ][ sample_ind ][ batch ], devdata [ 0 ][ sample_ind ] [batch ], dict_ )
228
240
F1 += f1
229
241
EM += em
230
242
F1 /= float (Params .batch_size )
231
243
EM /= float (Params .batch_size )
232
- sess .run (model .metric_assign ,{model .F1_placeholder : F1 , model .EM_placeholder : EM })
233
- print ("\n Exact_match : {}\n F1_score : {}" .format (EM ,F1 ))
244
+ sess .run (model .metric_assign ,{model .F1_placeholder : F1 , model .EM_placeholder : EM , model . dev_loss_placeholder : dev_loss })
245
+ print ("\n Dev_loss : {}\n Dev_Exact_match : {}\n Dev_F1_score: {} " .format (dev_loss , EM ,F1 ))
234
246
235
247
if __name__ == '__main__' :
236
248
if Params .mode .lower () == "debug" :
0 commit comments