1
1
# Most of the codes are from
2
2
# https://github.com/vshallc/PtrNets/blob/master/pointer/misc/tsp.py
3
3
import os
4
+ import re
5
+ import zipfile
4
6
import itertools
5
7
import threading
6
8
import numpy as np
7
- from tqdm import trange
9
+ from tqdm import trange , tqdm
8
10
from collections import namedtuple
9
11
10
12
import tensorflow as tf
13
+ from download import download_file_from_google_drive
14
+
15
+ GOOGLE_DRIVE_IDS = {
16
+ 'tsp5_train.zip' : '0B2fg8yPGn2TCSW1pNTJMXzFPYTg' ,
17
+ 'tsp10_train.zip' : '0B2fg8yPGn2TCbHowM0hfOTJCNkU' ,
18
+ 'tsp5-20_train.zip' : '0B2fg8yPGn2TCTWNxX21jTDBGeXc' ,
19
+ 'tsp50_train.zip' : '0B2fg8yPGn2TCaVQxSl9ab29QajA' ,
20
+ 'tsp20_test.txt' : '0B2fg8yPGn2TCdF9TUU5DZVNCNjQ' ,
21
+ 'tsp40_test.txt' : '0B2fg8yPGn2TCcjFrYk85SGFVNlU' ,
22
+ 'tsp50_test.txt.zip' : '0B2fg8yPGn2TCUVlCQmQtelpZTTQ' ,
23
+ }
11
24
12
25
TSP = namedtuple ('TSP' , ['x' , 'y' , 'name' ])
13
26
@@ -35,22 +48,34 @@ def generate_one_example(n_nodes, rng):
35
48
solutions = solve_tsp_dynamic (nodes )
36
49
return nodes , solutions
37
50
51
+ def read_paper_dataset (paths , max_length ):
52
+ x , y = [], []
53
+ for path in paths :
54
+ tf .logging .info ("Read dataset {} which is used in the paper.." .format (path ))
55
+ length = max (re .findall ('\d+' , path ))
56
+ with open (path ) as f :
57
+ for l in tqdm (f ):
58
+ inputs , outputs = l .split (' output ' )
59
+ x .append (np .array (inputs .split (), dtype = np .float32 ).reshape ([- 1 , 2 ]))
60
+ y .append (np .array (outputs .split (), dtype = np .int32 )[:- 1 ]) # skip the last one
61
+ return x , y
62
+
38
63
class TSPDataLoader (object ):
39
64
def __init__ (self , config , rng = None ):
40
65
self .config = config
41
66
self .rng = rng
42
67
43
- self .task = config .task
68
+ self .task = config .task . lower ()
44
69
self .batch_size = config .batch_size
45
70
self .min_length = config .min_data_length
46
71
self .max_length = config .max_data_length
47
72
48
73
self .is_train = config .is_train
49
74
self .use_terminal_symbol = config .use_terminal_symbol
75
+ self .random_seed = config .random_seed
50
76
51
77
self .data_num = {}
52
78
self .data_num ['train' ] = config .train_num
53
- self .data_num ['valid' ] = config .valid_num
54
79
self .data_num ['test' ] = config .test_num
55
80
56
81
self .data_dir = config .data_dir
@@ -63,7 +88,13 @@ def __init__(self, config, rng=None):
63
88
self .queue_ops , self .enqueue_ops = None , None
64
89
self .x , self .y , self .seq_length , self .mask = None , None , None , None
65
90
66
- self ._maybe_generate_and_save ()
91
+ paths = self .download_google_drive_file ()
92
+ if len (paths ) != 0 :
93
+ self ._maybe_generate_and_save (except_list = paths .keys ())
94
+ for name , path in paths .items ():
95
+ self .read_zip_and_update_data (path , name )
96
+ else :
97
+ self ._maybe_generate_and_save ()
67
98
self ._create_input_queue ()
68
99
69
100
def _create_input_queue (self , queue_capacity_factor = 16 ):
@@ -78,11 +109,13 @@ def _create_input_queue(self, queue_capacity_factor=16):
78
109
min_after_dequeue = 1000
79
110
capacity = min_after_dequeue + 3 * self .batch_size
80
111
81
- self .queue_ops [name ] = tf .PaddingFIFOQueue (
112
+ self .queue_ops [name ] = tf .RandomShuffleQueue (
82
113
capacity = capacity ,
114
+ min_after_dequeue = min_after_dequeue ,
83
115
dtypes = [tf .float32 , tf .int32 ],
84
- shapes = [[None , 2 ,], [None ]],
85
- name = "fifo_{}" .format (name ))
116
+ shapes = [[self .max_length , 2 ,], [self .max_length ]],
117
+ seed = self .random_seed ,
118
+ name = "random_queue_{}" .format (name ))
86
119
self .enqueue_ops [name ] = \
87
120
self .queue_ops [name ].enqueue ([self .input_ops [name ], self .target_ops [name ]])
88
121
@@ -127,21 +160,26 @@ def stop_input_queue(self):
127
160
self .coord .request_stop ()
128
161
self .coord .join (threads )
129
162
130
- def _maybe_generate_and_save (self ):
163
+ def _maybe_generate_and_save (self , except_list = [] ):
131
164
self .data = {}
132
165
133
166
for name , num in self .data_num .items ():
167
+ if name in except_list :
168
+ tf .logging .info ("Skip creating {} because of given except_list {}" .format (name , except_list ))
169
+ continue
134
170
path = self .get_path (name )
135
171
136
172
if not os .path .exists (path ):
137
173
tf .logging .info ("Creating {} for [{}]" .format (path , self .task ))
138
174
139
- x , y = [], []
140
- for i in trange (num , desc = "Create {} data" .format (name )):
175
+ x = np .zeros ([num , self .max_length , 2 ], dtype = np .float32 )
176
+ y = np .zeros ([num , self .max_length ], dtype = np .int32 )
177
+
178
+ for idx in trange (num , desc = "Create {} data" .format (name )):
141
179
n_nodes = self .rng .randint (self .min_length , self .max_length + 1 )
142
180
nodes , res = generate_one_example (n_nodes , self .rng )
143
- x . append (nodes )
144
- y . append (res )
181
+ x [ idx ,: len (nodes )] = nodes
182
+ y [ idx ,: len (res )] = res
145
183
146
184
np .savez (path , x = x , y = y )
147
185
self .data [name ] = TSP (x = x , y = y , name = name )
@@ -154,3 +192,50 @@ def get_path(self, name):
154
192
return os .path .join (
155
193
self .data_dir , "{}_{}={}.npz" .format (
156
194
self .task_name , name , self .data_num [name ]))
195
+
196
+ def download_google_drive_file (self ):
197
+ paths = {}
198
+ for mode in ['train' , 'test' ]:
199
+ candidates = []
200
+ candidates .append (
201
+ '{}{}_{}' .format (self .task , self .max_length , mode ))
202
+ candidates .append (
203
+ '{}{}-{}_{}' .format (self .task , self .min_length , self .max_length , mode ))
204
+
205
+ for key in candidates :
206
+ for search_key in GOOGLE_DRIVE_IDS .keys ():
207
+ if search_key .startswith (key ):
208
+ path = os .path .join (self .data_dir , search_key )
209
+ tf .logging .info ("Download dataset of the paper to {}" .format (path ))
210
+
211
+ if not os .path .exists (path ):
212
+ download_file_from_google_drive (GOOGLE_DRIVE_IDS [search_key ], path )
213
+ if path .endswith ('zip' ):
214
+ with zipfile .ZipFile (path , 'r' ) as z :
215
+ z .extractall (self .data_dir )
216
+ paths [mode ] = path
217
+
218
+ tf .logging .info ("Can't found dataset from the paper!" )
219
+ return paths
220
+
221
+ def read_zip_and_update_data (self , path , name ):
222
+ if path .endswith ('zip' ):
223
+ filenames = zipfile .ZipFile (path ).namelist ()
224
+ paths = [os .path .join (self .data_dir , filename ) for filename in filenames ]
225
+ else :
226
+ paths = [path ]
227
+
228
+ x_list , y_list = read_paper_dataset (paths , self .max_length )
229
+
230
+ x = np .zeros ([len (x_list ), self .max_length , 2 ], dtype = np .float32 )
231
+ y = np .zeros ([len (y_list ), self .max_length ], dtype = np .int32 )
232
+
233
+ for idx , (nodes , res ) in enumerate (tqdm (zip (x_list , y_list ))):
234
+ x [idx ,:len (nodes )] = nodes
235
+ y [idx ,:len (res )] = res
236
+
237
+ if self .data is None :
238
+ self .data = {}
239
+
240
+ tf .logging .info ("Update [{}] data with {} used in the paper" .format (name , path ))
241
+ self .data [name ] = TSP (x = x , y = y , name = name )
0 commit comments