Skip to content

Commit a665020

Browse files
committed
Constant image pool
1 parent efbfe1c commit a665020

File tree

5 files changed

+64
-105
lines changed

5 files changed

+64
-105
lines changed

generator_gan.py

Lines changed: 6 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -67,13 +67,11 @@ def __init__(self, name, setup=True, image_size=64, colors=3, batch_size=64, dir
6767
self._dis_scale = dicriminator_scaling_favor
6868
#Setup Images
6969
if image_manager is None:
70-
self.image_manager = ImageVariations(image_size=image_size, batch_size=batch_size, colored=(colors == 3))
70+
self.image_manager = ImageVariations(image_size=image_size, colored=(colors == 3))
7171
else:
7272
self.image_manager = image_manager
73-
self.image_manager.batch_size = batch_size
7473
self.image_manager.image_size = image_size
7574
self.image_manager.colored = (colors == 3)
76-
self.image_manager.start_threads()
7775
#Setup Networks
7876
self.iterations = tf.Variable(0, name="training_iterations", trainable=False)
7977
with tf.variable_scope('input'):
@@ -151,7 +149,7 @@ def get_session(self):
151149

152150
def __get_feed_dict__(self):
153151
return {
154-
self.image_input: self.image_manager.get_batch(),
152+
self.image_input: self.image_manager.get_batch(self.batch_size),
155153
self.generator_input: self.random_input()
156154
}
157155

@@ -202,7 +200,7 @@ def train(self, batches=100000, print_interval=1):
202200
logger(i)
203201
#Save network
204202
if timer() - last_save > 1800:
205-
saver.save(session, os.path.join(self.directory, self.name))
203+
saver.save(session, os.path.join(self.directory, self.name), self.iterations)
206204
last_save = timer()
207205
except KeyboardInterrupt:
208206
print()
@@ -234,10 +232,8 @@ def __call__(self, iteration):
234232
#Save image
235233
if iteration%self.image_interval == 0:
236234
#Hack to make tensorboard show multiple images, not just the latest one
237-
feed_dict = {
238-
self.gan.generator_input: self.batch_input,
239-
self.gan.image_input: self.gan.image_manager.get_old_batch()
240-
}
235+
feed_dict = self.gan.__get_feed_dict__()
236+
feed_dict[self.gan.generator_input] = self.batch_input,
241237
image, summary = self.session.run(
242238
[tf.summary.image(
243239
'training/iteration/%d'%iteration,
@@ -250,10 +246,7 @@ def __call__(self, iteration):
250246
self.writer.add_summary(image, iteration)
251247
self.writer.add_summary(summary, iteration)
252248
elif iteration%self.summary_interval == 0:
253-
feed_dict = {
254-
self.gan.generator_input: self.gan.random_input(),
255-
self.gan.image_input: self.gan.image_manager.get_old_batch()
256-
}
249+
feed_dict = self.gan.__get_feed_dict__()
257250
#Save summary
258251
summary = self.session.run(self.summary, feed_dict=feed_dict)
259252
self.writer.add_summary(summary, iteration)

image.py

Lines changed: 54 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -1,108 +1,80 @@
11

2-
import math
32
import os
43
import random
54
import time
6-
from threading import Event, Thread
5+
from multiprocessing import Pool
76

87
import numpy as np
98
from PIL import Image, ImageEnhance
109

11-
from queue import Queue
12-
1310

1411
class ImageVariations():
15-
def __init__(self, image_size=64, batch_size=64, colored=True,
16-
pools=8, pool_renew=1,
12+
def __init__(self, image_size=64, colored=True, pool_size=10000,
1713
in_directory='input', out_directory='output',
1814
rotation_range=(-15, 15), brightness_range=(0.7, 1.2),
1915
saturation_range=(0.7, 1.), contrast_range=(0.9, 1.3),
2016
size_range=(0.6, 0.8)):
2117
#Parameters
2218
self.image_size = image_size
23-
self.batch_size = batch_size
2419
self.in_directory = in_directory
2520
self.out_directory = out_directory
26-
self.pools = pools
27-
self.pool_renew = pool_renew
21+
self.images_count = pool_size
2822
#Variation Config
2923
self.rotation_range = rotation_range
3024
self.brightness_range = brightness_range
3125
self.saturation_range = saturation_range
3226
self.contrast_range = contrast_range
3327
self.size_range = size_range
3428
self.colored = colored
35-
#Thread variables
36-
self.pool = []
37-
self.pool_index = 0
38-
self.pool_iteration = 0
39-
self.queue = Queue()
40-
self.files = []
41-
self.threads = []
42-
self.event = Event()
43-
self.closing = True
44-
45-
def start_threads(self):
46-
"""Start the threads that are generating image variations"""
47-
self.closing = True
48-
self.event.set()
49-
self.files = [f for f in os.listdir(self.in_directory) if os.path.isfile(os.path.join(self.in_directory, f))]
50-
num_threads = os.cpu_count()
51-
if num_threads is None:
52-
num_threads = 4
53-
self.threads = [Thread(target=self.__thread__, args=(self.files[i::num_threads],), daemon=True)
54-
for i in range(num_threads)]
55-
self.event.clear()
56-
self.closing = False
57-
for t in self.threads:
58-
t.start()
59-
if(self.pools > 1):
60-
print('Processing input images')
61-
self.pool = [[] for _ in range(self.pools)]
62-
63-
def stop_threads(self):
64-
"""Stop the threads that are generating image variations (freeing memory)"""
65-
self.closing = True
66-
self.event.set()
67-
68-
def get_batch(self):
69-
"""Get a batch of images as arrays"""
70-
if self.closing: #Start threads
71-
self.start_threads()
72-
self.event.set()
73-
if len(self.pool[self.pool_index]) == 0: #Check and fill image pool
74-
self.pool[self.pool_index] = [self.queue.get() for _ in range(self.batch_size)]
75-
np.random.shuffle(self.pool[self.pool_index])
76-
images = self.pool[self.pool_index]
77-
for i in range(self.pool_renew): #Replace old images
78-
self.pool[self.pool_index][(self.pool_iteration+i)%self.batch_size] = self.queue.get()
79-
self.pool_index += 1
80-
if self.pool_index == self.pools: #Cycle indexes
81-
self.pool_index = 0
82-
self.pool_iteration = (self.pool_iteration+self.pool_renew)%self.batch_size
83-
self.event.clear()
84-
return images
85-
86-
def get_old_batch(self):
87-
if self.closing or len(self.pool[self.pool_index]) == 0:
88-
return self.get_batch()
89-
return self.pool[self.pool_index-1]
29+
#Generate Images
30+
self.index = 0
31+
if self.images_count > 0:
32+
if self.images_count > 20:
33+
print("Processing Images")
34+
files = [f for f in os.listdir(self.in_directory) if os.path.isfile(os.path.join(self.in_directory, f))]
35+
np.random.shuffle(files)
36+
mp = self.images_count//len(files)
37+
rest = self.images_count%len(files)
38+
if mp > 0:
39+
pool = Pool()
40+
images = pool.starmap(self.__generate_images__, [(f, mp) for f in files])
41+
self.pool = [img for sub in images for img in sub]
42+
pool.close()
43+
else:
44+
self.pool = []
45+
self.pool += [img for sub in [self.__generate_images__(f, 1) for f in files[:rest]] for img in sub]
46+
np.random.shuffle(self.pool)
9047

91-
def __thread__(self, files):
48+
def __generate_images__(self, image_file, iterations):
9249
if self.colored:
93-
images = [Image.open(os.path.join(self.in_directory, file)) for file in files]
50+
image = Image.open(os.path.join(self.in_directory, image_file))
9451
else:
95-
images = [Image.open(os.path.join(self.in_directory, file)).convert("L") for file in files]
96-
index = 0
97-
while not self.closing:
98-
image = images[index]
99-
index = (index+1)%len(images)
52+
image = Image.open(os.path.join(self.in_directory, image_file)).convert("L")
53+
def variation_to_numpy():
10054
arr = np.asarray(self.get_variation(image), dtype=np.float)
10155
if not self.colored:
10256
arr.shape = arr.shape+(1,)
103-
self.queue.put(arr)
104-
while self.queue.qsize() >= self.batch_size and not self.closing:
105-
self.event.wait()
57+
return arr
58+
return [variation_to_numpy() for _ in range(iterations)]
59+
60+
61+
def get_batch(self, count):
62+
"""Get a batch of images as arrays"""
63+
if self.index + count < len(self.pool):
64+
batch = self.pool[self.index:self.index+count]
65+
self.index += count
66+
return batch
67+
else:
68+
batch = self.pool[self.index:]
69+
self.index = 0
70+
np.random.shuffle(self.pool)
71+
return batch + self.get_batch(count - len(batch))
72+
73+
def get_rnd_batch(self, count):
74+
if count > len(self.pool):
75+
return self.get_batch(count)
76+
index = np.random.randint(0, len(self.pool)-count)
77+
return self.pool[index:index+count]
10678

10779
def get_variation(self, image):
10880
"""Get an variation of the image according to the object config"""
@@ -155,19 +127,14 @@ def save_image(self, image, name=None):
155127

156128
if __name__ == "__main__":
157129
if len(os.sys.argv) > 1:
158-
imgvariations = ImageVariations(pools=1, batch_size=int(os.sys.argv[1]))
159-
imgvariations.start_threads()
160-
images_batch = imgvariations.get_batch()
161-
imgvariations.stop_threads()
162-
for variant_id in range(int(os.sys.argv[1])):
130+
num_imgs = int(os.sys.argv[1])
131+
imgvariations = ImageVariations(pool_size=num_imgs)
132+
images_batch = imgvariations.get_batch(num_imgs)
133+
for variant_id in range(num_imgs):
163134
imgvariations.save_image(images_batch[variant_id], name="variant_%d"%variant_id)
164-
print("Generated %s image variations as they are when fed to the network"%os.sys.argv[1])
135+
print("Generated %i image variations as they are when fed to the network"%num_imgs)
165136
else:
166137
print("Testing memory requiremens")
167-
imgvariations = ImageVariations()
168-
input("Press Enter to continue... (all images loaded and pools filled)")
169-
iml1 = imgvariations.get_batch()
170-
iml2 = imgvariations.get_batch()
171-
iml3 = imgvariations.get_batch()
172-
iml4 = imgvariations.get_batch()
173-
input("Press Enter to continue... (also four batches)")
138+
num_imgs = 10000
139+
imgvariations = ImageVariations(pool_size=num_imgs)
140+
input("Press Enter to continue... (Pool countains %i images)"%num_imgs)

network.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,8 @@ def image_decoder(input_tensors, name='decoder', image_size=64, convolutions=5,
3333
assert conv_image_size*(2**convolutions) == image_size, "Images must be a multiple of two (and >= 2**convolutions)"
3434
with tf.variable_scope(name):
3535
prev_layer = expand_relu(input_tensors, [-1, conv_image_size, conv_image_size, base_width*2**(convolutions-1)], 'expand')
36-
for i in range(convolutions-1):
37-
prev_layer = conv2d_transpose(prev_layer, batch_size, 2**(convolutions-i-2)*base_width, 'convolution_%d'%i)
36+
for i in range(convolutions):
37+
prev_layer = conv2d_transpose(prev_layer, batch_size, 2**(convolutions-i-1)*base_width, 'convolution_%d'%i)
3838
prev_layer = conv2d_transpose_tanh(prev_layer, batch_size, colors, 'output')
3939
return prev_layer
4040

operators.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -110,8 +110,8 @@ def conv2d_transpose_tanh(tensors, batch_size=1, conv_size=32, name: str='conv2d
110110
filt = tf.get_variable('filter', [5, 5, conv_size, tensor_shape[-1]], tf.float32, tf.random_normal_initializer(0, stddev), trainable=True)
111111
output = []
112112
for tensor in tensors:
113-
conv_shape = [batch_size, int(tensor_shape[1]*2), int(tensor_shape[2]*2), conv_size]
114-
deconv = tf.nn.conv2d_transpose(tensor, filt, conv_shape, [1, 2, 2, 1])
113+
conv_shape = [batch_size, int(tensor_shape[1]), int(tensor_shape[2]), conv_size]
114+
deconv = tf.nn.conv2d_transpose(tensor, filt, conv_shape, [1, 1, 1, 1])
115115
output.append(tf.nn.tanh(deconv))
116116
return output
117117

train.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
}
1919

2020
IMAGE_CONFIG = {
21-
'pool_renew': 2,
2221
'rotation_range': (-20, 20),
2322
'brightness_range': (0.7, 1.2),
2423
'saturation_range': (0.9, 1.5),

0 commit comments

Comments
 (0)