Skip to content

Commit 763adf4

Browse files
committed
config update and learning rate decay update
1 parent 77e9032 commit 763adf4

File tree

5 files changed

+90
-43
lines changed

5 files changed

+90
-43
lines changed

config.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ def add_argument_group(name):
1919
net_arg.add_argument('--input_dim', type=int, default=2, help='')
2020
net_arg.add_argument('--max_enc_length', type=int, default=20, help='')
2121
net_arg.add_argument('--max_dec_length', type=int, default=33, help='')
22+
net_arg.add_argument('--init_min_val', type=float, default=-0.08, help='for uniform random initializer')
23+
net_arg.add_argument('--init_max_val', type=float, default=+0.08, help='for uniform random initializer')
2224
net_arg.add_argument('--num_glimpse', type=int, default=1, help='')
2325
net_arg.add_argument('--use_terminal_symbol', type=str2bool, default=True, help='Not implemented yet')
2426

@@ -34,10 +36,9 @@ def add_argument_group(name):
3436
train_arg.add_argument('--is_train', type=str2bool, default=True, help='')
3537
train_arg.add_argument('--optimizer', type=str, default='rmsprop', help='')
3638
train_arg.add_argument('--max_step', type=int, default=10000, help='')
37-
train_arg.add_argument('--reg_scale', type=float, default=0.5, help='')
38-
train_arg.add_argument('--batch_size', type=int, default=512, help='')
3939
train_arg.add_argument('--lr_start', type=float, default=0.001, help='')
4040
train_arg.add_argument('--lr_decay_step', type=int, default=5000, help='')
41+
train_arg.add_argument('--lr_decay_rate', type=float, default=0.96, help='')
4142
train_arg.add_argument('--max_grad_norm', type=float, default=1.0, help='')
4243
train_arg.add_argument('--checkpoint_secs', type=int, default=300, help='')
4344

data_loader.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
# Most of the codes are from https://github.com/vshallc/PtrNets/blob/master/pointer/misc/tsp.py
22
import os
3-
import np as np
3+
import itertools
4+
import numpy as np
5+
from tqdm import trange
46

57
def length(x, y):
68
return np.linalg.norm(np.asarray(x) - np.asarray(y))
@@ -24,10 +26,11 @@ def solve_tsp_dynamic(points):
2426
def generate_one_example(n_nodes):
2527
nodes = np.random.rand(n_nodes, 2)
2628
res = solve_tsp_dynamic(nodes)
29+
return nodes, res
2730

28-
def generate_examples(num, n_min, n_max):
31+
def generate_examples(num, n_min, n_max, desc=""):
2932
examples = []
30-
for i in range(num):
33+
for i in trange(num, desc=desc):
3134
n_nodes = np.random.randint(n_min, n_max + 1)
3235
nodes, res = generate_one_example(n_nodes)
3336
examples.append((nodes, res))
@@ -45,18 +48,22 @@ def __init__(self, config, rng=None):
4548
self.task_name = "{}_{}_{}".format(self.task, self.min_length, self.max_length)
4649
self.npz_path = os.path.join(config.data_dir, "{}.npz".format(self.task_name))
4750

48-
def maybe_generate_and_save(self):
51+
self._maybe_generate_and_save()
52+
53+
def _maybe_generate_and_save(self):
4954
if not os.path.exists(self.npz_path):
5055
print("[*] Creating dataset for {}".format(self.task))
5156

52-
train = generate_examples(1048576, self.min_length, self.max_length)
53-
valid = generate_examples(1000, self.min_length, self.max_length)
54-
test = generate_examples(1000, self.max_length, self.max_length)
57+
train = generate_examples(
58+
1000000, self.min_length, self.max_length, "Train data..")
59+
valid = generate_examples(
60+
1000, self.min_length, self.max_length, "Valid data..")
61+
test = generate_examples(
62+
1000, self.max_length, self.max_length, "Test data..")
5563

5664
np.savez(self.npz_path, train=train, test=test, valid=valid)
5765
else:
5866
print("[*] Loading dataset for {}".format(self.task))
59-
data = np.load(self.npz_path, train=, test=, val=)
67+
data = np.load(self.npz_path)
6068
self.train, self.test, self.valid = \
6169
data['train'], data['test'], data['valid']
62-

layers.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
from tensorflow.contrib import seq2seq
55
from tensorflow.python.util import nest
66

7-
linear = layers.linear
87
LSTMCell = rnn.LSTMCell
98
MultiRNNCell = rnn.MultiRNNCell
109
dynamic_rnn_decoder = seq2seq.dynamic_rnn_decoder
@@ -13,16 +12,20 @@
1312
def decoder_rnn(cell, inputs,
1413
enc_outputs, enc_final_states,
1514
seq_length, hidden_dim, num_glimpse,
16-
max_dec_length, batch_size, is_train, end_of_sequence_id=0):
15+
max_dec_length, batch_size, is_train,
16+
end_of_sequence_id=0, initializer=None):
1717
with tf.variable_scope("decoder_rnn") as scope:
1818
first_decoder_input = trainable_initial_state(
1919
batch_size, hidden_dim, name="first_decoder_input")
2020

