-
Notifications
You must be signed in to change notification settings - Fork 0
/
dis_modules.py
183 lines (152 loc) · 6.04 KB
/
dis_modules.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
import torch
import torch.nn as nn
import torch.nn.functional as F
from network import *
def squash(s, dim=-1):
'''
"Squashing" non-linearity that shrunks short vectors to almost zero length and long vectors to a length slightly below 1
Eq. (1): v_j = ||s_j||^2 / (1 + ||s_j||^2) * s_j / ||s_j||
Args:
s: Vector before activation
dim: Dimension along which to calculate the norm
Returns:
Squashed vector
'''
squared_norm = torch.sum(s**2, dim=dim, keepdim=True)
return squared_norm / (1 + squared_norm) * s / (torch.sqrt(squared_norm) + 1e-8)
class PrimaryCapsules(nn.Module):
def __init__(self, in_channels, out_channels, dim_caps,
kernel_size=9, stride=2, padding=0):
"""
Initialize the layer.
Args:
in_channels: Number of input channels.
out_channels: Number of output channels.
dim_caps: Dimensionality, i.e. length, of the output capsule vector.
"""
super(PrimaryCapsules, self).__init__()
self.dim_caps = dim_caps
self._caps_channel = int(out_channels / dim_caps)
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding)
torch.nn.init.xavier_normal_(self.conv.weight)
def forward(self, x):
out = self.conv(x)
out = out.view(out.size(0), self._caps_channel, out.size(2), out.size(3), self.dim_caps)
out = out.view(out.size(0), -1, self.dim_caps)
return squash(out)
class RoutingCapsules(nn.Module):
def __init__(self, in_dim, in_caps, num_caps, dim_caps, num_routing):
"""
Initialize the layer.
Args:
in_dim: Dimensionality (i.e. length) of each capsule vector.
in_caps: Number of input capsules if digits layer.
num_caps: Number of capsules in the capsule layer
dim_caps: Dimensionality, i.e. length, of the output capsule vector.
num_routing: Number of iterations during routing algorithm
"""
super(RoutingCapsules, self).__init__()
self.in_dim = in_dim
self.in_caps = in_caps
self.num_caps = num_caps
self.dim_caps = dim_caps
self.num_routing = num_routing
self.W = nn.Parameter(torch.randn(1, num_caps, in_caps, dim_caps, in_dim )*(3/(in_dim + dim_caps + num_caps))**0.5)
def __repr__(self):
tab = ' '
line = '\n'
next = ' -> '
res = self.__class__.__name__ + '('
res = res + line + tab + '(' + str(0) + '): ' + 'CapsuleLinear('
res = res + str(self.in_dim) + ', ' + str(self.dim_caps) + ')'
res = res + line + tab + '(' + str(1) + '): ' + 'Routing('
res = res + 'num_routing=' + str(self.num_routing) + ')'
res = res + line + ')'
return res
def forward(self, x):
batch_size = x.size(0)
# (batch_size, in_caps, in_dim) -> (batch_size, 1, in_caps, in_dim, 1)
x = x.unsqueeze(1).unsqueeze(4)
#
# W @ x =
# (1, num_caps, in_caps, dim_caps, in_dim) @ (batch_size, 1, in_caps, in_dim, 1) =
# (batch_size, num_caps, in_caps, dim_caps, 1)
u_hat = torch.matmul(self.W, x)
# (batch_size, num_caps, in_caps, dim_caps)to_variable
u_hat = u_hat.squeeze(-1)
# detach u_hat during routing iterations to prevent gradients from flowing
temp_u_hat = u_hat
'''
Procedure 1: Routing algorithm
'''
b = torch.zeros(batch_size, self.num_caps, self.in_caps, 1).cuda()
for route_iter in range(self.num_routing-1):
# (batch_size, num_caps, in_caps, 1) -> Softmax along num_caps
c = F.softmax(b, dim=1)
# element-wise multiplication
# (batch_size, num_caps, in_caps, 1) * (batch_size, in_caps, num_caps, dim_caps) ->
# (batch_size, num_caps, in_caps, dim_caps) sum across in_caps ->
# (batch_size, num_caps, dim_caps)
s = (c * temp_u_hat).sum(dim=2)
# apply "squashing" non-linearity along dim_caps
v = squash(s)
# dot product agreement between the current output vj and the prediction uj|i
# (batch_size, num_caps, in_caps, dim_caps) @ (batch_size, num_caps, dim_caps, 1)
# -> (batch_size, num_caps, in_caps, 1)
uv = torch.matmul(temp_u_hat, v.unsqueeze(-1))
b += uv
# last iteration is done on the original u_hat, without the routing weights update
c = F.softmax(b, dim=1)
s = (c * u_hat).sum(dim=2)
# apply "squashing" non-linearity along dim_caps
#v = squash(s)
return s
class RealOrFake(nn.Module):
def __init__(self, num_caps, dim_caps, dim_real, num_routing):
super(RealOrFake, self).__init__()
self.num_caps = num_caps
self.dim_caps = dim_caps
self.dim_real = dim_real
self.num_routing = num_routing
self.W = nn.Parameter(torch.randn(1, 1, self.num_caps, self.dim_real, self.dim_caps)*(2/(dim_real + dim_caps))**0.5)
def forward(self, x):
'''
x: squashed digit capsules, of dimensions: (batch_size, num_caps, dim_caps)
'''
batch_size = x.size(0)
# (batch_size, in_caps, in_dim) -> (batch_size, 1, in_caps, in_dim, 1)
x = x.unsqueeze(1).unsqueeze(4)
#
# W @ x =
# (1, num_caps, in_caps, dim_caps, in_dim) @ (batch_size, 1, in_caps, in_dim, 1) =
# (batch_size, num_caps, in_caps, dim_caps, 1)
u_hat = torch.matmul(self.W, x)
# (batch_size, num_caps, in_caps, dim_caps)
u_hat = u_hat.squeeze(-1)
# detach u_hat during routing iterations to prevent gradients from flowing
temp_u_hat = u_hat
'''
Procedure 1: Routing algorithm
'''
b = torch.zeros(batch_size, 1, self.num_caps, 1).cuda()
for route_iter in range(self.num_routing-1):
# (batch_size, num_caps, in_caps, 1) -> Softmax along num_caps
c = F.softmax(b, dim=1)
# element-wise multiplication
# (batch_size, num_caps, in_caps, 1) * (batch_size, in_caps, num_caps, dim_caps) ->
# (batch_size, num_caps, in_caps, dim_caps) sum across in_caps ->
# (batch_size, num_caps, dim_caps)
s = (c * temp_u_hat).sum(dim=2)
# apply "squashing" non-linearity along dim_caps
v = squash(s)
# dot product agreement between the current output vj and the prediction uj|i
# (batch_size, num_caps, in_caps, dim_caps) @ (batch_size, num_caps, dim_caps, 1)
# -> (batch_size, num_caps, in_caps, 1)
uv = torch.matmul(temp_u_hat, v.unsqueeze(-1))
b += uv
# last iteration is done on the original u_hat, without the routing weights update
c = F.softmax(b, dim=1)
s = (c * u_hat).sum(dim=2)
# apply "squashing" non-linearity along dim_caps
v = squash(s)
return v