-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathmodels.py
97 lines (82 loc) · 2.69 KB
/
models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import torch
import torch.nn as nn
nz = 100
nc = 3
ngf = 64
ndf = 64
class _netG(nn.Module):
def __init__(self):
super(_netG, self).__init__()
self.main = nn.Sequential(
# Z
nn.ConvTranspose2d(nz, ngf*8, 2, 1, 0, bias=False),
nn.BatchNorm2d(ngf * 8),
nn.ReLU(True),
# (ngf * 8) x 2 x 2
nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf * 4),
nn.ReLU(True),
# (ngf * 4) x 4 x 4
nn.ConvTranspose2d(ngf*4, ngf*2, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf*2),
nn.ReLU(True),
# (ngf * 2) x 8 x 8
nn.ConvTranspose2d(ngf*2, ngf, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf),
nn.ReLU(),
# ngf x 16 x 16
nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),
nn.Tanh()
)
def forward(self, input):
output = self.main(input)
return output
class _netD(nn.Module):
def __init__(self):
super(_netD, self).__init__()
self.main = nn.Sequential(
# (nc) x 32 x 32
nn.Conv2d(nc, ndf, 4,2,1,bias=False),
nn.LeakyReLU(0.2, inplace=True),
# ndf x 16 x 16
nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(ndf * 2),
nn.LeakyReLU(0.2, inplace=True),
# (ndf * 2) x 8 x 8
nn.Conv2d(ndf*2, ndf*4, 4, 2, 1, bias=False),
nn.BatchNorm2d(ndf * 4),
nn.LeakyReLU(0.2, inplace=True),
# (ndf * 4) x 4 x 4
nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
nn.BatchNorm2d(ndf*8),
nn.LeakyReLU(0.2, inplace=True),
# (ndf * 8) x 2 x 2
nn.Conv2d(ndf*8, 1, 2, 1, 0, bias=False),
nn.Softplus()
)
def forward(self, input):
output = self.main(input)
return output.view(-1, 1).squeeze(1)
def weights_init(m):
classname = m.__class__.__name__
if classname.find('Conv') != -1: # Conv weight init
m.weight.data.normal_(0.0, 0.01)
elif classname.find('BatchNorm') != -1: # BatchNorm weight init
m.weight.data.normal_(1.0, 0.02)
m.bias.data.fill_(0)
def get_netG():
use_cuda = torch.cuda.is_available()
netG = _netG()
netG.apply(weights_init)
if use_cuda:
print("USE CUDA")
netG.cuda()
return netG
def get_netD():
use_cuda = torch.cuda.is_available()
netD = _netD()
netD.apply(weights_init)
if use_cuda:
print("USE CUDA")
netD.cuda()
return netD