Skip to content

Commit d5d818b

Browse files
committed
update pointer network from other repo
1 parent 7eaa2a0 commit d5d818b

File tree

7 files changed

+276
-103
lines changed

7 files changed

+276
-103
lines changed

config.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ def add_argument_group(name):
1414

1515
# Network
1616
net_arg = add_argument_group('Network')
17-
net_arg.add_argument('--hidden_dim', type=int, default=128, help='')
17+
net_arg.add_argument('--hidden_dim', type=int, default=256, help='')
1818
net_arg.add_argument('--num_layers', type=int, default=1, help='')
1919
net_arg.add_argument('--input_dim', type=int, default=2, help='')
2020
net_arg.add_argument('--max_enc_length', type=int, default=None, help='')
@@ -26,7 +26,7 @@ def add_argument_group(name):
2626

2727
# Data
2828
data_arg = add_argument_group('Data')
29-
data_arg.add_argument('--task', type=str, default='TSP')
29+
data_arg.add_argument('--task', type=str, default='tsp')
3030
data_arg.add_argument('--batch_size', type=int, default=128)
3131
data_arg.add_argument('--min_data_length', type=int, default=5)
3232
data_arg.add_argument('--max_data_length', type=int, default=10)
@@ -42,12 +42,13 @@ def add_argument_group(name):
4242
train_arg.add_argument('--lr_start', type=float, default=0.001, help='')
4343
train_arg.add_argument('--lr_decay_step', type=int, default=5000, help='')
4444
train_arg.add_argument('--lr_decay_rate', type=float, default=0.96, help='')
45-
train_arg.add_argument('--max_grad_norm', type=float, default=1.0, help='')
45+
train_arg.add_argument('--max_grad_norm', type=float, default=2.0, help='')
4646
train_arg.add_argument('--checkpoint_secs', type=int, default=300, help='')
4747

4848
# Misc
4949
misc_arg = add_argument_group('Misc')
50-
misc_arg.add_argument('--log_step', type=int, default=20, help='')
50+
misc_arg.add_argument('--log_step', type=int, default=50, help='')
51+
misc_arg.add_argument('--num_log_samples', type=int, default=3, help='')
5152
misc_arg.add_argument('--log_level', type=str, default='INFO', choices=['INFO', 'DEBUG', 'WARN'], help='')
5253
misc_arg.add_argument('--log_dir', type=str, default='logs')
5354
misc_arg.add_argument('--data_dir', type=str, default='data')

data_loader.py

Lines changed: 97 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,26 @@
11
# Most of the codes are from
22
# https://github.com/vshallc/PtrNets/blob/master/pointer/misc/tsp.py
33
import os
4+
import re
5+
import zipfile
46
import itertools
57
import threading
68
import numpy as np
7-
from tqdm import trange
9+
from tqdm import trange, tqdm
810
from collections import namedtuple
911

1012
import tensorflow as tf
13+
from download import download_file_from_google_drive
14+
15+
GOOGLE_DRIVE_IDS = {
16+
'tsp5_train.zip': '0B2fg8yPGn2TCSW1pNTJMXzFPYTg',
17+
'tsp10_train.zip': '0B2fg8yPGn2TCbHowM0hfOTJCNkU',
18+
'tsp5-20_train.zip': '0B2fg8yPGn2TCTWNxX21jTDBGeXc',
19+
'tsp50_train.zip': '0B2fg8yPGn2TCaVQxSl9ab29QajA',
20+
'tsp20_test.txt': '0B2fg8yPGn2TCdF9TUU5DZVNCNjQ',
21+
'tsp40_test.txt': '0B2fg8yPGn2TCcjFrYk85SGFVNlU',
22+
'tsp50_test.txt.zip': '0B2fg8yPGn2TCUVlCQmQtelpZTTQ',
23+
}
1124

1225
TSP = namedtuple('TSP', ['x', 'y', 'name'])
1326

@@ -35,22 +48,34 @@ def generate_one_example(n_nodes, rng):
3548
solutions = solve_tsp_dynamic(nodes)
3649
return nodes, solutions
3750

