Skip to content

Commit 765f849

Browse files
committed
fix
1 parent 8f5b779 commit 765f849

File tree

1 file changed

+14
-24
lines changed

1 file changed

+14
-24
lines changed

09-Generative Adversarial network/simple_Gan.py

Lines changed: 14 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -25,18 +25,15 @@ def to_img(x):
2525

2626
# Image processing
2727
img_transform = transforms.Compose([
28-
transforms.ToTensor(),
29-
transforms.Normalize(mean=(0.5, 0.5, 0.5),
30-
std=(0.5, 0.5, 0.5))])
28+
transforms.ToTensor(),
29+
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
30+
])
3131
# MNIST dataset
32-
mnist = datasets.MNIST(root='./data/',
33-
train=True,
34-
transform=img_transform,
35-
download=True)
32+
mnist = datasets.MNIST(
33+
root='./data/', train=True, transform=img_transform, download=True)
3634
# Data loader
37-
dataloader = torch.utils.data.DataLoader(dataset=mnist,
38-
batch_size=batch_size,
39-
shuffle=True)
35+
dataloader = torch.utils.data.DataLoader(
36+
dataset=mnist, batch_size=batch_size, shuffle=True)
4037

4138

4239
# Discriminator
@@ -47,10 +44,7 @@ def __init__(self):
4744
nn.Linear(784, 256),
4845
nn.LeakyReLU(0.2),
4946
nn.Linear(256, 256),
50-
nn.LeakyReLU(0.2),
51-
nn.Linear(256, 1),
52-
nn.Sigmoid()
53-
)
47+
nn.LeakyReLU(0.2), nn.Linear(256, 1), nn.Sigmoid())
5448

5549
def forward(self, x):
5650
x = self.dis(x)
@@ -64,11 +58,7 @@ def __init__(self):
6458
self.gen = nn.Sequential(
6559
nn.Linear(100, 256),
6660
nn.ReLU(True),
67-
nn.Linear(256, 256),
68-
nn.ReLU(True),
69-
nn.Linear(256, 784),
70-
nn.Tanh()
71-
)
61+
nn.Linear(256, 256), nn.ReLU(True), nn.Linear(256, 784), nn.Tanh())
7262

7363
def forward(self, x):
7464
x = self.gen(x)
@@ -125,17 +115,17 @@ def forward(self, x):
125115
g_loss.backward()
126116
g_optimizer.step()
127117

128-
if (i+1) % 100 == 0:
118+
if (i + 1) % 100 == 0:
129119
print('Epoch [{}/{}], d_loss: {:.6f}, g_loss: {:.6f} '
130-
'D real: {:.6f}, D fake: {:.6f}'
131-
.format(epoch, num_epoch, d_loss.data[0], g_loss.data[0],
132-
real_scores.data.mean(), fake_scores.data.mean()))
120+
'D real: {:.6f}, D fake: {:.6f}'.format(
121+
epoch, num_epoch, d_loss.data[0], g_loss.data[0],
122+
real_scores.data.mean(), fake_scores.data.mean()))
133123
if epoch == 0:
134124
real_images = to_img(real_img.cpu().data)
135125
save_image(real_images, './img/real_images.png')
136126

137127
fake_images = to_img(fake_img.cpu().data)
138-
save_image(fake_images, './img/fake_images-{}.png'.format(epoch+1))
128+
save_image(fake_images, './img/fake_images-{}.png'.format(epoch + 1))
139129

140130
torch.save(G.state_dict(), './generator.pth')
141131
torch.save(D.state_dict(), './discriminator.pth')

0 commit comments

Comments
 (0)