|
2 | 2 | "cells": [
|
3 | 3 | {
|
4 | 4 | "cell_type": "code",
|
5 |
| - "execution_count": null, |
| 5 | + "execution_count": 8, |
6 | 6 | "metadata": {},
|
7 | 7 | "outputs": [],
|
8 | 8 | "source": [
|
|
12 | 12 | "import torch.nn as nn\n",
|
13 | 13 | "import torch.nn.functional as F\n",
|
14 | 14 | "from tqdm import tqdm\n",
|
| 15 | + "import os\n", |
15 | 16 | "\n",
|
16 | 17 | "from model.net import *\n",
|
17 | 18 | "from utils.training import *\n",
|
|
27 | 28 | },
|
28 | 29 | {
|
29 | 30 | "cell_type": "code",
|
30 |
| - "execution_count": null, |
| 31 | + "execution_count": 10, |
31 | 32 | "metadata": {},
|
32 | 33 | "outputs": [],
|
33 | 34 | "source": [
|
| 35 | + "model_path = os.getcwd()\n", |
| 36 | + "\n", |
34 | 37 | "args = {\n",
|
35 | 38 | " 'USE_CUDA': True if torch.cuda.is_available() else False,\n",
|
36 | 39 | " 'BATCH_SIZE': 32,\n",
|
|
50 | 53 | },
|
51 | 54 | {
|
52 | 55 | "cell_type": "code",
|
53 |
| - "execution_count": null, |
| 56 | + "execution_count": 13, |
54 | 57 | "metadata": {},
|
55 |
| - "outputs": [], |
| 58 | + "outputs": [ |
| 59 | + { |
| 60 | + "data": { |
| 61 | + "text/plain": [ |
| 62 | + "['data',\n", |
| 63 | + " 'model',\n", |
| 64 | + " 'utils',\n", |
| 65 | + " 'LICENSE',\n", |
| 66 | + " '.gitignore',\n", |
| 67 | + " '.git',\n", |
| 68 | + " 'Capsule Network Train.ipynb',\n", |
| 69 | + " 'Pretrain_Capsule.ipynb',\n", |
| 70 | + " 'L2_recon.ipynb',\n", |
| 71 | + " '.ipynb_checkpoints',\n", |
| 72 | + " 'README.md',\n", |
| 73 | + " 'CapsNetMNIS.pth ']" |
| 74 | + ] |
| 75 | + }, |
| 76 | + "execution_count": 13, |
| 77 | + "metadata": {}, |
| 78 | + "output_type": "execute_result" |
| 79 | + } |
| 80 | + ], |
| 81 | + "source": [ |
| 82 | + "os.listdir()" |
| 83 | + ] |
| 84 | + }, |
| 85 | + { |
| 86 | + "cell_type": "code", |
| 87 | + "execution_count": 15, |
| 88 | + "metadata": {}, |
| 89 | + "outputs": [ |
| 90 | + { |
| 91 | + "data": { |
| 92 | + "text/plain": [ |
| 93 | + "<All keys matched successfully>" |
| 94 | + ] |
| 95 | + }, |
| 96 | + "execution_count": 15, |
| 97 | + "metadata": {}, |
| 98 | + "output_type": "execute_result" |
| 99 | + } |
| 100 | + ], |
56 | 101 | "source": [
|
57 | 102 | "#Config for 49 16d vectors in the Primary Capsule. Set Softmax dimension to 0 in this case\n",
|
58 | 103 | "class Config:\n",
|
|
93 | 138 | "if args['USE_CUDA']:\n",
|
94 | 139 | " net = net.cuda()\n",
|
95 | 140 | " \n",
|
96 |
| - "net.load_state_dict(torch.load('./CapsNetMNIST.pth'), map_location='cpu')" |
| 141 | + "net.load_state_dict(torch.load(os.path.join(model_path, 'CapsNetMNIST.pth'), map_location='cpu'))" |
97 | 142 | ]
|
98 | 143 | },
|
99 | 144 | {
|
|
134 | 179 | "metadata": {},
|
135 | 180 | "outputs": [],
|
136 | 181 | "source": [
|
137 |
| - "# torch.save(capsule_net.state_dict(), \"./CapsNetMNIST.pth \")" |
| 182 | + "# torch.save(capsule_net.state_dict(), \"./CapsNetMNIST.pth\")" |
138 | 183 | ]
|
139 | 184 | },
|
140 | 185 | {
|
|
0 commit comments