51+
def read_paper_dataset(paths, max_length):
52+
x, y = [], []
53+
for path in paths:
54+
tf.logging.info("Read dataset {} which is used in the paper..".format(path))
55+
length = max(re.findall('\d+', path))
56+
with open(path) as f:
57+
for l in tqdm(f):
58+
inputs, outputs = l.split(' output ')
59+
x.append(np.array(inputs.split(), dtype=np.float32).reshape([-1, 2]))
60+
y.append(np.array(outputs.split(), dtype=np.int32)[:-1]) # skip the last one
61+
return x, y
62+
3863
class TSPDataLoader(object):
3964
def __init__(self, config, rng=None):
4065
self.config = config
4166
self.rng = rng
4267

43-
self.task = config.task
68+
self.task = config.task.lower()
4469
self.batch_size = config.batch_size
4570
self.min_length = config.min_data_length
4671
self.max_length = config.max_data_length
4772

4873
self.is_train = config.is_train
4974
self.use_terminal_symbol = config.use_terminal_symbol
75+
self.random_seed = config.random_seed
5076

5177
self.data_num = {}
5278
self.data_num['train'] = config.train_num
53-
self.data_num['valid'] = config.valid_num
5479
self.data_num['test'] = config.test_num
5580

5681
self.data_dir = config.data_dir
@@ -63,7 +88,13 @@ def __init__(self, config, rng=None):
6388
self.queue_ops, self.enqueue_ops = None, None
6489
self.x, self.y, self.seq_length, self.mask = None, None, None, None
6590

66-
self._maybe_generate_and_save()
91+
paths = self.download_google_drive_file()
92+
if len(paths) != 0:
93+
self._maybe_generate_and_save(except_list=paths.keys())
94+
for name, path in paths.items():
95+
self.read_zip_and_update_data(path, name)
96+
else:
97+
self._maybe_generate_and_save()
6798
self._create_input_queue()
6899

69100
def _create_input_queue(self, queue_capacity_factor=16):
@@ -78,11 +109,13 @@ def _create_input_queue(self, queue_capacity_factor=16):
78109
min_after_dequeue = 1000
79110
capacity = min_after_dequeue + 3 * self.batch_size
80111

81-
self.queue_ops[name] = tf.PaddingFIFOQueue(
112+
self.queue_ops[name] = tf.RandomShuffleQueue(
82113
capacity=capacity,
114+
min_after_dequeue=min_after_dequeue,
83115
dtypes=[tf.float32, tf.int32],
84-
shapes=[[None, 2,], [None]],
85-
name="fifo_{}".format(name))
116+
shapes=[[self.max_length, 2,], [self.max_length]],
117+
seed=self.random_seed,
118+
name="random_queue_{}".format(name))
86119
self.enqueue_ops[name] = \
87120
self.queue_ops[name].enqueue([self.input_ops[name], self.target_ops[name]])
88121

@@ -127,21 +160,26 @@ def stop_input_queue(self):
127160
self.coord.request_stop()
128161
self.coord.join(threads)
129162

130-
def _maybe_generate_and_save(self):
163+
def _maybe_generate_and_save(self, except_list=[]):
131164
self.data = {}
132165

133166
for name, num in self.data_num.items():
167+
if name in except_list:
168+
tf.logging.info("Skip creating {} because of given except_list {}".format(name, except_list))
169+
continue
134170
path = self.get_path(name)
135171

136172
if not os.path.exists(path):
137173
tf.logging.info("Creating {} for [{}]".format(path, self.task))
138174

139-
x, y = [], []
140-
for i in trange(num, desc="Create {} data".format(name)):
175+
x = np.zeros([num, self.max_length, 2], dtype=np.float32)
176+
y = np.zeros([num, self.max_length], dtype=np.int32)
177+
178+
for idx in trange(num, desc="Create {} data".format(name)):
141179
n_nodes = self.rng.randint(self.min_length, self.max_length+ 1)
142180
nodes, res = generate_one_example(n_nodes, self.rng)
143-
x.append(nodes)
144-
y.append(res)
181+
x[idx,:len(nodes)] = nodes
182+
y[idx,:len(res)] = res
145183

