Skip to content

Commit 047eaa1

Browse files
committed
recon trainer
1 parent 9833cff commit 047eaa1

File tree

2 files changed

+444
-2
lines changed

2 files changed

+444
-2
lines changed

L2_recon.ipynb

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,36 @@
7979
}
8080
],
8181
"source": [
82-
"os.listdir()"
82+
"#Modified Decoder\n",
83+
"\n",
84+
"class Decoder(nn.Module):\n",
85+
" def __init__(self, input_width=28, input_height=28, input_channel=1):\n",
86+
" super(Decoder, self).__init__()\n",
87+
" self.input_width = input_width\n",
88+
" self.input_height = input_height\n",
89+
" self.input_channel = input_channel\n",
90+
" self.reconstraction_layers = nn.Sequential(\n",
91+
" nn.Linear(16 * 10, 512),\n",
92+
" nn.ReLU(inplace=True),\n",
93+
" nn.Linear(512, 1024),\n",
94+
" nn.ReLU(inplace=True),\n",
95+
" nn.Linear(1024, self.input_height * self.input_height * self.input_channel),\n",
96+
" nn.Sigmoid()\n",
97+
" )\n",
98+
"\n",
99+
" def forward(self, x, data):\n",
100+
" classes = torch.sqrt((x ** 2).sum(2))\n",
101+
" classes = F.softmax(classes.squeeze(), dim=1)\n",
102+
"\n",
103+
" _, max_length_indices = classes.max(dim=1)\n",
104+
" masked = torch.sparse.torch.eye(10)\n",
105+
" if USE_CUDA:\n",
106+
" masked = masked.cuda()\n",
107+
" masked = masked.index_select(dim=0, index=max_length_indices.squeeze().data)\n",
108+
"# t = (x * masked[:, :, None, None]).view(x.size(0), -1)\n",
109+
" reconstructions = self.reconstraction_layers(x)\n",
110+
" reconstructions = reconstructions.view(-1, self.input_channel, self.input_width, self.input_height)\n",
111+
" return reconstructions, None"
83112
]
84113
},
85114
{
@@ -161,7 +190,8 @@
161190
"cell_type": "markdown",
162191
"metadata": {},
163192
"source": [
164-
"## Training CapsuleNet"
193+
"## The histogram of L2 distances between the input and the reconstruction using the\n",
194+
"## correct capsule or other capsules in CapsNet on the real MNIST images. "
165195
]
166196
},
167197
{

0 commit comments

Comments
 (0)