@@ -31,7 +31,9 @@ def __init__(self, options):
31
31
self .options = options
32
32
33
33
# Char embeddings
34
- self .char_emb_mat = self .random_weight (self .options ['char_vocab_size' ], self .options ['char_emb_mat_dim' ], name = 'char_emb_matrix' )
34
+ if options ['char_emb' ]:
35
+ self .char_emb_mat = self .random_weight (self .options ['char_vocab_size' ],
36
+ self .options ['char_emb_mat_dim' ], name = 'char_emb_matrix' )
35
37
36
38
# Weights
37
39
self .W_uQ = self .random_weight (2 * options ['state_size' ], options ['state_size' ], name = 'W_uQ' )
@@ -55,86 +57,98 @@ def __init__(self, options):
55
57
56
58
# QP_match
57
59
with tf .variable_scope ('QP_match' ) as scope :
58
- self .QPmatch_cell = self .DropoutWrappedGRUCell (self .options ['state_size' ], 1.0 )
60
+ self .QPmatch_cell = self .DropoutWrappedGRUCell (self .options ['state_size' ], self . options [ 'in_keep_prob' ] )
59
61
self .QPmatch_state = self .QPmatch_cell .zero_state (self .options ['batch_size' ], dtype = tf .float32 )
60
62
61
63
# Ans Ptr
62
64
with tf .variable_scope ('Ans_ptr' ) as scope :
63
- self .AnsPtr_cell = self .DropoutWrappedGRUCell (2 * self .options ['state_size' ], 1.0 )
65
+ self .AnsPtr_cell = self .DropoutWrappedGRUCell (2 * self .options ['state_size' ], self . options [ 'in_keep_prob' ] )
64
66
65
67
def build_model (self ):
66
- paragraph = tf .placeholder (tf .float32 , [self .options ['batch_size' ], self .options ['p_length' ], self .options ['emb_dim' ]])
67
- paragraph_c = tf .placeholder (tf .int32 , [self .options ['batch_size' ], self .options ['p_length' ], self .options ['char_max_length' ]])
68
- question = tf .placeholder (tf .float32 , [self .options ['batch_size' ], self .options ['q_length' ], self .options ['emb_dim' ]])
69
- question_c = tf .placeholder (tf .int32 , [self .options ['batch_size' ], self .options ['q_length' ], self .options ['char_max_length' ]])
70
- answer_si = tf .placeholder (tf .float32 , [self .options ['batch_size' ], self .options ['p_length' ]])
71
- answer_ei = tf .placeholder (tf .float32 , [self .options ['batch_size' ], self .options ['p_length' ]])
68
+ opts = self .options
69
+
70
+ # placeholders
71
+ paragraph = tf .placeholder (tf .float32 , [opts ['batch_size' ], opts ['p_length' ], opts ['emb_dim' ]])
72
+ question = tf .placeholder (tf .float32 , [opts ['batch_size' ], opts ['q_length' ], opts ['emb_dim' ]])
73
+ answer_si = tf .placeholder (tf .float32 , [opts ['batch_size' ], opts ['p_length' ]])
74
+ answer_ei = tf .placeholder (tf .float32 , [opts ['batch_size' ], opts ['p_length' ]])
75
+ if opts ['char_emb' ]:
76
+ paragraph_c = tf .placeholder (tf .int32 , [opts ['batch_size' ], opts ['p_length' ], opts ['char_max_length' ]])
77
+ question_c = tf .placeholder (tf .int32 , [opts ['batch_size' ], opts ['q_length' ], opts ['char_max_length' ]])
72
78
73
79
print ('Question and Passage Encoding' )
74
- # char embedding -> word level char embedding
75
- paragraph_c_emb = tf .nn .embedding_lookup (self .char_emb_mat , paragraph_c ) # [batch_size, p_length, char_max_length, char_emb_dim]
76
- question_c_emb = tf .nn .embedding_lookup (self .char_emb_mat , question_c )
77
- paragraph_c_list = [tf .squeeze (w , [1 ]) for w in tf .split (paragraph_c_emb , self .options ['p_length' ], axis = 1 )]
78
- question_c_list = [tf .squeeze (w , [1 ]) for w in tf .split (question_c_emb , self .options ['q_length' ], axis = 1 )]
80
+ if opts ['char_emb' ]:
81
+ # char embedding -> word level char embedding
82
+ paragraph_c_emb = tf .nn .embedding_lookup (self .char_emb_mat , paragraph_c ) # [batch_size, p_length, char_max_length, char_emb_dim]
83
+ question_c_emb = tf .nn .embedding_lookup (self .char_emb_mat , question_c )
84
+ paragraph_c_list = [tf .squeeze (w , [1 ]) for w in tf .split (paragraph_c_emb , opts ['p_length' ], axis = 1 )]
85
+ question_c_list = [tf .squeeze (w , [1 ]) for w in tf .split (question_c_emb , opts ['q_length' ], axis = 1 )]
86
+
87
+ c_Q = []
88
+ c_P = []
89
+ with tf .variable_scope ('char_emb_rnn' ) as scope :
90
+ char_emb_fw_cell = self .DropoutWrappedGRUCell (opts ['emb_dim' ], 1.0 )
91
+ char_emb_bw_cell = self .DropoutWrappedGRUCell (opts ['emb_dim' ], 1.0 )
92
+ for t in range (opts ['q_length' ]):
93
+ unstacked_q_c = tf .unstack (question_c_list [t ], opts ['char_max_length' ], 1 )
94
+ if t > 0 :
95
+ tf .get_variable_scope ().reuse_variables ()
96
+ q_c_e_outputs , q_c_e_final_fw , q_c_e_final_bw = tf .contrib .rnn .static_bidirectional_rnn (
97
+ char_emb_fw_cell , char_emb_bw_cell , unstacked_q_c , dtype = tf .float32 , scope = 'char_emb' )
98
+ c_q_t = tf .concat ([q_c_e_final_fw [1 ], q_c_e_final_bw [1 ]], 1 )
99
+ c_Q .append (c_q_t )
100
+ for t in range (opts ['p_length' ]):
101
+ unstacked_p_c = tf .unstack (paragraph_c_list [t ], opts ['char_max_length' ], 1 )
102
+ p_c_e_outputs , p_c_e_final_fw , p_c_e_final_bw = tf .contrib .rnn .static_bidirectional_rnn (
103
+ char_emb_fw_cell , char_emb_bw_cell , unstacked_p_c , dtype = tf .float32 , scope = 'char_emb' )
104
+ c_p_t = tf .concat ([p_c_e_final_fw [1 ], p_c_e_final_bw [1 ]], 1 )
105
+ c_P .append (c_p_t )
106
+ c_Q = tf .stack (c_Q , 1 )
107
+ c_P = tf .stack (c_P , 1 )
108
+ print ('c_Q' , c_Q )
109
+ print ('c_P' , c_P )
110
+
111
+ # Concat e and c
112
+ eQcQ = tf .concat ([question , c_Q ], 2 )
113
+ ePcP = tf .concat ([paragraph , c_P ], 2 )
114
+ else :
115
+ eQcQ = question
116
+ ePcP = paragraph
79
117
80
- c_Q = []
81
- c_P = []
82
- with tf .variable_scope ('char_emb_rnn' ) as scope :
83
- char_emb_fw_cell = self .DropoutWrappedGRUCell (self .options ['emb_dim' ], 1.0 )
84
- char_emb_bw_cell = self .DropoutWrappedGRUCell (self .options ['emb_dim' ], 1.0 )
85
- for t in range (self .options ['q_length' ]):
86
- unstacked_q_c = tf .unstack (question_c_list [t ], self .options ['char_max_length' ], 1 )
87
- if t > 0 :
88
- tf .get_variable_scope ().reuse_variables ()
89
- q_c_e_outputs , q_c_e_final_fw , q_c_e_final_bw = tf .contrib .rnn .static_bidirectional_rnn (
90
- char_emb_fw_cell , char_emb_bw_cell , unstacked_q_c , dtype = tf .float32 , scope = 'char_emb' )
91
- c_q_t = tf .concat ([q_c_e_final_fw [1 ], q_c_e_final_bw [1 ]], 1 )
92
- c_Q .append (c_q_t )
93
- for t in range (self .options ['p_length' ]):
94
- unstacked_p_c = tf .unstack (paragraph_c_list [t ], self .options ['char_max_length' ], 1 )
95
- p_c_e_outputs , p_c_e_final_fw , p_c_e_final_bw = tf .contrib .rnn .static_bidirectional_rnn (
96
- char_emb_fw_cell , char_emb_bw_cell , unstacked_p_c , dtype = tf .float32 , scope = 'char_emb' )
97
- c_p_t = tf .concat ([p_c_e_final_fw [1 ], p_c_e_final_bw [1 ]], 1 )
98
- c_P .append (c_p_t )
99
- c_Q = tf .stack (c_Q , 1 )
100
- c_P = tf .stack (c_P , 1 )
101
- print ('c_Q' , c_Q )
102
- print ('c_P' , c_P )
103
- # Concat e and c
104
- eQcQ = tf .concat ([question , c_Q ], 2 )
105
- ePcP = tf .concat ([paragraph , c_P ], 2 )
106
- unstacked_eQcQ = tf .unstack (eQcQ , self .options ['q_length' ], 1 )
107
- unstacked_ePcP = tf .unstack (ePcP , self .options ['p_length' ], 1 )
118
+ unstacked_eQcQ = tf .unstack (eQcQ , opts ['q_length' ], 1 )
119
+ unstacked_ePcP = tf .unstack (ePcP , opts ['p_length' ], 1 )
108
120
with tf .variable_scope ('encoding' ) as scope :
109
- enc_fw_cell = self .DropoutWrappedGRUCell (self . options ['state_size' ], 1.0 )
110
- enc_bw_cell = self .DropoutWrappedGRUCell (self . options ['state_size' ], 1.0 )
111
- q_enc_outputs , q_enc_final_fw , q_enc_final_bw = tf .contrib .rnn .static_bidirectional_rnn (
112
- enc_fw_cell , enc_bw_cell , unstacked_eQcQ , dtype = tf .float32 , scope = 'context_encoding' )
121
+ stacked_enc_fw_cells = [ self .DropoutWrappedGRUCell (opts ['state_size' ], opts [ 'in_keep_prob' ]) for _ in range ( 2 )]
122
+ stacked_enc_bw_cells = [ self .DropoutWrappedGRUCell (opts ['state_size' ], opts [ 'in_keep_prob' ]) for _ in range ( 2 )]
123
+ q_enc_outputs , q_enc_final_fw , q_enc_final_bw = tf .contrib .rnn .stack_bidirectional_rnn (
124
+ stacked_enc_fw_cells , stacked_enc_bw_cells , unstacked_eQcQ , dtype = tf .float32 , scope = 'context_encoding' )
113
125
tf .get_variable_scope ().reuse_variables ()
114
- p_enc_outputs , p_enc_final_fw , p_enc_final_bw = tf .contrib .rnn .static_bidirectional_rnn (
115
- enc_fw_cell , enc_bw_cell , unstacked_ePcP , dtype = tf .float32 , scope = 'context_encoding' )
126
+ p_enc_outputs , p_enc_final_fw , p_enc_final_bw = tf .contrib .rnn .stack_bidirectional_rnn (
127
+ stacked_enc_fw_cells , stacked_enc_bw_cells , unstacked_ePcP , dtype = tf .float32 , scope = 'context_encoding' )
116
128
u_Q = tf .stack (q_enc_outputs , 1 )
117
129
u_P = tf .stack (p_enc_outputs , 1 )
130
+ u_Q = tf .nn .dropout (u_Q , opts ['in_keep_prob' ])
131
+ u_P = tf .nn .dropout (u_P , opts ['in_keep_prob' ])
118
132
print (u_Q )
119
133
print (u_P )
120
134
121
135
v_P = []
122
136
print ('Question-Passage Matching' )
123
- for t in range (self . options ['p_length' ]):
137
+ for t in range (opts ['p_length' ]):
124
138
# Calculate c_t
125
139
W_uQ_u_Q = self .mat_weight_mul (u_Q , self .W_uQ ) # [batch_size, q_length, state_size]
126
- tiled_u_tP = tf .concat ( [tf .reshape (u_P [:, t , :], [self . options ['batch_size' ], 1 , - 1 ])] * self . options ['q_length' ], 1 )
140
+ tiled_u_tP = tf .concat ( [tf .reshape (u_P [:, t , :], [opts ['batch_size' ], 1 , - 1 ])] * opts ['q_length' ], 1 )
127
141
W_uP_u_tP = self .mat_weight_mul (tiled_u_tP , self .W_uP )
128
142
129
143
if t == 0 :
130
144
tanh = tf .tanh (W_uQ_u_Q + W_uP_u_tP )
131
145
else :
132
- tiled_v_t1P = tf .concat ( [tf .reshape (v_P [t - 1 ], [self . options ['batch_size' ], 1 , - 1 ])] * self . options ['q_length' ], 1 )
146
+ tiled_v_t1P = tf .concat ( [tf .reshape (v_P [t - 1 ], [opts ['batch_size' ], 1 , - 1 ])] * opts ['q_length' ], 1 )
133
147
W_vP_v_t1P = self .mat_weight_mul (tiled_v_t1P , self .W_vP )
134
148
tanh = tf .tanh (W_uQ_u_Q + W_uP_u_tP + W_vP_v_t1P )
135
149
s_t = tf .squeeze (self .mat_weight_mul (tanh , tf .reshape (self .B_v_QP , [- 1 , 1 ])))
136
150
a_t = tf .nn .softmax (s_t , 1 )
137
- tiled_a_t = tf .concat ( [tf .reshape (a_t , [self . options ['batch_size' ], - 1 , 1 ])] * 2 * self . options ['state_size' ] , 2 )
151
+ tiled_a_t = tf .concat ( [tf .reshape (a_t , [opts ['batch_size' ], - 1 , 1 ])] * 2 * opts ['state_size' ] , 2 )
138
152
c_t = tf .reduce_sum ( tf .multiply (tiled_a_t , u_Q ) , 1 ) # [batch_size, 2 * state_size]
139
153
140
154
# gate
@@ -148,20 +162,21 @@ def build_model(self):
148
162
output , self .QPmatch_state = self .QPmatch_cell (u_tP_c_t_star , self .QPmatch_state )
149
163
v_P .append (output )
150
164
v_P = tf .stack (v_P , 1 )
165
+ v_P = tf .nn .dropout (v_P , opts ['in_keep_prob' ])
151
166
print ('v_P' , v_P )
152
167
153
168
print ('Self-Matching Attention' )
154
169
SM_star = []
155
- for t in range (self . options ['p_length' ]):
170
+ for t in range (opts ['p_length' ]):
156
171
# Calculate s_t
157
172
W_p1_v_P = self .mat_weight_mul (v_P , self .W_smP1 ) # [batch_size, p_length, state_size]
158
- tiled_v_tP = tf .concat ( [tf .reshape (v_P [:, t , :], [self . options ['batch_size' ], 1 , - 1 ])] * self . options ['p_length' ], 1 )
173
+ tiled_v_tP = tf .concat ( [tf .reshape (v_P [:, t , :], [opts ['batch_size' ], 1 , - 1 ])] * opts ['p_length' ], 1 )
159
174
W_p2_v_tP = self .mat_weight_mul (tiled_v_tP , self .W_smP2 )
160
175
161
176
tanh = tf .tanh (W_p1_v_P + W_p2_v_tP )
162
177
s_t = tf .squeeze (self .mat_weight_mul (tanh , tf .reshape (self .B_v_SM , [- 1 , 1 ])))
163
178
a_t = tf .nn .softmax (s_t , 1 )
164
- tiled_a_t = tf .concat ( [tf .reshape (a_t , [self . options ['batch_size' ], - 1 , 1 ])] * self . options ['state_size' ] , 2 )
179
+ tiled_a_t = tf .concat ( [tf .reshape (a_t , [opts ['batch_size' ], - 1 , 1 ])] * opts ['state_size' ] , 2 )
165
180
c_t = tf .reduce_sum ( tf .multiply (tiled_a_t , v_P ) , 1 ) # [batch_size, 2 * state_size]
166
181
167
182
# gate
@@ -170,25 +185,27 @@ def build_model(self):
170
185
v_tP_c_t_star = tf .squeeze (tf .multiply (v_tP_c_t , g_t ))
171
186
SM_star .append (v_tP_c_t_star )
172
187
SM_star = tf .stack (SM_star , 1 )
173
- unstacked_SM_star = tf .unstack (SM_star , self . options ['p_length' ], 1 )
188
+ unstacked_SM_star = tf .unstack (SM_star , opts ['p_length' ], 1 )
174
189
with tf .variable_scope ('Self_match' ) as scope :
175
- SM_fw_cell = self .DropoutWrappedGRUCell (self . options ['state_size' ], 1.0 )
176
- SM_bw_cell = self .DropoutWrappedGRUCell (self . options ['state_size' ], 1.0 )
190
+ SM_fw_cell = self .DropoutWrappedGRUCell (opts ['state_size' ], opts [ 'in_keep_prob' ] )
191
+ SM_bw_cell = self .DropoutWrappedGRUCell (opts ['state_size' ], opts [ 'in_keep_prob' ] )
177
192
SM_outputs , SM_final_fw , SM_final_bw = tf .contrib .rnn .static_bidirectional_rnn (SM_fw_cell , SM_bw_cell , unstacked_SM_star , dtype = tf .float32 )
178
193
h_P = tf .stack (SM_outputs , 1 )
194
+ h_P = tf .nn .dropout (h_P , opts ['in_keep_prob' ])
179
195
print ('h_P' , h_P )
180
196
181
197
print ('Output Layer' )
182
198
# calculate r_Q
183
199
W_ruQ_u_Q = self .mat_weight_mul (u_Q , self .W_ruQ ) # [batch_size, q_length, 2 * state_size]
184
200
W_vQ_V_rQ = tf .matmul (self .W_VrQ , self .W_vQ )
185
- W_vQ_V_rQ = tf .stack ([W_vQ_V_rQ ]* self . options ['batch_size' ], 0 ) # stack -> [batch_size, state_size, state_size]
201
+ W_vQ_V_rQ = tf .stack ([W_vQ_V_rQ ]* opts ['batch_size' ], 0 ) # stack -> [batch_size, state_size, state_size]
186
202
187
203
tanh = tf .tanh (W_ruQ_u_Q + W_vQ_V_rQ )
188
204
s_t = tf .squeeze (self .mat_weight_mul (tanh , tf .reshape (self .B_v_rQ , [- 1 , 1 ])))
189
205
a_t = tf .nn .softmax (s_t , 1 )
190
- tiled_a_t = tf .concat ( [tf .reshape (a_t , [self . options ['batch_size' ], - 1 , 1 ])] * 2 * self . options ['state_size' ] , 2 )
206
+ tiled_a_t = tf .concat ( [tf .reshape (a_t , [opts ['batch_size' ], - 1 , 1 ])] * 2 * opts ['state_size' ] , 2 )
191
207
r_Q = tf .reduce_sum ( tf .multiply (tiled_a_t , u_Q ) , 1 ) # [batch_size, 2 * state_size]
208
+ r_Q = tf .nn .dropout (r_Q , opts ['in_keep_prob' ])
192
209
print ('r_Q' , r_Q )
193
210
194
211
# r_Q as initial state of ans ptr
@@ -202,19 +219,19 @@ def build_model(self):
202
219
else :
203
220
h_t1a = h_a
204
221
print ('h_t1a' , h_t1a )
205
- tiled_h_t1a = tf .concat ( [tf .reshape (h_t1a , [self . options ['batch_size' ], 1 , - 1 ])] * self . options ['p_length' ], 1 )
222
+ tiled_h_t1a = tf .concat ( [tf .reshape (h_t1a , [opts ['batch_size' ], 1 , - 1 ])] * opts ['p_length' ], 1 )
206
223
W_ha_h_t1a = self .mat_weight_mul (tiled_h_t1a , self .W_ha )
207
224
208
225
tanh = tf .tanh (W_hP_h_P + W_ha_h_t1a )
209
226
s_t = tf .squeeze (self .mat_weight_mul (tanh , tf .reshape (self .B_v_ap , [- 1 , 1 ])))
210
227
a_t = tf .nn .softmax (s_t , 1 )
211
228
p [t ] = a_t
212
229
213
- tiled_a_t = tf .concat ( [tf .reshape (a_t , [self . options ['batch_size' ], - 1 , 1 ])] * 2 * self . options ['state_size' ] , 2 )
230
+ tiled_a_t = tf .concat ( [tf .reshape (a_t , [opts ['batch_size' ], - 1 , 1 ])] * 2 * opts ['state_size' ] , 2 )
214
231
c_t = tf .reduce_sum ( tf .multiply (tiled_a_t , h_P ) , 1 ) # [batch_size, 2 * state_size]
215
232
216
233
if t == 0 :
217
- AnsPtr_state = self .AnsPtr_cell .zero_state (self . options ['batch_size' ], dtype = tf .float32 )
234
+ AnsPtr_state = self .AnsPtr_cell .zero_state (opts ['batch_size' ], dtype = tf .float32 )
218
235
h_a , _ = self .AnsPtr_cell (c_t , (AnsPtr_state , r_Q ) )
219
236
h_a = h_a [1 ]
220
237
print (h_a )
@@ -234,7 +251,7 @@ def build_model(self):
234
251
loss = loss_si + loss_ei
235
252
"""
236
253
237
- batch_idx = tf .reshape (tf .range (0 , self . options ['batch_size' ]), [- 1 ,1 ])
254
+ batch_idx = tf .reshape (tf .range (0 , opts ['batch_size' ]), [- 1 ,1 ])
238
255
answer_si_re = tf .reshape (answer_si_idx , [- 1 ,1 ])
239
256
batch_idx_si = tf .concat ([batch_idx , answer_si_re ],1 )
240
257
answer_ei_re = tf .reshape (answer_ei_idx , [- 1 ,1 ])
@@ -245,26 +262,26 @@ def build_model(self):
245
262
246
263
# Search
247
264
prob = []
248
- search_range = self . options ['p_length' ] - self . options ['span_length' ]
265
+ search_range = opts ['p_length' ] - opts ['span_length' ]
249
266
for i in range (search_range ):
250
- for j in range (self . options ['span_length' ]):
267
+ for j in range (opts ['span_length' ]):
251
268
prob .append (tf .multiply (p1 [:, i ], p2 [:, i + j ]))
252
269
prob = tf .stack (prob , axis = 1 )
253
270
argmax_idx = tf .argmax (prob , axis = 1 )
254
- pred_si = argmax_idx / self . options ['span_length' ]
255
- pred_ei = pred_si + tf .cast (tf .mod (argmax_idx , self . options ['span_length' ]), tf .float64 )
271
+ pred_si = argmax_idx / opts ['span_length' ]
272
+ pred_ei = pred_si + tf .cast (tf .mod (argmax_idx , opts ['span_length' ]), tf .float64 )
256
273
correct = tf .logical_and (tf .equal (tf .cast (pred_si , tf .int64 ), tf .cast (answer_si_idx , tf .int64 )),
257
274
tf .equal (tf .cast (pred_ei , tf .int64 ), tf .cast (answer_ei_idx , tf .int64 )))
258
275
accuracy = tf .reduce_mean (tf .cast (correct , tf .float32 ))
259
276
260
277
input_tensors = {
261
278
'p' :paragraph ,
262
279
'q' :question ,
263
- 'pc' : paragraph_c ,
264
- 'qc' : question_c ,
265
280
'a_si' :answer_si ,
266
281
'a_ei' :answer_ei ,
267
282
}
283
+ if opts ['char_emb' ]:
284
+ input_tensors .update ({'pc' : paragraph_c , 'qc' : question_c })
268
285
269
286
print ('Model built' )
270
287
for v in tf .global_variables ():
0 commit comments