Skip to content

Commit 9833cff

Browse files
committed
addpretrainedmodel
1 parent 31c00ea commit 9833cff

File tree

2 files changed

+51
-6
lines changed

2 files changed

+51
-6
lines changed

CapsNetMNIST.pth

8.28 MB
Binary file not shown.

L2_recon.ipynb

Lines changed: 51 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
"cells": [
33
{
44
"cell_type": "code",
5-
"execution_count": null,
5+
"execution_count": 8,
66
"metadata": {},
77
"outputs": [],
88
"source": [
@@ -12,6 +12,7 @@
1212
"import torch.nn as nn\n",
1313
"import torch.nn.functional as F\n",
1414
"from tqdm import tqdm\n",
15+
"import os\n",
1516
"\n",
1617
"from model.net import *\n",
1718
"from utils.training import *\n",
@@ -27,10 +28,12 @@
2728
},
2829
{
2930
"cell_type": "code",
30-
"execution_count": null,
31+
"execution_count": 10,
3132
"metadata": {},
3233
"outputs": [],
3334
"source": [
35+
"model_path = os.getcwd()\n",
36+
"\n",
3437
"args = {\n",
3538
" 'USE_CUDA': True if torch.cuda.is_available() else False,\n",
3639
" 'BATCH_SIZE': 32,\n",
@@ -50,9 +53,51 @@
5053
},
5154
{
5255
"cell_type": "code",
53-
"execution_count": null,
56+
"execution_count": 13,
5457
"metadata": {},
55-
"outputs": [],
58+
"outputs": [
59+
{
60+
"data": {
61+
"text/plain": [
62+
"['data',\n",
63+
" 'model',\n",
64+
" 'utils',\n",
65+
" 'LICENSE',\n",
66+
" '.gitignore',\n",
67+
" '.git',\n",
68+
" 'Capsule Network Train.ipynb',\n",
69+
" 'Pretrain_Capsule.ipynb',\n",
70+
" 'L2_recon.ipynb',\n",
71+
" '.ipynb_checkpoints',\n",
72+
" 'README.md',\n",
73+
" 'CapsNetMNIS.pth ']"
74+
]
75+
},
76+
"execution_count": 13,
77+
"metadata": {},
78+
"output_type": "execute_result"
79+
}
80+
],
81+
"source": [
82+
"os.listdir()"
83+
]
84+
},
85+
{
86+
"cell_type": "code",
87+
"execution_count": 15,
88+
"metadata": {},
89+
"outputs": [
90+
{
91+
"data": {
92+
"text/plain": [
93+
"<All keys matched successfully>"
94+
]
95+
},
96+
"execution_count": 15,
97+
"metadata": {},
98+
"output_type": "execute_result"
99+
}
100+
],
56101
"source": [
57102
"#Config for 49 16d vectors in the Primary Capsule. Set Softmax dimension to 0 in this case\n",
58103
"class Config:\n",
@@ -93,7 +138,7 @@
93138
"if args['USE_CUDA']:\n",
94139
" net = net.cuda()\n",
95140
" \n",
96-
"net.load_state_dict(torch.load('./CapsNetMNIST.pth'), map_location='cpu')"
141+
"net.load_state_dict(torch.load(os.path.join(model_path, 'CapsNetMNIST.pth'), map_location='cpu'))"
97142
]
98143
},
99144
{
@@ -134,7 +179,7 @@
134179
"metadata": {},
135180
"outputs": [],
136181
"source": [
137-
"# torch.save(capsule_net.state_dict(), \"./CapsNetMNIST.pth \")"
182+
"# torch.save(capsule_net.state_dict(), \"./CapsNetMNIST.pth\")"
138183
]
139184
},
140185
{

0 commit comments

Comments
 (0)