Skip to content

Commit 66b51c2

Browse files
committed
multithreading batch queue input loader finished
1 parent 8705650 commit 66b51c2

File tree

2 files changed

+24
-22
lines changed

2 files changed

+24
-22
lines changed

data_loader.py

Lines changed: 24 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ def __init__(self, config, rng=None):
4444
self.batch_size = config.batch_size
4545
self.min_length = config.min_data_length
4646
self.max_length = config.max_data_length
47+
self.is_train = config.is_train
4748

4849
self.data_num = {}
4950
self.data_num['train'] = config.train_num
@@ -58,55 +59,60 @@ def __init__(self, config, rng=None):
5859
self.coord = None
5960
self.input_ops, self.target_ops = None, None
6061
self.queue_ops, self.enqueue_ops = None, None
62+
self.x, self.y, self.mask = None, None, None
6163

6264
self._maybe_generate_and_save()
6365
self._create_input_queue()
6466

6567
def _create_input_queue(self, queue_capacity_factor=16):
6668
self.input_ops, self.target_ops = {}, {}
6769
self.queue_ops, self.enqueue_ops = {}, {}
70+
self.x, self.y, self.mask = {}, {}, {}
6871

6972
for name in self.data_num.keys():
7073
self.input_ops[name] = tf.placeholder(tf.float32, shape=[None, None])
7174
self.target_ops[name] = tf.placeholder(tf.int32, shape=[None])
7275

73-
min_after_dequeue = 5000
76+
min_after_dequeue = 1000
7477
capacity = min_after_dequeue + 3 * self.batch_size
7578

76-
if self.is_training:
77-
self.queue_ops[name] = tf.RandomShuffleQueue(
78-
capacity=capacity,
79-
min_after_dequeue=min_after_dequeue,
80-
dtypes=[tf.float32, tf.int32],
81-
name="random_{}".format(name))
82-
else:
83-
self.queue_ops[name] = tf.FIFOQueue(
84-
capacity=capacity,
85-
dtypes=[tf.float32, tf.int32],
86-
name="fifo_{}".format(name))
87-
79+
self.queue_ops[name] = tf.PaddingFIFOQueue(
80+
capacity=capacity,
81+
dtypes=[tf.float32, tf.int32],
82+
shapes=[[None, 2,], [None]],
83+
name="fifo_{}".format(name))
8884
self.enqueue_ops[name] = \
8985
self.queue_ops[name].enqueue([self.input_ops[name], self.target_ops[name]])
9086

91-
tf.train.queue_runner.add_queue_runner(tf.train.queue_runner.QueueRunner(
92-
values_queue, enqueue_ops))
87+
inputs, labels = self.queue_ops[name].dequeue()
88+
89+
caption_length = tf.shape(inputs)[0]
90+
input_length = tf.expand_dims(tf.subtract(caption_length, 1), 0)
91+
indicator = tf.ones(input_length, dtype=tf.int32)
92+
93+
self.x[name], self.y[name], self.mask[name] = tf.train.batch(
94+
[inputs, labels, indicator],
95+
batch_size=self.batch_size,
96+
capacity=capacity,
97+
dynamic_pad=True,
98+
name="batch_and_pad")
9399

94100
def run_input_queue(self, sess):
95101
threads = []
96102
self.coord = tf.train.Coordinator()
97103

98104
for name in self.data_num.keys():
99-
def load_and_enqueue(sess, name, input_ops, enqueue_ops, coord):
105+
def load_and_enqueue(sess, name, input_ops, target_ops, enqueue_ops, coord):
100106
idx = 0
101107
while not coord.should_stop():
102108
feed_dict = {
103109
input_ops[name]: self.data[name].x[idx],
104110
target_ops[name]: self.data[name].y[idx],
105111
}
106112
sess.run(self.enqueue_ops[name], feed_dict=feed_dict)
107-
idx += 1
113+
idx = idx+1 if idx+1 <= len(self.data[name].x) - 1 else 0
108114

109-
args = (sess, name, self.input_ops, self.enqueue_ops, self.coord)
115+
args = (sess, name, self.input_ops, self.target_ops, self.enqueue_ops, self.coord)
110116
t = threading.Thread(target=load_and_enqueue, args=args)
111117
t.start()
112118
threads.append(t)

trainer.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -55,10 +55,6 @@ def train(self):
5555
tf.logging.info("Training starts...")
5656

5757
self.data_loader.run_input_queue(self.sess)
58-
import ipdb; ipdb.set_trace()
59-
x = 123
60-
61-
self.data_loader.run_input_queue()
6258

6359
def test(self):
6460
tf.logging.info("Testing starts...")

0 commit comments

Comments
 (0)