-
Notifications
You must be signed in to change notification settings - Fork 476
/
sagan_models.py
153 lines (123 loc) · 5.28 KB
/
sagan_models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from spectral import SpectralNorm
import numpy as np
class Self_Attn(nn.Module):
""" Self attention Layer"""
def __init__(self,in_dim,activation):
super(Self_Attn,self).__init__()
self.chanel_in = in_dim
self.activation = activation
self.query_conv = nn.Conv2d(in_channels = in_dim , out_channels = in_dim//8 , kernel_size= 1)
self.key_conv = nn.Conv2d(in_channels = in_dim , out_channels = in_dim//8 , kernel_size= 1)
self.value_conv = nn.Conv2d(in_channels = in_dim , out_channels = in_dim , kernel_size= 1)
self.gamma = nn.Parameter(torch.zeros(1))
self.softmax = nn.Softmax(dim=-1) #
def forward(self,x):
"""
inputs :
x : input feature maps( B X C X W X H)
returns :
out : self attention value + input feature
attention: B X N X N (N is Width*Height)
"""
m_batchsize,C,width ,height = x.size()
proj_query = self.query_conv(x).view(m_batchsize,-1,width*height).permute(0,2,1) # B X CX(N)
proj_key = self.key_conv(x).view(m_batchsize,-1,width*height) # B X C x (*W*H)
energy = torch.bmm(proj_query,proj_key) # transpose check
attention = self.softmax(energy) # BX (N) X (N)
proj_value = self.value_conv(x).view(m_batchsize,-1,width*height) # B X C X N
out = torch.bmm(proj_value,attention.permute(0,2,1) )
out = out.view(m_batchsize,C,width,height)
out = self.gamma*out + x
return out,attention
class Generator(nn.Module):
"""Generator."""
def __init__(self, batch_size, image_size=64, z_dim=100, conv_dim=64):
super(Generator, self).__init__()
self.imsize = image_size
layer1 = []
layer2 = []
layer3 = []
last = []
repeat_num = int(np.log2(self.imsize)) - 3
mult = 2 ** repeat_num # 8
layer1.append(SpectralNorm(nn.ConvTranspose2d(z_dim, conv_dim * mult, 4)))
layer1.append(nn.BatchNorm2d(conv_dim * mult))
layer1.append(nn.ReLU())
curr_dim = conv_dim * mult
layer2.append(SpectralNorm(nn.ConvTranspose2d(curr_dim, int(curr_dim / 2), 4, 2, 1)))
layer2.append(nn.BatchNorm2d(int(curr_dim / 2)))
layer2.append(nn.ReLU())
curr_dim = int(curr_dim / 2)
layer3.append(SpectralNorm(nn.ConvTranspose2d(curr_dim, int(curr_dim / 2), 4, 2, 1)))
layer3.append(nn.BatchNorm2d(int(curr_dim / 2)))
layer3.append(nn.ReLU())
if self.imsize == 64:
layer4 = []
curr_dim = int(curr_dim / 2)
layer4.append(SpectralNorm(nn.ConvTranspose2d(curr_dim, int(curr_dim / 2), 4, 2, 1)))
layer4.append(nn.BatchNorm2d(int(curr_dim / 2)))
layer4.append(nn.ReLU())
self.l4 = nn.Sequential(*layer4)
curr_dim = int(curr_dim / 2)
self.l1 = nn.Sequential(*layer1)
self.l2 = nn.Sequential(*layer2)
self.l3 = nn.Sequential(*layer3)
last.append(nn.ConvTranspose2d(curr_dim, 3, 4, 2, 1))
last.append(nn.Tanh())
self.last = nn.Sequential(*last)
self.attn1 = Self_Attn( 128, 'relu')
self.attn2 = Self_Attn( 64, 'relu')
def forward(self, z):
z = z.view(z.size(0), z.size(1), 1, 1)
out=self.l1(z)
out=self.l2(out)
out=self.l3(out)
out,p1 = self.attn1(out)
out=self.l4(out)
out,p2 = self.attn2(out)
out=self.last(out)
return out, p1, p2
class Discriminator(nn.Module):
"""Discriminator, Auxiliary Classifier."""
def __init__(self, batch_size=64, image_size=64, conv_dim=64):
super(Discriminator, self).__init__()
self.imsize = image_size
layer1 = []
layer2 = []
layer3 = []
last = []
layer1.append(SpectralNorm(nn.Conv2d(3, conv_dim, 4, 2, 1)))
layer1.append(nn.LeakyReLU(0.1))
curr_dim = conv_dim
layer2.append(SpectralNorm(nn.Conv2d(curr_dim, curr_dim * 2, 4, 2, 1)))
layer2.append(nn.LeakyReLU(0.1))
curr_dim = curr_dim * 2
layer3.append(SpectralNorm(nn.Conv2d(curr_dim, curr_dim * 2, 4, 2, 1)))
layer3.append(nn.LeakyReLU(0.1))
curr_dim = curr_dim * 2
if self.imsize == 64:
layer4 = []
layer4.append(SpectralNorm(nn.Conv2d(curr_dim, curr_dim * 2, 4, 2, 1)))
layer4.append(nn.LeakyReLU(0.1))
self.l4 = nn.Sequential(*layer4)
curr_dim = curr_dim*2
self.l1 = nn.Sequential(*layer1)
self.l2 = nn.Sequential(*layer2)
self.l3 = nn.Sequential(*layer3)
last.append(nn.Conv2d(curr_dim, 1, 4))
self.last = nn.Sequential(*last)
self.attn1 = Self_Attn(256, 'relu')
self.attn2 = Self_Attn(512, 'relu')
def forward(self, x):
out = self.l1(x)
out = self.l2(out)
out = self.l3(out)
out,p1 = self.attn1(out)
out=self.l4(out)
out,p2 = self.attn2(out)
out=self.last(out)
return out.squeeze(), p1, p2