146184
np.savez(path, x=x, y=y)
147185
self.data[name] = TSP(x=x, y=y, name=name)
@@ -154,3 +192,50 @@ def get_path(self, name):
154192
return os.path.join(
155193
self.data_dir, "{}_{}={}.npz".format(
156194
self.task_name, name, self.data_num[name]))
195+
196+
def download_google_drive_file(self):
197+
paths = {}
198+
for mode in ['train', 'test']:
199+
candidates = []
200+
candidates.append(
201+
'{}{}_{}'.format(self.task, self.max_length, mode))
202+
candidates.append(
203+
'{}{}-{}_{}'.format(self.task, self.min_length, self.max_length, mode))
204+
205+
for key in candidates:
206+
for search_key in GOOGLE_DRIVE_IDS.keys():
207+
if search_key.startswith(key):
208+
path = os.path.join(self.data_dir, search_key)
209+
tf.logging.info("Download dataset of the paper to {}".format(path))
210+
211+
if not os.path.exists(path):
212+
download_file_from_google_drive(GOOGLE_DRIVE_IDS[search_key], path)
213+
if path.endswith('zip'):
214+
with zipfile.ZipFile(path, 'r') as z:
215+
z.extractall(self.data_dir)
216+
paths[mode] = path
217+
218+
tf.logging.info("Can't found dataset from the paper!")
219+
return paths
220+
221+
def read_zip_and_update_data(self, path, name):
222+
if path.endswith('zip'):
223+
filenames = zipfile.ZipFile(path).namelist()
224+
paths = [os.path.join(self.data_dir, filename) for filename in filenames]
225+
else:
226+
paths = [path]
227+
228+
x_list, y_list = read_paper_dataset(paths, self.max_length)
229+
230+
x = np.zeros([len(x_list), self.max_length, 2], dtype=np.float32)
231+
y = np.zeros([len(y_list), self.max_length], dtype=np.int32)
232+
233+
for idx, (nodes, res) in enumerate(tqdm(zip(x_list, y_list))):
234+
x[idx,:len(nodes)] = nodes
235+
y[idx,:len(res)] = res
236+
237+
if self.data is None:
238+
self.data = {}
239+
240+
tf.logging.info("Update [{}] data with {} used in the paper".format(name, path))
241+
self.data[name] = TSP(x=x, y=y, name=name)

download.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
# Code based on
2+
# http://stackoverflow.com/questions/25010369/wget-curl-large-file-from-google-drive/39225039#39225039
3+
import requests
4+
from tqdm import tqdm
5+
6+
def download_file_from_google_drive(id, destination):
7+
URL = "https://docs.google.com/uc?export=download"
8+
9+
session = requests.Session()
10+
11+
response = session.get(URL, params = { 'id' : id }, stream = True)
12+
token = get_confirm_token(response)
13+
14+
if token:
15+
params = { 'id' : id, 'confirm' : token }
16+
response = session.get(URL, params = params, stream = True)
17+
18+
save_response_content(response, destination)
19+
return True
20+
21+
def get_confirm_token(response):
22+
for key, value in response.cookies.items():
23+
if key.startswith('download_warning'):
24+
return value
25+
26+
return None
27+
28+
def save_response_content(response, destination):
29+
CHUNK_SIZE = 32768
30+
31+
with open(destination, "wb") as f:
32+
for chunk in tqdm(response.iter_content(CHUNK_SIZE)):
33+
if chunk: # filter out keep-alive new chunks
34+
f.write(chunk)

layers.py

Lines changed: 36 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,10 @@
1111

1212
def decoder_rnn(cell, inputs,
1313
enc_outputs, enc_final_states,
14-
seq_length, hidden_dim, num_glimpse,
15-
max_dec_length, batch_size, is_train,
16-
end_of_sequence_id=0, initializer=None):
14+
seq_length, hidden_dim,
15+
num_glimpse, batch_size, is_train,
16+
end_of_sequence_id=0, initializer=None,
17+
max_length=None):
1718
with tf.variable_scope("decoder_rnn") as scope:
1819
def attention(ref, query, with_softmax, scope="attention"):
1920
with tf.variable_scope(scope):
@@ -41,37 +42,37 @@ def glimpse(ref, query, scope="glimpse"):
4142
return tf.reduce_sum(alignments * ref, [1])
4243

