@@ -155,11 +155,11 @@ def __get_feed_dict__(self):
155
155
156
156
def __training_iteration__ (self , session , i ):
157
157
if i < 500 : #Initialising iterations
158
- if i < 100 :
158
+ if i < 50 :
159
159
session .run ([self .discriminator_solver ], feed_dict = self .__get_feed_dict__ ())
160
160
else :
161
161
session .run ([self .discriminator_solver , self .generator_solver ], feed_dict = self .__get_feed_dict__ ())
162
- elif i % 20 == 0 : #Check the scaling
162
+ elif i % 10 == 0 : #Check the scaling
163
163
_ , _ , self .current_scale = session .run ([self .discriminator_solver , self .generator_solver , self .scale ], feed_dict = self .__get_feed_dict__ ())
164
164
elif self .current_scale > 1.3 : #Train only the worse performing network (do some additional faster iterations)
165
165
session .run (self .generator_solver , feed_dict = {self .generator_input : self .random_input ()})
@@ -193,8 +193,8 @@ def train(self, batches=100000, print_interval=1):
193
193
curr_time = timer ()
194
194
time_per = time_per * 0.6 + (curr_time - last_time )/ print_interval * 0.4
195
195
time = curr_time - start_time
196
- print ("Iteration : %04d Time: %02d:%02d:%02d (%02.1fs / iteration)" % \
197
- (i , time // 3600 , time % 3600 // 60 , time % 60 , time_per ), end = '\r ' )
196
+ print ("\r Iteration : %04d Time: %02d:%02d:%02d (%02.1fs / iteration)" % \
197
+ (i , time // 3600 , time % 3600 // 60 , time % 60 , time_per ), end = '' )
198
198
last_time = curr_time
199
199
if self .log :
200
200
logger (i )
@@ -203,9 +203,9 @@ def train(self, batches=100000, print_interval=1):
203
203
saver .save (session , os .path .join (self .directory , self .name ), self .iterations )
204
204
last_save = timer ()
205
205
except KeyboardInterrupt :
206
- print ()
207
- print ("Stopping the training" )
206
+ pass
208
207
finally :
208
+ print ()
209
209
if self .log :
210
210
logger .close ()
211
211
print ("Saving the network" )
@@ -233,7 +233,7 @@ def __call__(self, iteration):
233
233
if iteration % self .image_interval == 0 :
234
234
#Hack to make tensorboard show multiple images, not just the latest one
235
235
feed_dict = self .gan .__get_feed_dict__ ()
236
- feed_dict [self .gan .generator_input ] = self .batch_input ,
236
+ feed_dict [self .gan .generator_input ] = self .batch_input
237
237
image , summary = self .session .run (
238
238
[tf .summary .image (
239
239
'training/iteration/%d' % iteration ,
0 commit comments