2121
def attention(ref, query, with_softmax=True, scope="attention"):
2222
with tf.variable_scope(scope):
23-
W_ref = tf.get_variable("W_ref", [1, hidden_dim, hidden_dim])
24-
W_q = tf.get_variable("W_q", [hidden_dim, hidden_dim])
25-
v = tf.get_variable("v", [hidden_dim])
23+
W_ref = tf.get_variable(
24+
"W_ref", [1, hidden_dim, hidden_dim], initializer=initializer)
25+
W_q = tf.get_variable(
26+
"W_q", [hidden_dim, hidden_dim], initializer=initializer)
27+
v = tf.get_variable(
28+
"v", [hidden_dim], initializer=initializer)
2629

2730
encoded_ref = tf.nn.conv1d(ref, W_ref, 1, "VALID")
2831
encoded_query = tf.matmul(tf.reshape(query, [-1, hidden_dim]), W_q)
@@ -85,7 +88,7 @@ def decoder_fn_inference(
8588
output_logit = output_fn(enc_outputs, output, num_glimpse)
8689
scope.reuse_variables()
8790
output_logits.append(output_logit)
88-
outputs = tf.stack(output_logits, 1)
91+
outputs = tf.stack(output_logits, axis=1)
8992

9093
return outputs, final_state, final_context_state
9194

@@ -99,7 +102,6 @@ def trainable_initial_state(batch_size, state_size,
99102
flat_initializer = tuple(tf.zeros_initializer for initializer in flat_state_size)
100103

101104
names = ["{}_{}".format(name, i) for i in xrange(len(flat_state_size))]
102-
103105
tiled_states = []
104106

105107
for name, size, init in zip(names, flat_state_size, flat_initializer):
@@ -118,4 +120,4 @@ def index_matrix_to_pairs(index_matrix):
118120
replicated_first_indices = tf.tile(
119121
tf.expand_dims(tf.range(tf.shape(index_matrix)[0]), dim=1),
120122
[1, tf.shape(index_matrix)[1]])
121-
return tf.pack([replicated_first_indices, index_matrix], axis=2)
123+
return tf.stack([replicated_first_indices, index_matrix], axis=2)

model.py

Lines changed: 60 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from utils import show_all_variables
66

77
class Model(object):
8-
def __init__(self, config, data_loader):
8+
def __init__(self, config, data_loader, is_critic=False):
99
self.data_loader = data_loader
1010

1111
self.task = config.task
@@ -15,33 +15,46 @@ def __init__(self, config, data_loader):
1515
self.input_dim = config.input_dim
1616
self.hidden_dim = config.hidden_dim
1717
self.num_layers = config.num_layers
18+
1819
self.max_enc_length = config.max_enc_length
1920
self.max_dec_length = config.max_dec_length
2021
self.num_glimpse = config.num_glimpse
2122

23+
self.init_min_val = config.init_min_val
24+
self.init_max_val = config.init_max_val
25+
self.initializer = \
26+
tf.random_uniform_initializer(self.init_min_val, self.init_max_val)
27+
2228
self.use_terminal_symbol = config.use_terminal_symbol
2329

24-
self.reg_scale = config.reg_scale
2530
self.lr_start = config.lr_start
31+
self.lr_decay_step = config.lr_decay_step
32+
self.lr_decay_rate = config.lr_decay_rate
2633
self.max_grad_norm = config.max_grad_norm
27-
self.batch_size = config.batch_size
2834

2935
self.layer_dict = {}
3036

31-
with arg_scope([linear, LSTMCell], \
32-
initializer=tf.random_normal_initializer(0, 0.001)):
33-
self._build_model()
37+
self._build_model()
38+
if is_critic:
39+
self._build_critic_model()
3440

3541
self._build_optim()
42+
self._build_summary()
3643

3744
show_all_variables()
3845

46+
def _build_summary(self):
47+
tf.summary.scalar("learning_rate", self.lr)
48+
49+
def _build_critic_model(self):
50+
pass
51+
3952
def _build_model(self):
4053
self.global_step = tf.Variable(0, trainable=False)
4154

42-
initializer = None
4355
input_weight = tf.get_variable(
44-
"input_weight", [1, self.input_dim, self.hidden_dim])
56+
"input_weight", [1, self.input_dim, self.hidden_dim],
57+
initializer=self.initializer)
4558

4659
with tf.variable_scope("encoder"):
4760
self.enc_seq_length = tf.placeholder(
@@ -53,22 +66,27 @@ def _build_model(self):
5366
self.enc_inputs, input_weight, 1, "VALID")
5467

5568
batch_size = tf.shape(self.enc_inputs)[0]
56-
tiled_zeros = tf.tile(tf.zeros(
57-
[1, self.hidden_dim]), [batch_size, 1], name="tiled_zeros")
58-
5969
with tf.variable_scope("encoder"):
60-
self.enc_cell = LSTMCell(self.hidden_dim)
70+
self.enc_cell = LSTMCell(
71+
self.hidden_dim,
72+
initializer=self.initializer)
73+
6174
if self.num_layers > 1:
6275
cells = [self.enc_cell] * self.num_layers
6376
self.enc_cell = MultiRNNCell(cells)
64-
self.enc_init_state = trainable_initial_state(batch_size, self.enc_cell.state_size)
77+
self.enc_init_state = trainable_initial_state(
78+
batch_size, self.enc_cell.state_size)
6579

6680
# self.encoder_outputs : [None, max_time, output_size]
6781
self.enc_outputs, self.enc_final_states = tf.nn.dynamic_rnn(
68-
self.enc_cell, self.transformed_enc_inputs, self.enc_seq_length, self.enc_init_state)
82+
self.enc_cell, self.transformed_enc_inputs,
83+
self.enc_seq_length, self.enc_init_state)
6984

