5
5
from utils import show_all_variables
6
6
7
7
class Model (object ):
8
- def __init__ (self , config , data_loader ):
8
+ def __init__ (self , config , data_loader , is_critic = False ):
9
9
self .data_loader = data_loader
10
10
11
11
self .task = config .task
@@ -15,33 +15,46 @@ def __init__(self, config, data_loader):
15
15
self .input_dim = config .input_dim
16
16
self .hidden_dim = config .hidden_dim
17
17
self .num_layers = config .num_layers
18
+
18
19
self .max_enc_length = config .max_enc_length
19
20
self .max_dec_length = config .max_dec_length
20
21
self .num_glimpse = config .num_glimpse
21
22
23
+ self .init_min_val = config .init_min_val
24
+ self .init_max_val = config .init_max_val
25
+ self .initializer = \
26
+ tf .random_uniform_initializer (self .init_min_val , self .init_max_val )
27
+
22
28
self .use_terminal_symbol = config .use_terminal_symbol
23
29
24
- self .reg_scale = config .reg_scale
25
30
self .lr_start = config .lr_start
31
+ self .lr_decay_step = config .lr_decay_step
32
+ self .lr_decay_rate = config .lr_decay_rate
26
33
self .max_grad_norm = config .max_grad_norm
27
- self .batch_size = config .batch_size
28
34
29
35
self .layer_dict = {}
30
36
31
- with arg_scope ([ linear , LSTMCell ], \
32
- initializer = tf . random_normal_initializer ( 0 , 0.001 )) :
33
- self ._build_model ()
37
+ self . _build_model ()
38
+ if is_critic :
39
+ self ._build_critic_model ()
34
40
35
41
self ._build_optim ()
42
+ self ._build_summary ()
36
43
37
44
show_all_variables ()
38
45
46
+ def _build_summary (self ):
47
+ tf .summary .scalar ("learning_rate" , self .lr )
48
+
49
+ def _build_critic_model (self ):
50
+ pass
51
+
39
52
def _build_model (self ):
40
53
self .global_step = tf .Variable (0 , trainable = False )
41
54
42
- initializer = None
43
55
input_weight = tf .get_variable (
44
- "input_weight" , [1 , self .input_dim , self .hidden_dim ])
56
+ "input_weight" , [1 , self .input_dim , self .hidden_dim ],
57
+ initializer = self .initializer )
45
58
46
59
with tf .variable_scope ("encoder" ):
47
60
self .enc_seq_length = tf .placeholder (
@@ -53,22 +66,27 @@ def _build_model(self):
53
66
self .enc_inputs , input_weight , 1 , "VALID" )
54
67
55
68
batch_size = tf .shape (self .enc_inputs )[0 ]
56
- tiled_zeros = tf .tile (tf .zeros (
57
- [1 , self .hidden_dim ]), [batch_size , 1 ], name = "tiled_zeros" )
58
-
59
69
with tf .variable_scope ("encoder" ):
60
- self .enc_cell = LSTMCell (self .hidden_dim )
70
+ self .enc_cell = LSTMCell (
71
+ self .hidden_dim ,
72
+ initializer = self .initializer )
73
+
61
74
if self .num_layers > 1 :
62
75
cells = [self .enc_cell ] * self .num_layers
63
76
self .enc_cell = MultiRNNCell (cells )
64
- self .enc_init_state = trainable_initial_state (batch_size , self .enc_cell .state_size )
77
+ self .enc_init_state = trainable_initial_state (
78
+ batch_size , self .enc_cell .state_size )
65
79
66
80
# self.encoder_outputs : [None, max_time, output_size]
67
81
self .enc_outputs , self .enc_final_states = tf .nn .dynamic_rnn (
68
- self .enc_cell , self .transformed_enc_inputs , self .enc_seq_length , self .enc_init_state )
82
+ self .enc_cell , self .transformed_enc_inputs ,
83
+ self .enc_seq_length , self .enc_init_state )
69
84
70
85
if self .use_terminal_symbol :
71
- self .enc_outputs = [tiled_zeros ] + self .enc_outputs
86
+ tiled_zeros = tf .tile (tf .zeros (
87
+ [1 , self .hidden_dim ]), [batch_size , 1 ], name = "tiled_zeros" )
88
+ expanded_tiled_zeros = tf .expand_dims (tiled_zeros , axis = 1 )
89
+ self .enc_outputs = tf .concat_v2 ([expanded_tiled_zeros , self .enc_outputs ], axis = 1 )
72
90
73
91
with tf .variable_scope ("dencoder" ):
74
92
#self.first_decoder_input = \
@@ -86,19 +104,28 @@ def _build_model(self):
86
104
87
105
idx_pairs = index_matrix_to_pairs (self .dec_idx_inputs )
88
106
self .dec_inputs = tf .gather_nd (self .enc_inputs , idx_pairs )
89
- self .transformed_dec_inputs = tf .gather_nd (self .transformed_enc_inputs , idx_pairs )
107
+ self .transformed_dec_inputs = \
108
+ tf .gather_nd (self .transformed_enc_inputs , idx_pairs )
90
109
91
110
#dec_inputs = [
92
111
# tf.expand_dims(self.first_decoder_input, 1),
93
112
# dec_inputs_without_first,
94
113
#]
95
114
#self.dec_inputs = tf.concat_v2(dec_inputs, axis=1)
96
115
97
- self .dec_targets = tf .placeholder (tf .float32 ,
98
- [None , self .max_enc_length + 1 ], name = "dec_targets" )
116
+ if self .use_terminal_symbol :
117
+ dec_target_dims = [None , self .max_enc_length + 1 ]
118
+ else :
119
+ dec_target_dims = [None , self .max_enc_length ]
120
+
121
+ self .dec_targets = tf .placeholder (
122
+ tf .int32 , dec_target_dims , name = "dec_targets" )
99
123
self .is_train = tf .placeholder (tf .bool , name = "is_train" )
100
124
101
- self .dec_cell = LSTMCell (self .hidden_dim )
125
+ self .dec_cell = LSTMCell (
126
+ self .hidden_dim ,
127
+ initializer = self .initializer )
128
+
102
129
if self .num_layers > 1 :
103
130
cells = [self .dec_cell ] * self .num_layers
104
131
self .dec_cell = MultiRNNCell (cells )
@@ -107,19 +134,29 @@ def _build_model(self):
107
134
self .dec_cell , self .transformed_dec_inputs ,
108
135
self .enc_outputs , self .enc_final_states ,
109
136
self .enc_seq_length , self .hidden_dim , self .num_glimpse ,
110
- self .max_dec_length , batch_size , is_train = True )
137
+ self .max_dec_length , batch_size , is_train = True ,
138
+ initializer = self .initializer )
111
139
112
140
with tf .variable_scope ("dencoder" , reuse = True ):
113
141
self .dec_outputs , _ , self .predictions = decoder_rnn (
114
142
self .dec_cell , self .transformed_dec_inputs ,
115
143
self .enc_outputs , self .enc_final_states ,
116
144
self .enc_seq_length , self .hidden_dim , self .num_glimpse ,
117
- self .max_dec_length , batch_size , is_train = False )
145
+ self .max_dec_length , batch_size , is_train = False ,
146
+ initializer = self .initializer )
118
147
119
148
def _build_optim (self ):
120
- self .loss = tf .reduce_mean (self .output - self .targets )
149
+ self .loss = tf .nn .sparse_softmax_cross_entropy_with_logits (
150
+ logits = self .dec_output_logits , labels = self .dec_targets )
151
+
152
+ # TODO: length masking
153
+ #mask = tf.sign(tf.to_float(targets_flat))
154
+ #masked_losses = mask * self.loss
155
+
156
+ self .lr = tf .train .exponential_decay (
157
+ self .lr_start , self .global_step , self .lr_decay_step ,
158
+ self .lr_decay_rate , staircase = True , name = "learning_rate" )
121
159
122
- self .lr = tf .Variable (self .lr_start )
123
160
optimizer = tf .train .AdamOptimizer (self .lr )
124
161
125
162
if self .max_grad_norm != None :
0 commit comments