Skip to content

Commit 956575d

Browse files
committed
Training Rate Scaling
1 parent 79812c0 commit 956575d

File tree

3 files changed

+14
-5
lines changed

3 files changed

+14
-5
lines changed

generator_gan.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ def setup_network(self):
9191
dis_var = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='discriminator')
9292
self.generator_solver, self.discriminator_solver, self.scale = \
9393
gan_optimizer('train', gen_var, dis_var, gen_logit, image_logit, self._y_offset, 1-self._y_offset,
94-
*self.learning_rate, global_step=self.iterations, summary=self.log)
94+
*self.learning_rate, learning_rate_pivot=10000, global_step=self.iterations, summary=self.log)
9595

9696

9797
def random_input(self):
@@ -190,10 +190,14 @@ def train(self, batches=100000, print_interval=1):
190190
if timer() - last_save > 1800:
191191
saver.save(session, os.path.join(self.directory, self.name))
192192
last_save = timer()
193+
print("Saving the network")
194+
saver.save(session, os.path.join(self.directory, self.name))
195+
if self.log:
196+
logger.close()
197+
session.close()
193198
except KeyboardInterrupt:
194199
print()
195-
print("Stopping the training", end='')
196-
finally:
200+
print("Stopping the training")
197201
saver.save(session, os.path.join(self.directory, self.name))
198202
if self.log:
199203
logger.close()

network.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,9 +91,14 @@ def batch_optimizer(name, variables, positive_tensors=None, positive_value=1, po
9191
return solver
9292

9393
def gan_optimizer(name, gen_vars, dis_vars, fake_tensor, real_tensor, false_val=0, real_val=1,
94-
learning_rate=0.001, learning_momentum=0.9, learning_momentum2=0.99, global_step=None, summary=True):
94+
learning_rate=0.001, learning_momentum=0.9, learning_momentum2=0.99,
95+
learning_rate_pivot=0, global_step=None, summary=True):
9596
"""Create an optimizer for a GAN"""
9697
with openif_scope(name):
98+
#learning rate scaling
99+
if learning_rate_pivot > 0:
100+
scaler = tf.sqrt(tf.div(tf.to_float(global_step), float(learning_rate_pivot))+1)
101+
learning_rate = tf.div(learning_rate, scaler)
97102
#generator
98103
with tf.variable_scope('generator'):
99104
gen_labels = tf.fill(tf.shape(fake_tensor), real_val)

train.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
CONFIG = {
77
'colors': 3,
8-
'batch_size': 128,
8+
'batch_size': 192,
99
'generator_base_width': 32,
1010
'image_size': 64,
1111
'discriminator_convolutions': 5,

0 commit comments

Comments
 (0)