|
79 | 79 | }
|
80 | 80 | ],
|
81 | 81 | "source": [
|
82 |
| - "os.listdir()" |
| 82 | + "#Modified Decoder\n", |
| 83 | + "\n", |
| 84 | + "class Decoder(nn.Module):\n", |
| 85 | + " def __init__(self, input_width=28, input_height=28, input_channel=1):\n", |
| 86 | + " super(Decoder, self).__init__()\n", |
| 87 | + " self.input_width = input_width\n", |
| 88 | + " self.input_height = input_height\n", |
| 89 | + " self.input_channel = input_channel\n", |
| 90 | + " self.reconstraction_layers = nn.Sequential(\n", |
| 91 | + " nn.Linear(16 * 10, 512),\n", |
| 92 | + " nn.ReLU(inplace=True),\n", |
| 93 | + " nn.Linear(512, 1024),\n", |
| 94 | + " nn.ReLU(inplace=True),\n", |
| 95 | + " nn.Linear(1024, self.input_height * self.input_height * self.input_channel),\n", |
| 96 | + " nn.Sigmoid()\n", |
| 97 | + " )\n", |
| 98 | + "\n", |
| 99 | + " def forward(self, x, data):\n", |
| 100 | + " classes = torch.sqrt((x ** 2).sum(2))\n", |
| 101 | + " classes = F.softmax(classes.squeeze(), dim=1)\n", |
| 102 | + "\n", |
| 103 | + " _, max_length_indices = classes.max(dim=1)\n", |
| 104 | + " masked = torch.sparse.torch.eye(10)\n", |
| 105 | + " if USE_CUDA:\n", |
| 106 | + " masked = masked.cuda()\n", |
| 107 | + " masked = masked.index_select(dim=0, index=max_length_indices.squeeze().data)\n", |
| 108 | + "# t = (x * masked[:, :, None, None]).view(x.size(0), -1)\n", |
| 109 | + " reconstructions = self.reconstraction_layers(x)\n", |
| 110 | + " reconstructions = reconstructions.view(-1, self.input_channel, self.input_width, self.input_height)\n", |
| 111 | + " return reconstructions, None" |
83 | 112 | ]
|
84 | 113 | },
|
85 | 114 | {
|
|
161 | 190 | "cell_type": "markdown",
|
162 | 191 | "metadata": {},
|
163 | 192 | "source": [
|
164 |
| - "## Training CapsuleNet" |
| 193 | + "## The histogram of L2 distances between the input and the reconstruction using the\n", |
| 194 | + "## correct capsule or other capsules in CapsNet on the real MNIST images. " |
165 | 195 | ]
|
166 | 196 | },
|
167 | 197 | {
|
|
0 commit comments