@@ -44,6 +44,7 @@ def __init__(self, config, rng=None):
44
44
self .batch_size = config .batch_size
45
45
self .min_length = config .min_data_length
46
46
self .max_length = config .max_data_length
47
+ self .is_train = config .is_train
47
48
48
49
self .data_num = {}
49
50
self .data_num ['train' ] = config .train_num
@@ -58,55 +59,60 @@ def __init__(self, config, rng=None):
58
59
self .coord = None
59
60
self .input_ops , self .target_ops = None , None
60
61
self .queue_ops , self .enqueue_ops = None , None
62
+ self .x , self .y , self .mask = None , None , None
61
63
62
64
self ._maybe_generate_and_save ()
63
65
self ._create_input_queue ()
64
66
65
67
def _create_input_queue (self , queue_capacity_factor = 16 ):
66
68
self .input_ops , self .target_ops = {}, {}
67
69
self .queue_ops , self .enqueue_ops = {}, {}
70
+ self .x , self .y , self .mask = {}, {}, {}
68
71
69
72
for name in self .data_num .keys ():
70
73
self .input_ops [name ] = tf .placeholder (tf .float32 , shape = [None , None ])
71
74
self .target_ops [name ] = tf .placeholder (tf .int32 , shape = [None ])
72
75
73
- min_after_dequeue = 5000
76
+ min_after_dequeue = 1000
74
77
capacity = min_after_dequeue + 3 * self .batch_size
75
78
76
- if self .is_training :
77
- self .queue_ops [name ] = tf .RandomShuffleQueue (
78
- capacity = capacity ,
79
- min_after_dequeue = min_after_dequeue ,
80
- dtypes = [tf .float32 , tf .int32 ],
81
- name = "random_{}" .format (name ))
82
- else :
83
- self .queue_ops [name ] = tf .FIFOQueue (
84
- capacity = capacity ,
85
- dtypes = [tf .float32 , tf .int32 ],
86
- name = "fifo_{}" .format (name ))
87
-
79
+ self .queue_ops [name ] = tf .PaddingFIFOQueue (
80
+ capacity = capacity ,
81
+ dtypes = [tf .float32 , tf .int32 ],
82
+ shapes = [[None , 2 ,], [None ]],
83
+ name = "fifo_{}" .format (name ))
88
84
self .enqueue_ops [name ] = \
89
85
self .queue_ops [name ].enqueue ([self .input_ops [name ], self .target_ops [name ]])
90
86
91
- tf .train .queue_runner .add_queue_runner (tf .train .queue_runner .QueueRunner (
92
- values_queue , enqueue_ops ))
87
+ inputs , labels = self .queue_ops [name ].dequeue ()
88
+
89
+ caption_length = tf .shape (inputs )[0 ]
90
+ input_length = tf .expand_dims (tf .subtract (caption_length , 1 ), 0 )
91
+ indicator = tf .ones (input_length , dtype = tf .int32 )
92
+
93
+ self .x [name ], self .y [name ], self .mask [name ] = tf .train .batch (
94
+ [inputs , labels , indicator ],
95
+ batch_size = self .batch_size ,
96
+ capacity = capacity ,
97
+ dynamic_pad = True ,
98
+ name = "batch_and_pad" )
93
99
94
100
def run_input_queue (self , sess ):
95
101
threads = []
96
102
self .coord = tf .train .Coordinator ()
97
103
98
104
for name in self .data_num .keys ():
99
- def load_and_enqueue (sess , name , input_ops , enqueue_ops , coord ):
105
+ def load_and_enqueue (sess , name , input_ops , target_ops , enqueue_ops , coord ):
100
106
idx = 0
101
107
while not coord .should_stop ():
102
108
feed_dict = {
103
109
input_ops [name ]: self .data [name ].x [idx ],
104
110
target_ops [name ]: self .data [name ].y [idx ],
105
111
}
106
112
sess .run (self .enqueue_ops [name ], feed_dict = feed_dict )
107
- idx += 1
113
+ idx = idx + 1 if idx + 1 <= len ( self . data [ name ]. x ) - 1 else 0
108
114
109
- args = (sess , name , self .input_ops , self .enqueue_ops , self .coord )
115
+ args = (sess , name , self .input_ops , self .target_ops , self . enqueue_ops , self .coord )
110
116
t = threading .Thread (target = load_and_enqueue , args = args )
111
117
t .start ()
112
118
threads .append (t )
0 commit comments