Skip to content

Commit 2a05515

Browse files
committed
Remove the autoencoder initilizer
1 parent a665020 commit 2a05515

File tree

3 files changed

+9
-49
lines changed

3 files changed

+9
-49
lines changed

generator_autoinit.py

Lines changed: 0 additions & 40 deletions
This file was deleted.

generator_gan.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -155,11 +155,11 @@ def __get_feed_dict__(self):
155155

156156
def __training_iteration__(self, session, i):
157157
if i < 500: #Initialising iterations
158-
if i < 100:
158+
if i < 50:
159159
session.run([self.discriminator_solver], feed_dict=self.__get_feed_dict__())
160160
else:
161161
session.run([self.discriminator_solver, self.generator_solver], feed_dict=self.__get_feed_dict__())
162-
elif i%20 == 0: #Check the scaling
162+
elif i%10 == 0: #Check the scaling
163163
_, _, self.current_scale = session.run([self.discriminator_solver, self.generator_solver, self.scale], feed_dict=self.__get_feed_dict__())
164164
elif self.current_scale > 1.3: #Train only the worse performing network (do some additional faster iterations)
165165
session.run(self.generator_solver, feed_dict={self.generator_input: self.random_input()})
@@ -193,8 +193,8 @@ def train(self, batches=100000, print_interval=1):
193193
curr_time = timer()
194194
time_per = time_per*0.6 + (curr_time-last_time)/print_interval*0.4
195195
time = curr_time - start_time
196-
print("Iteration: %04d Time: %02d:%02d:%02d (%02.1fs / iteration)" % \
197-
(i, time//3600, time%3600//60, time%60, time_per), end='\r')
196+
print("\rIteration: %04d Time: %02d:%02d:%02d (%02.1fs / iteration)" % \
197+
(i, time//3600, time%3600//60, time%60, time_per), end='')
198198
last_time = curr_time
199199
if self.log:
200200
logger(i)
@@ -203,9 +203,9 @@ def train(self, batches=100000, print_interval=1):
203203
saver.save(session, os.path.join(self.directory, self.name), self.iterations)
204204
last_save = timer()
205205
except KeyboardInterrupt:
206-
print()
207-
print("Stopping the training")
206+
pass
208207
finally:
208+
print()
209209
if self.log:
210210
logger.close()
211211
print("Saving the network")
@@ -233,7 +233,7 @@ def __call__(self, iteration):
233233
if iteration%self.image_interval == 0:
234234
#Hack to make tensorboard show multiple images, not just the latest one
235235
feed_dict = self.gan.__get_feed_dict__()
236-
feed_dict[self.gan.generator_input] = self.batch_input,
236+
feed_dict[self.gan.generator_input] = self.batch_input
237237
image, summary = self.session.run(
238238
[tf.summary.image(
239239
'training/iteration/%d'%iteration,

train.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@
1010
'generator_base_width': 32,
1111
'image_size': 64,
1212
'input_size': 128,
13-
'discriminator_convolutions': 4,
14-
'generator_convolutions': 4,
13+
'discriminator_convolutions': 3,
14+
'generator_convolutions': 3,
1515
'learning_rate': 0.0002,
1616
'learning_momentum': 0.8,
1717
'learning_momentum2': 0.95

0 commit comments

Comments
 (0)