7085
if self.use_terminal_symbol:
71-
self.enc_outputs = [tiled_zeros] + self.enc_outputs
86+
tiled_zeros = tf.tile(tf.zeros(
87+
[1, self.hidden_dim]), [batch_size, 1], name="tiled_zeros")
88+
expanded_tiled_zeros = tf.expand_dims(tiled_zeros, axis=1)
89+
self.enc_outputs = tf.concat_v2([expanded_tiled_zeros, self.enc_outputs], axis=1)
7290

7391
with tf.variable_scope("dencoder"):
7492
#self.first_decoder_input = \
@@ -86,19 +104,28 @@ def _build_model(self):
86104

87105
idx_pairs = index_matrix_to_pairs(self.dec_idx_inputs)
88106
self.dec_inputs = tf.gather_nd(self.enc_inputs, idx_pairs)
89-
self.transformed_dec_inputs = tf.gather_nd(self.transformed_enc_inputs, idx_pairs)
107+
self.transformed_dec_inputs = \
108+
tf.gather_nd(self.transformed_enc_inputs, idx_pairs)
90109

91110
#dec_inputs = [
92111
# tf.expand_dims(self.first_decoder_input, 1),
93112
# dec_inputs_without_first,
94113
#]
95114
#self.dec_inputs = tf.concat_v2(dec_inputs, axis=1)
96115

97-
self.dec_targets = tf.placeholder(tf.float32,
98-
[None, self.max_enc_length + 1], name="dec_targets")
116+
if self.use_terminal_symbol:
117+
dec_target_dims = [None, self.max_enc_length + 1]
118+
else:
119+
dec_target_dims = [None, self.max_enc_length]
120+
121+
self.dec_targets = tf.placeholder(
122+
tf.int32, dec_target_dims, name="dec_targets")
99123
self.is_train = tf.placeholder(tf.bool, name="is_train")
100124

101-
self.dec_cell = LSTMCell(self.hidden_dim)
125+
self.dec_cell = LSTMCell(
126+
self.hidden_dim,
127+
initializer=self.initializer)
128+
102129
if self.num_layers > 1:
103130
cells = [self.dec_cell] * self.num_layers
104131
self.dec_cell = MultiRNNCell(cells)
@@ -107,19 +134,29 @@ def _build_model(self):
107134
self.dec_cell, self.transformed_dec_inputs,
108135
self.enc_outputs, self.enc_final_states,
109136
self.enc_seq_length, self.hidden_dim, self.num_glimpse,
110-
self.max_dec_length, batch_size, is_train=True)
137+
self.max_dec_length, batch_size, is_train=True,
138+
initializer=self.initializer)
111139

112140
with tf.variable_scope("dencoder", reuse=True):
113141
self.dec_outputs, _, self.predictions = decoder_rnn(
114142
self.dec_cell, self.transformed_dec_inputs,
115143
self.enc_outputs, self.enc_final_states,
116144
self.enc_seq_length, self.hidden_dim, self.num_glimpse,
117-
self.max_dec_length, batch_size, is_train=False)
145+
self.max_dec_length, batch_size, is_train=False,
146+
initializer=self.initializer)
118147

119148
def _build_optim(self):
120-
self.loss = tf.reduce_mean(self.output - self.targets)
149+
self.loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
150+
logits=self.dec_output_logits, labels=self.dec_targets)
151+
152+
# TODO: length masking
153+
#mask = tf.sign(tf.to_float(targets_flat))
154+
#masked_losses = mask * self.loss
155+
156+
self.lr = tf.train.exponential_decay(
157+
self.lr_start, self.global_step, self.lr_decay_step,
158+
self.lr_decay_rate, staircase=True, name="learning_rate")
121159

122-
self.lr = tf.Variable(self.lr_start)
123160
optimizer = tf.train.AdamOptimizer(self.lr)
124161

125162
if self.max_grad_norm != None:

trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def _build_session(self):
4141
summary_writer=self.summary_writer,
4242
save_summaries_secs=300,
4343
save_model_secs=self.checkpoint_secs,
44-
global_step=self.model.discrim_step)
44+
global_step=self.model.global_step)
4545

4646
gpu_options = tf.GPUOptions(
4747
per_process_gpu_memory_fraction=self.gpu_memory_fraction,

0 commit comments

Comments
 (0)