@@ -25,18 +25,15 @@ def to_img(x):
25
25
26
26
# Image processing
27
27
img_transform = transforms .Compose ([
28
- transforms .ToTensor (),
29
- transforms .Normalize (mean = (0.5 , 0.5 , 0.5 ),
30
- std = ( 0.5 , 0.5 , 0.5 )) ])
28
+ transforms .ToTensor (),
29
+ transforms .Normalize (mean = (0.5 , 0.5 , 0.5 ), std = ( 0.5 , 0.5 , 0.5 ))
30
+ ])
31
31
# MNIST dataset
32
- mnist = datasets .MNIST (root = './data/' ,
33
- train = True ,
34
- transform = img_transform ,
35
- download = True )
32
+ mnist = datasets .MNIST (
33
+ root = './data/' , train = True , transform = img_transform , download = True )
36
34
# Data loader
37
- dataloader = torch .utils .data .DataLoader (dataset = mnist ,
38
- batch_size = batch_size ,
39
- shuffle = True )
35
+ dataloader = torch .utils .data .DataLoader (
36
+ dataset = mnist , batch_size = batch_size , shuffle = True )
40
37
41
38
42
39
# Discriminator
@@ -47,10 +44,7 @@ def __init__(self):
47
44
nn .Linear (784 , 256 ),
48
45
nn .LeakyReLU (0.2 ),
49
46
nn .Linear (256 , 256 ),
50
- nn .LeakyReLU (0.2 ),
51
- nn .Linear (256 , 1 ),
52
- nn .Sigmoid ()
53
- )
47
+ nn .LeakyReLU (0.2 ), nn .Linear (256 , 1 ), nn .Sigmoid ())
54
48
55
49
def forward (self , x ):
56
50
x = self .dis (x )
@@ -64,11 +58,7 @@ def __init__(self):
64
58
self .gen = nn .Sequential (
65
59
nn .Linear (100 , 256 ),
66
60
nn .ReLU (True ),
67
- nn .Linear (256 , 256 ),
68
- nn .ReLU (True ),
69
- nn .Linear (256 , 784 ),
70
- nn .Tanh ()
71
- )
61
+ nn .Linear (256 , 256 ), nn .ReLU (True ), nn .Linear (256 , 784 ), nn .Tanh ())
72
62
73
63
def forward (self , x ):
74
64
x = self .gen (x )
@@ -125,17 +115,17 @@ def forward(self, x):
125
115
g_loss .backward ()
126
116
g_optimizer .step ()
127
117
128
- if (i + 1 ) % 100 == 0 :
118
+ if (i + 1 ) % 100 == 0 :
129
119
print ('Epoch [{}/{}], d_loss: {:.6f}, g_loss: {:.6f} '
130
- 'D real: {:.6f}, D fake: {:.6f}'
131
- . format ( epoch , num_epoch , d_loss .data [0 ], g_loss .data [0 ],
132
- real_scores .data .mean (), fake_scores .data .mean ()))
120
+ 'D real: {:.6f}, D fake: {:.6f}' . format (
121
+ epoch , num_epoch , d_loss .data [0 ], g_loss .data [0 ],
122
+ real_scores .data .mean (), fake_scores .data .mean ()))
133
123
if epoch == 0 :
134
124
real_images = to_img (real_img .cpu ().data )
135
125
save_image (real_images , './img/real_images.png' )
136
126
137
127
fake_images = to_img (fake_img .cpu ().data )
138
- save_image (fake_images , './img/fake_images-{}.png' .format (epoch + 1 ))
128
+ save_image (fake_images , './img/fake_images-{}.png' .format (epoch + 1 ))
139
129
140
130
torch .save (G .state_dict (), './generator.pth' )
141
131
torch .save (D .state_dict (), './discriminator.pth' )
0 commit comments