13
13
14
14
class GANetwork ():
15
15
16
- def __init__ (self , name , image_size = 64 , colors = 3 , batch_size = 64 , directory = 'network' , image_manager = None ,
16
+ def __init__ (self , name , setup = True , image_size = 64 , colors = 3 , batch_size = 64 , directory = 'network' , image_manager = None ,
17
17
input_size = 64 , learning_rate = 0.0002 , dropout = 0.4 , generator_convolutions = 5 , generator_base_width = 32 ,
18
18
discriminator_convolutions = 4 , discriminator_base_width = 32 , classification_depth = 1 , grid_size = 4 ,
19
19
log = True , y_offset = 0.1 , learning_momentum = 0.6 , learning_momentum2 = 0.9 ):
20
20
"""
21
21
Create a GAN for generating images
22
22
Args:
23
23
name: The name of the network
24
+ setup: Initialize the network in the constructor
24
25
image_size: The size of the generated images
25
26
colors: number of color layers (3 is rgb, 1 is grayscale)
26
27
batch_size: images per training batch
@@ -44,11 +45,21 @@ def __init__(self, name, image_size=64, colors=3, batch_size=64, directory='netw
44
45
self .image_size = image_size
45
46
self .colors = colors
46
47
self .batch_size = batch_size
48
+ self .grid_size = min (grid_size , int (math .sqrt (batch_size )))
49
+ self .log = log
47
50
self .directory = directory
51
+ os .makedirs (directory , exist_ok = True )
52
+ #Network variables
48
53
self .input_size = input_size
54
+ self ._gen_conv = generator_convolutions
55
+ self ._gen_width = generator_base_width
56
+ self ._dis_conv = discriminator_convolutions
57
+ self ._dis_width = discriminator_base_width
58
+ self ._class_depth = classification_depth
49
59
self .dropout = dropout
50
- self .grid_size = min (grid_size , int (math .sqrt (batch_size )))
51
- self .log = log
60
+ #Training variables
61
+ self .learning_rate = (learning_rate , learning_momentum , learning_momentum2 )
62
+ self ._y_offset = y_offset
52
63
#Setup Images
53
64
if image_manager is None :
54
65
self .image_manager = ImageVariations (image_size = image_size , batch_size = batch_size , colored = (colors == 3 ))
@@ -60,33 +71,37 @@ def __init__(self, name, image_size=64, colors=3, batch_size=64, directory='netw
60
71
self .image_manager .start_threads ()
61
72
#Setup Networks
62
73
self .iterations = tf .Variable (0 , name = "training_iterations" , trainable = False )
63
- os . makedirs ( directory , exist_ok = True )
64
- #Generator
65
- self .input = self . generator_output = None
66
- self .generator ( generator_convolutions , generator_base_width )
67
- #Generated output
74
+ with tf . variable_scope ( 'input' ):
75
+ self . generator_input = tf . placeholder ( tf . float32 , [ None , self . input_size ], name = 'generator_input' )
76
+ self .image_input = tf . placeholder ( tf . uint8 , shape = [ None , image_size , image_size , self . colors ], name = 'image_input' )
77
+ self .image_input_scaled = tf . subtract ( tf . to_float ( self . image_input ) / 127.5 , 1 , name = 'image_scaling' )
78
+ self . generator_output = None
68
79
self .image_output = self .image_grid_output = None
69
- self .setup_output ()
70
- #Discriminator
71
- self .image_input = self .image_logit = self .generated_logit = None
80
+ self .image_logit = self .generated_logit = None
72
81
self .variation_updater = self .image_variation = None
73
- self .discriminator (self .generator_output , discriminator_convolutions , discriminator_base_width , classification_depth )
74
- #Losses and Solvers
75
- self .generator_loss , self .discriminator_loss , self .d_loss_real , self .d_loss_fake = \
76
- self .loss_functions (self .image_logit , self .generated_logit , y_offset )
77
- self .generator_solver , self .discriminator_solver = \
78
- self .solver_functions (self .generator_loss , self .discriminator_loss , learning_rate , learning_momentum , learning_momentum2 )
82
+ self .generator_solver = self .discriminator_solver = None
83
+ if setup :
84
+ self .setup_network ()
85
+
86
+ def setup_network (self ):
87
+ """Initialize the network if it is not done in the constructor"""
88
+ self .__generator__ ()
89
+ self .setup_output ()
90
+ self .discriminator (self .generator_output , self ._dis_conv , self ._dis_width , self ._class_depth )
91
+ g_loss , d_loss , d_loss_real , d_loss_fake = self .loss_functions (self .image_logit , self .generated_logit , self ._y_offset )
92
+ self .generator_solver , self .discriminator_solver = self .solver_functions (g_loss , d_loss , * self .learning_rate )
79
93
80
94
81
- def generator (self , conv_layers , conv_size ):
95
+ def __generator__ (self ):
82
96
"""Create a Generator Network"""
97
+ conv_layers = self ._gen_conv
98
+ conv_size = self ._gen_width
83
99
with tf .variable_scope ('generator' ):
84
100
#Network layer variables
85
101
conv_image_size = self .image_size // (2 ** conv_layers )
86
102
assert conv_image_size * (2 ** conv_layers ) == self .image_size , "Images must be a multiple of two (or at least divisible by 2**num_of_conv_layers_plus_one)"
87
- #Input Layers
88
- self .input = tf .placeholder (tf .float32 , [None , self .input_size ], name = 'input' )
89
- prev_layer = expand_relu (self .input , [- 1 , conv_image_size , conv_image_size , conv_size * 2 ** (conv_layers - 1 )], 'expand' )
103
+ #Input Layer
104
+ prev_layer = expand_relu (self .generator_input , [- 1 , conv_image_size , conv_image_size , conv_size * 2 ** (conv_layers - 1 )], 'expand' )
90
105
#Conv layers
91
106
for i in range (conv_layers - 1 ):
92
107
prev_layer = conv2d_transpose (prev_layer , self .batch_size , 2 ** (conv_layers - i - 2 )* conv_size , 'convolution_%d' % i )
@@ -118,9 +133,6 @@ def discriminator(self, generator_output, conv_layers, conv_size, class_layers):
118
133
"""Create a Discriminator Network"""
119
134
image_size = self .image_size
120
135
with tf .variable_scope ('discriminator' ) as scope :
121
- with tf .variable_scope ('real_input' ):
122
- self .image_input = tf .placeholder (tf .uint8 , shape = [None , image_size , image_size , self .colors ], name = 'image_input' )
123
- real_input_scaled = tf .subtract (tf .to_float (self .image_input )/ 127.5 , 1 , name = 'scaling' )
124
136
conv_output_size = ((image_size // (2 ** conv_layers ))** 2 ) * conv_size * conv_layers
125
137
class_output_size = 2 ** int (math .log (conv_output_size // 2 , 2 ))
126
138
#Create Layers
@@ -135,11 +147,11 @@ def create_network(layer, summary=True):
135
147
return linear (layer , 1 , 'output' , summary = summary )
136
148
self .generated_logit = create_network (generator_output )
137
149
scope .reuse_variables ()
138
- self .image_logit = create_network (real_input_scaled , False )
150
+ self .image_logit = create_network (self . image_input_scaled , False )
139
151
if self .log :
140
152
with tf .variable_scope ('pixel_variation' ):
141
153
#Pixel Variations
142
- img_tot_var = tf .image .total_variation (real_input_scaled )
154
+ img_tot_var = tf .image .total_variation (self . image_input_scaled )
143
155
gen_tot_var = tf .image .total_variation (generator_output )
144
156
image_variation = tf .reduce_sum (img_tot_var )
145
157
gener_variation = tf .reduce_sum (gen_tot_var )
@@ -204,7 +216,7 @@ def generate(self, session, name, amount=1):
204
216
def get_arr ():
205
217
arr = np .asarray (session .run (
206
218
self .image_output ,
207
- feed_dict = {self .input : self .random_input (self .batch_size )}
219
+ feed_dict = {self .generator_input : self .random_input (self .batch_size )}
208
220
), np .uint8 )
209
221
arr .shape = self .batch_size , self .image_size , self .image_size , self .colors
210
222
return arr
@@ -223,7 +235,7 @@ def generate_grid(self, session, name):
223
235
"""Generate a image and save it"""
224
236
grid = session .run (
225
237
self .image_grid_output ,
226
- feed_dict = {self .input : self .random_input (self .batch_size )}
238
+ feed_dict = {self .generator_input : self .random_input (self .batch_size )}
227
239
)
228
240
self .image_manager .image_size = self .image_grid_output .get_shape ()[1 ]
229
241
self .image_manager .save_image (grid , name )
@@ -238,7 +250,7 @@ def get_session(self):
238
250
saver .restore (session , os .path .join (self .directory , self .name ))
239
251
start_iteration = session .run (self .iterations )
240
252
print ("\n Loaded an existing network\n " )
241
- except Exception :
253
+ except Exception as e :
242
254
start_iteration = 0
243
255
if self .log :
244
256
tf .summary .FileWriter (os .path .join (LOG_DIR , self .name ), session .graph )
@@ -265,7 +277,7 @@ def train(self, batches=100000, print_interval=1):
265
277
for i in range (start_iteration , start_iteration + batches + 1 ):
266
278
feed_dict = {
267
279
self .image_input : self .image_manager .get_batch (),
268
- self .input : self .random_input (self .batch_size )
280
+ self .generator_input : self .random_input (self .batch_size )
269
281
}
270
282
session .run (calculations , feed_dict = feed_dict )
271
283
#Print progress
0 commit comments