forked from motokimura/capsnet_pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcaps_layers.py
executable file
·146 lines (115 loc) · 5.4 KB
/
caps_layers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
#
# Dynamic Routing Between Capsules
# https://arxiv.org/pdf/1710.09829.pdf
#
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
from torchvision import datasets, transforms
import torch.nn.functional as F
import math
def squash(x, dim=2):
v_length_sq = x.pow(2).sum(dim=dim, keepdim=True)
v_length = torch.sqrt(v_length_sq)
scaling_factor = v_length_sq / (1 + v_length_sq) / v_length
return x * scaling_factor
class PrimaryCaps(nn.Module):
"""
PrimaryCaps layers.
"""
def __init__(self, in_channels, out_capsules, out_capsule_dim,
kernel_size=9, stride=2):
super(PrimaryCaps, self).__init__()
self.in_channels = in_channels
self.out_capsules = out_capsules
self.out_capsule_dim = out_capsule_dim
self.capsules = nn.Conv2d(
in_channels=in_channels,
out_channels=out_capsules * out_capsule_dim,
kernel_size=kernel_size,
stride=stride,
bias=True
)
def forward(self, x):
"""
Revise based on adambielski's implementation.
ref: https://github.com/adambielski/CapsNet-pytorch/blob/master/net.py
"""
# x: [batch_size, in_channels=256, 20, 20]
batch_size = x.size(0)
out = self.capsules(x)
# out: [batch_size, out_capsules=32 * out_capsule_dim=8 = 256, 6, 6]
_, C, H, W = out.size()
out = out.view(batch_size, self.out_capsules, self.out_capsule_dim, H, W)
out = out.permute(0, 1, 3, 4, 2).contiguous()
out = out.view(out.size(0), -1, out.size(4))
# u: [batch_size, 32 * 6 * 6=1152, 8]
# Squash vectors
out = squash(out, dim=2)
return out
class DigitCaps(nn.Module):
def __init__(self, in_capsules, in_capsule_dim, out_capsules, out_capsule_dim,
routing_iters=3):
super(DigitCaps, self).__init__()
self.routing_iters = routing_iters
self.in_capsules = in_capsules
self.in_capsule_dim = in_capsule_dim
self.out_capsules = out_capsules
self.out_capsule_dim = out_capsule_dim
self.W = nn.Parameter(
torch.Tensor(
self.in_capsules,
self.out_capsules,
self.out_capsule_dim,
self.in_capsule_dim
)
)
# W: [in_capsules, out_capsules, out_capsule_dim, in_capsule_dim] = [1152, 10, 16, 8]
self.reset_parameters()
def reset_parameters(self):
stdv = 1. / math.sqrt(self.in_capsules)
self.W.data.uniform_(-stdv, stdv)
def forward(self, x):
# x: [batch_size, in_capsules=1152, in_capsule_dim=8]
batch_size = x.size(0)
x = torch.stack([x] * self.out_capsules, dim=2)
# x: [batch_size, in_capsules=1152, out_capsules=10, in_capsule_dim=8]
W = torch.cat([self.W.unsqueeze(0)] * batch_size, dim=0)
# W: [batch_size, in_capsules=1152, out_capsules=10, out_capsule_dim=16, in_capsule_dim=8]
# Transform inputs by weight matrix `W`.
u_hat = torch.matmul(W, x.unsqueeze(4)) # matrix multiplication
# u_hat: [batch_size, in_capsules=1152, out_capsules=10, out_capsule_dim=16, 1]
u_hat_detached = u_hat.detach()
# u_hat_detached: [batch_size, in_capsules=1152, out_capsules=10, out_capsule_dim=16, 1]
# In forward pass, `u_hat_detached` = `u_hat`, and
# in backward, no gradient can flow from `u_hat_detached` back to `u_hat`.
# Initialize routing logits to zero.
b_ij = Variable(torch.zeros(batch_size, self.in_capsules, self.out_capsules, 1))
if torch.cuda.is_available():
b_ij = b_ij.cuda()
# b_ij: [batch_size, in_capsules=1152, out_capsules=10, 1]
# Iterative routing.
for iteration in range(self.routing_iters):
# Convert routing logits to softmax.
c_ij = F.softmax(b_ij, dim=2).unsqueeze(4)
# c_ij: [batch_size, in_capsules=1152, out_capsules=10, 1, 1]
if iteration == self.routing_iters - 1:
# Apply routing `c_ij` to weighted inputs `u_hat`.
s_j = (c_ij * u_hat).sum(dim=1, keepdim=True) # element-wise product
# s_j: [batch_size, 1, out_capsules=10, out_capsule_dim=16, 1]
v_j = squash(s_j, dim=3)
# v_j: [batch_size, 1, out_capsules=10, out_capsule_dim=16, 1]
else:
# Apply routing `c_ij` to weighted inputs `u_hat`.
s_j = (c_ij * u_hat_detached).sum(dim=1, keepdim=True) # element-wise product
# s_j: [batch_size, 1, out_capsules=10, out_capsule_dim=16, 1]
v_j = squash(s_j, dim=3)
# v_j: [batch_size, 1, out_capsules=10, out_capsule_dim=16, 1]
# Compute inner products of 2 16D-vectors, `u_hat` and `v_j`.
u_vj1 = torch.matmul(u_hat_detached.transpose(3, 4), v_j).squeeze(4)
# u_vj1: [batch_size, in_capsules=1152, out_capsules=10, 1]
# Not calculate batch mean.
# Update b_ij (routing).
b_ij = b_ij + u_vj1
return v_j.squeeze(4).squeeze(1) # [batch_size, out_capsules=10, out_capsule_dim=16]