Skip to content

Commit bf4902c

Browse files
committed
Prepare for subclasses
1 parent 9deaa8c commit bf4902c

File tree

3 files changed

+53
-31
lines changed

3 files changed

+53
-31
lines changed

autogan.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
from network import GANetwork
2+
3+
class AutoGanGenerator(GANetwork):
4+
5+
def __init__(self, **kwargs):
6+
super().__init__(setup=False, **kwargs)
7+
#TODO setup autencoder
8+
self.setup_network()

network.py

Lines changed: 42 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,15 @@
1313

1414
class GANetwork():
1515

16-
def __init__(self, name, image_size=64, colors=3, batch_size=64, directory='network', image_manager=None,
16+
def __init__(self, name, setup=True, image_size=64, colors=3, batch_size=64, directory='network', image_manager=None,
1717
input_size=64, learning_rate=0.0002, dropout=0.4, generator_convolutions=5, generator_base_width=32,
1818
discriminator_convolutions=4, discriminator_base_width=32, classification_depth=1, grid_size=4,
1919
log=True, y_offset=0.1, learning_momentum=0.6, learning_momentum2=0.9):
2020
"""
2121
Create a GAN for generating images
2222
Args:
2323
name: The name of the network
24+
setup: Initialize the network in the constructor
2425
image_size: The size of the generated images
2526
colors: number of color layers (3 is rgb, 1 is grayscale)
2627
batch_size: images per training batch
@@ -44,11 +45,21 @@ def __init__(self, name, image_size=64, colors=3, batch_size=64, directory='netw
4445
self.image_size = image_size
4546
self.colors = colors
4647
self.batch_size = batch_size
48+
self.grid_size = min(grid_size, int(math.sqrt(batch_size)))
49+
self.log = log
4750
self.directory = directory
51+
os.makedirs(directory, exist_ok=True)
52+
#Network variables
4853
self.input_size = input_size
54+
self._gen_conv = generator_convolutions
55+
self._gen_width = generator_base_width
56+
self._dis_conv = discriminator_convolutions
57+
self._dis_width = discriminator_base_width
58+
self._class_depth = classification_depth
4959
self.dropout = dropout
50-
self.grid_size = min(grid_size, int(math.sqrt(batch_size)))
51-
self.log = log
60+
#Training variables
61+
self.learning_rate = (learning_rate, learning_momentum, learning_momentum2)
62+
self._y_offset = y_offset
5263
#Setup Images
5364
if image_manager is None:
5465
self.image_manager = ImageVariations(image_size=image_size, batch_size=batch_size, colored=(colors == 3))
@@ -60,33 +71,37 @@ def __init__(self, name, image_size=64, colors=3, batch_size=64, directory='netw
6071
self.image_manager.start_threads()
6172
#Setup Networks
6273
self.iterations = tf.Variable(0, name="training_iterations", trainable=False)
63-
os.makedirs(directory, exist_ok=True)
64-
#Generator
65-
self.input = self.generator_output = None
66-
self.generator(generator_convolutions, generator_base_width)
67-
#Generated output
74+
with tf.variable_scope('input'):
75+
self.generator_input = tf.placeholder(tf.float32, [None, self.input_size], name='generator_input')
76+
self.image_input = tf.placeholder(tf.uint8, shape=[None, image_size, image_size, self.colors], name='image_input')
77+
self.image_input_scaled = tf.subtract(tf.to_float(self.image_input)/127.5, 1, name='image_scaling')
78+
self.generator_output = None
6879
self.image_output = self.image_grid_output = None
69-
self.setup_output()
70-
#Discriminator
71-
self.image_input = self.image_logit = self.generated_logit = None
80+
self.image_logit = self.generated_logit = None
7281
self.variation_updater = self.image_variation = None
73-
self.discriminator(self.generator_output, discriminator_convolutions, discriminator_base_width, classification_depth)
74-
#Losses and Solvers
75-
self.generator_loss, self.discriminator_loss, self.d_loss_real, self.d_loss_fake = \
76-
self.loss_functions(self.image_logit, self.generated_logit, y_offset)
77-
self.generator_solver, self.discriminator_solver = \
78-
self.solver_functions(self.generator_loss, self.discriminator_loss, learning_rate, learning_momentum, learning_momentum2)
82+
self.generator_solver = self.discriminator_solver = None
83+
if setup:
84+
self.setup_network()
85+
86+
def setup_network(self):
87+
"""Initialize the network if it is not done in the constructor"""
88+
self.__generator__()
89+
self.setup_output()
90+
self.discriminator(self.generator_output, self._dis_conv, self._dis_width, self._class_depth)
91+
g_loss, d_loss, d_loss_real, d_loss_fake = self.loss_functions(self.image_logit, self.generated_logit, self._y_offset)
92+
self.generator_solver, self.discriminator_solver = self.solver_functions(g_loss, d_loss, *self.learning_rate)
7993

8094

81-
def generator(self, conv_layers, conv_size):
95+
def __generator__(self):
8296
"""Create a Generator Network"""
97+
conv_layers = self._gen_conv
98+
conv_size = self._gen_width
8399
with tf.variable_scope('generator'):
84100
#Network layer variables
85101
conv_image_size = self.image_size // (2**conv_layers)
86102
assert conv_image_size*(2**conv_layers) == self.image_size, "Images must be a multiple of two (or at least divisible by 2**num_of_conv_layers_plus_one)"
87-
#Input Layers
88-
self.input = tf.placeholder(tf.float32, [None, self.input_size], name='input')
89-
prev_layer = expand_relu(self.input, [-1, conv_image_size, conv_image_size, conv_size*2**(conv_layers-1)], 'expand')
103+
#Input Layer
104+
prev_layer = expand_relu(self.generator_input, [-1, conv_image_size, conv_image_size, conv_size*2**(conv_layers-1)], 'expand')
90105
#Conv layers
91106
for i in range(conv_layers-1):
92107
prev_layer = conv2d_transpose(prev_layer, self.batch_size, 2**(conv_layers-i-2)*conv_size, 'convolution_%d'%i)
@@ -118,9 +133,6 @@ def discriminator(self, generator_output, conv_layers, conv_size, class_layers):
118133
"""Create a Discriminator Network"""
119134
image_size = self.image_size
120135
with tf.variable_scope('discriminator') as scope:
121-
with tf.variable_scope('real_input'):
122-
self.image_input = tf.placeholder(tf.uint8, shape=[None, image_size, image_size, self.colors], name='image_input')
123-
real_input_scaled = tf.subtract(tf.to_float(self.image_input)/127.5, 1, name='scaling')
124136
conv_output_size = ((image_size//(2**conv_layers))**2) * conv_size * conv_layers
125137
class_output_size = 2**int(math.log(conv_output_size//2, 2))
126138
#Create Layers
@@ -135,11 +147,11 @@ def create_network(layer, summary=True):
135147
return linear(layer, 1, 'output', summary=summary)
136148
self.generated_logit = create_network(generator_output)
137149
scope.reuse_variables()
138-
self.image_logit = create_network(real_input_scaled, False)
150+
self.image_logit = create_network(self.image_input_scaled, False)
139151
if self.log:
140152
with tf.variable_scope('pixel_variation'):
141153
#Pixel Variations
142-
img_tot_var = tf.image.total_variation(real_input_scaled)
154+
img_tot_var = tf.image.total_variation(self.image_input_scaled)
143155
gen_tot_var = tf.image.total_variation(generator_output)
144156
image_variation = tf.reduce_sum(img_tot_var)
145157
gener_variation = tf.reduce_sum(gen_tot_var)
@@ -204,7 +216,7 @@ def generate(self, session, name, amount=1):
204216
def get_arr():
205217
arr = np.asarray(session.run(
206218
self.image_output,
207-
feed_dict={self.input: self.random_input(self.batch_size)}
219+
feed_dict={self.generator_input: self.random_input(self.batch_size)}
208220
), np.uint8)
209221
arr.shape = self.batch_size, self.image_size, self.image_size, self.colors
210222
return arr
@@ -223,7 +235,7 @@ def generate_grid(self, session, name):
223235
"""Generate a image and save it"""
224236
grid = session.run(
225237
self.image_grid_output,
226-
feed_dict={self.input: self.random_input(self.batch_size)}
238+
feed_dict={self.generator_input: self.random_input(self.batch_size)}
227239
)
228240
self.image_manager.image_size = self.image_grid_output.get_shape()[1]
229241
self.image_manager.save_image(grid, name)
@@ -238,7 +250,7 @@ def get_session(self):
238250
saver.restore(session, os.path.join(self.directory, self.name))
239251
start_iteration = session.run(self.iterations)
240252
print("\nLoaded an existing network\n")
241-
except Exception:
253+
except Exception as e:
242254
start_iteration = 0
243255
if self.log:
244256
tf.summary.FileWriter(os.path.join(LOG_DIR, self.name), session.graph)
@@ -265,7 +277,7 @@ def train(self, batches=100000, print_interval=1):
265277
for i in range(start_iteration, start_iteration+batches+1):
266278
feed_dict = {
267279
self.image_input: self.image_manager.get_batch(),
268-
self.input: self.random_input(self.batch_size)
280+
self.generator_input: self.random_input(self.batch_size)
269281
}
270282
session.run(calculations, feed_dict=feed_dict)
271283
#Print progress

train.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11

22
import os
33
from network import GANetwork
4+
from autogan import AutoGanGenerator
45

56

67
CONFIG = {
@@ -13,7 +14,8 @@
1314
}
1415

1516
def get_network(name, **config):
16-
return GANetwork(name, **config)
17+
#return GANetwork(name, **config)
18+
return AutoGanGenerator(name=name, **config)
1719

1820

1921
if __name__ == '__main__':

0 commit comments

Comments
 (0)