1
1
2
2
import os
3
3
4
- import tensorflow as tf
5
-
6
- from network import GANetwork
7
- from train import CONFIG
4
+ from train import CONFIG , get_network
8
5
9
6
10
7
def get_config (batch ):
@@ -14,20 +11,19 @@ def get_config(batch):
14
11
config ['grid_size' ] = int (batch ** 0.5 )
15
12
return config
16
13
17
-
18
14
def generate (name , amount = 1 ):
19
- gan = GANetwork (name , ** get_config (amount ))
20
- session , _ , iter = gan .get_session ()
21
- if iter == 0 :
15
+ gan = get_network (name , ** get_config (amount ))
16
+ session , _ , iteration = gan .get_session ()
17
+ if iteration == 0 :
22
18
print ("No already trained network found (%s)" % name )
23
19
return
24
20
print ("Generating %d images using the %s network" % (amount , name ))
25
21
gan .generate (session , gan .name , amount )
26
22
27
23
def generate_grid (name , size = 5 ):
28
- gan = GANetwork (name , ** get_config (size * size ))
29
- session , _ , iter = gan .get_session ()
30
- if iter == 0 :
24
+ gan = get_network (name , ** get_config (size * size ))
25
+ session , _ , iteration = gan .get_session ()
26
+ if iteration == 0 :
31
27
print ("No already trained network found (%s)" % name )
32
28
return
33
29
print ("Generating a image grid using the %s network" % name )
0 commit comments