4344
def output_fn(ref, query, num_glimpse):
44-
for idx in range(num_glimpse):
45-
query = glimpse(ref, query, "glimpse_{}".format(idx))
46-
return attention(ref, query, with_softmax=False, scope="attention")
47-
48-
maximum_length = tf.convert_to_tensor(max_dec_length, tf.int32)
49-
def decoder_fn_inference(
50-
time, cell_state, cell_input, cell_output, context_state):
51-
if context_state is None:
52-
context_state = tf.TensorArray(tf.float32, size=maximum_length)
53-
54-
if cell_output is None:
55-
# invariant tha this is time == 0
56-
cell_state = enc_final_states
57-
cell_input = inputs[:,0,:]
58-
done = tf.zeros([batch_size,], dtype=tf.bool)
45+
if query is None:
46+
return tf.zeros([max_length], tf.float32) # only used for shape inference
5947
else:
60-
output_logit = output_fn(enc_outputs, cell_output, num_glimpse)
61-
sampled_idx = tf.multinomial(output_logit, 1)
48+
for idx in range(num_glimpse):
49+
query = glimpse(ref, query, "glimpse_{}".format(idx))
50+
return attention(ref, query, with_softmax=False, scope="attention")
6251

63-
context_state.write(time, output_logit)
64-
done = tf.squeeze(tf.equal(sampled_idx, end_of_sequence_id), -1)
65-
66-
done = tf.cond(tf.greater(time, maximum_length),
67-
lambda: tf.ones([batch_size,], dtype=tf.bool),
68-
lambda: done)
69-
return (done, cell_state, cell_input, cell_output, context_state)
52+
def input_fn(sampled_idx):
53+
return tf.stop_gradient(
54+
tf.gather_nd(enc_outputs, index_matrix_to_pairs(sampled_idx)))
7055

7156
if is_train:
7257
decoder_fn = simple_decoder_fn_train(enc_final_states)
7358
else:
74-
decoder_fn = decoder_fn_inference
59+
maximum_length = tf.convert_to_tensor(max_length, tf.int32)
60+
61+
def decoder_fn(time, cell_state, cell_input, cell_output, context_state):
62+
cell_output = output_fn(enc_outputs, cell_output, num_glimpse)
63+
if cell_state is None:
64+
cell_state = enc_final_states
65+
next_input = cell_input
66+
done = tf.zeros([batch_size,], dtype=tf.bool)
67+
else:
68+
sampled_idx = tf.cast(tf.argmax(cell_output, 1), tf.int32)
69+
next_input = input_fn(sampled_idx)
70+
done = tf.equal(sampled_idx, end_of_sequence_id)
71+
72+
done = tf.cond(tf.greater(time, maximum_length),
73+
lambda: tf.ones([batch_size,], dtype=tf.bool),
74+
lambda: done)
75+
return (done, cell_state, next_input, cell_output, context_state)
7576

7677
outputs, final_state, final_context_state = \
7778
dynamic_rnn_decoder(cell, decoder_fn, inputs=inputs,
@@ -111,8 +112,10 @@ def trainable_initial_state(batch_size, state_size,
111112
def index_matrix_to_pairs(index_matrix):
112113
# [[3,1,2], [2,3,1]] -> [[[0, 3], [1, 1], [2, 2]],
113114
# [[0, 2], [1, 3], [2, 1]]]
114-
replicated_first_indices = tf.tile(
115-
tf.expand_dims(tf.range(tf.shape(index_matrix)[0]), dim=1),
116-
[1, tf.shape(index_matrix)[1]])
117-
return tf.stack([replicated_first_indices, index_matrix], axis=2)
118-
115+
replicated_first_indices = tf.range(tf.shape(index_matrix)[0])
116+
rank = len(index_matrix.get_shape())
117+
if rank == 2:
118+
replicated_first_indices = tf.tile(
119+
tf.expand_dims(replicated_first_indices, dim=1),
120+
[1, tf.shape(index_matrix)[1]])
121+
return tf.stack([replicated_first_indices, index_matrix], axis=rank)

0 commit comments

Comments
 (0)