Skip to content

Commit 55d8823

Browse files
committed
test cases for loading
1 parent 44b76e1 commit 55d8823

File tree

1 file changed

+259
-0
lines changed

1 file changed

+259
-0
lines changed

load_pretrained.ipynb

+259
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,259 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "code",
5+
"execution_count": null,
6+
"metadata": {},
7+
"outputs": [],
8+
"source": [
9+
"import torchvision\n",
10+
"import torchvision.models as models\n",
11+
"import torch.utils.model_zoo as model_zoo"
12+
]
13+
},
14+
{
15+
"cell_type": "markdown",
16+
"metadata": {},
17+
"source": [
18+
"### Configuration"
19+
]
20+
},
21+
{
22+
"cell_type": "code",
23+
"execution_count": null,
24+
"metadata": {},
25+
"outputs": [],
26+
"source": [
27+
"model_urls = {\n",
28+
" 'alexnet': 'https://download.pytorch.org/models/alexnet-owt-4df8aa71.pth',\n",
29+
" 'densenet121': 'https://download.pytorch.org/models/densenet121-241335ed.pth',\n",
30+
" 'densenet169': 'https://download.pytorch.org/models/densenet169-6f0f7f60.pth',\n",
31+
" 'densenet201': 'https://download.pytorch.org/models/densenet201-4c113574.pth',\n",
32+
" 'densenet161': 'https://download.pytorch.org/models/densenet161-17b70270.pth',\n",
33+
" #truncated _google to match module name\n",
34+
" 'inception_v3': 'https://download.pytorch.org/models/inception_v3_google-1a9a5a14.pth',\n",
35+
" 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',\n",
36+
" 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',\n",
37+
" 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',\n",
38+
" 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',\n",
39+
" 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', \n",
40+
" 'squeezenet1_0': 'https://download.pytorch.org/models/squeezenet1_0-a815701f.pth',\n",
41+
" 'squeezenet1_1': 'https://download.pytorch.org/models/squeezenet1_1-f364aa15.pth',\n",
42+
" 'vgg11': 'https://download.pytorch.org/models/vgg11-bbd30ac9.pth',\n",
43+
" 'vgg13': 'https://download.pytorch.org/models/vgg13-c768596a.pth',\n",
44+
" 'vgg16': 'https://download.pytorch.org/models/vgg16-397923af.pth',\n",
45+
" 'vgg19': 'https://download.pytorch.org/models/vgg19-dcbb9e9d.pth', \n",
46+
"}\n",
47+
"\n",
48+
"model_names = model_urls.keys()\n",
49+
"\n",
50+
"input_sizes = {\n",
51+
" 'alexnet' : (224,224),\n",
52+
" 'densenet': (224,224),\n",
53+
" 'resnet' : (224,224),\n",
54+
" 'inception' : (299,299),\n",
55+
" 'squeezenet' : (224,224),#not 255,255 acc. to https://github.com/pytorch/pytorch/issues/1120\n",
56+
" 'vgg' : (224,224)\n",
57+
"}\n",
58+
"\n",
59+
"#models_to_test = ['alexnet', 'densenet169', 'inception_v3', \\\n",
60+
"# 'resnet34', 'squeezenet1_1', 'vgg13']\n",
61+
"\n",
62+
"models_to_test = model_names"
63+
]
64+
},
65+
{
66+
"cell_type": "markdown",
67+
"metadata": {},
68+
"source": [
69+
"### Generic pretrained model loading"
70+
]
71+
},
72+
{
73+
"cell_type": "code",
74+
"execution_count": null,
75+
"metadata": {},
76+
"outputs": [],
77+
"source": [
78+
"#We solve the dimensionality mismatch between\n",
79+
"#final layers in the constructed vs pretrained\n",
80+
"#modules at the data level.\n",
81+
"def diff_states(dict_canonical, dict_subset):\n",
82+
" names1, names2 = (list(dict_canonical.keys()), list(dict_subset.keys()))\n",
83+
" \n",
84+
" #Sanity check that param names overlap\n",
85+
" #Note that params are not necessarily in the same order\n",
86+
" #for every pretrained model\n",
87+
" not_in_1 = [n for n in names1 if n not in names2]\n",
88+
" not_in_2 = [n for n in names2 if n not in names1]\n",
89+
" assert len(not_in_1) == 0\n",
90+
" assert len(not_in_2) == 0\n",
91+
"\n",
92+
" for name, v1 in dict_canonical.items():\n",
93+
" v2 = dict_subset[name]\n",
94+
" assert hasattr(v2, 'size')\n",
95+
" if v1.size() != v2.size():\n",
96+
" yield (name, v1) \n",
97+
"\n",
98+
"def load_model_named(name): \n",
99+
" #Densenets don't (yet) pass on num_classes, hack it in\n",
100+
" if \"densenet\" in name:\n",
101+
" if name == 'densenet169':\n",
102+
" return models.DenseNet(num_init_features=64, growth_rate=32, \\\n",
103+
" block_config=(6, 12, 32, 32), num_classes=num_classes)\n",
104+
" \n",
105+
" elif name == 'densenet121':\n",
106+
" return models.DenseNet(num_init_features=64, growth_rate=32, \\\n",
107+
" block_config=(6, 12, 24, 16), num_classes=num_classes)\n",
108+
" \n",
109+
" elif name == 'densenet201':\n",
110+
" return models.DenseNet(num_init_features=64, growth_rate=32, \\\n",
111+
" block_config=(6, 12, 48, 32), num_classes=num_classes)\n",
112+
"\n",
113+
" elif name == 'densenet161':\n",
114+
" return models.DenseNet(num_init_features=96, growth_rate=48, \\\n",
115+
" block_config=(6, 12, 36, 24), num_classes=num_classes)\n",
116+
" else:\n",
117+
" raise ValueError(\"Cirumventing missing num_classes kwargs not implemented for %s\" % name)\n",
118+
" \n",
119+
" return models.__dict__[name](num_classes=num_classes)\n",
120+
" \n",
121+
" \n",
122+
"def load_model(name, num_classes):\n",
123+
" \n",
124+
" model = load_model_named(name)\n",
125+
" pretrained_state = model_zoo.load_url(model_urls[name])\n",
126+
"\n",
127+
" #Diff\n",
128+
" diff = list(diff_states(model.state_dict(), pretrained_state))\n",
129+
" \n",
130+
" for name, value in diff:\n",
131+
" pretrained_state[name] = value\n",
132+
" \n",
133+
" assert len(list(diff_states(model.state_dict(), pretrained_state))) == 0\n",
134+
" \n",
135+
" #Merge\n",
136+
" model.load_state_dict(pretrained_state)\n",
137+
" return model, diff"
138+
]
139+
},
140+
{
141+
"cell_type": "code",
142+
"execution_count": null,
143+
"metadata": {},
144+
"outputs": [],
145+
"source": [
146+
"# Method to mutate module programmatically (PR #175)\n",
147+
"# https://github.com/pytorch/vision/pull/175\n",
148+
"\n",
149+
"def resize_network_output(net, output_size):\n",
150+
" if isinstance(net, torch.nn.DataParallel):\n",
151+
" return resize_network_output(net.module, output_size)\n",
152+
"\n",
153+
" # Edit: Can't index iterable in python3\n",
154+
" #output_layer = net._modules.keys()[-1]\n",
155+
" for output_layer in net._modules.keys():\n",
156+
" pass\n",
157+
" old_output_layer = net._modules[output_layer]\n",
158+
"\n",
159+
" if isinstance(old_output_layer, nn.Sequential):\n",
160+
" return resize_network_output(old_output_layer, output_size)\n",
161+
" elif isinstance(old_output_layer, nn.modules.pooling.AvgPool2d):\n",
162+
" # Go back in the layer sequence and find the last conv layer and resize that\n",
163+
" # Only happens for squeezenet1_0\n",
164+
" # Edit: iteritems deprecated in python3\n",
165+
" for name, layer in list(net._modules.items())[::-1][1:]:\n",
166+
" if isinstance(layer, nn.modules.conv.Conv2d):\n",
167+
" net._modules[name] = nn.modules.conv.Conv2d(layer.in_channels, output_size, layer.kernel_size,\n",
168+
" layer.stride, layer.padding, layer.dilation, layer.groups)\n",
169+
" return\n",
170+
" assert False\n",
171+
"\n",
172+
" assert isinstance(old_output_layer, nn.Linear), 'Class of old_output_layer {}'.format(old_output_layer.__class__.__name__)\n",
173+
" input_size = old_output_layer.weight.size()[1]\n",
174+
"\n",
175+
" net._modules[output_layer] = nn.Linear(input_size, output_size)\n",
176+
"\n",
177+
"\n",
178+
"def load_model_resize_post(name, num_classes):\n",
179+
" model = load_model_named(name)\n",
180+
" resize_network_output(model, num_classes)\n",
181+
" return model"
182+
]
183+
},
184+
{
185+
"cell_type": "markdown",
186+
"metadata": {},
187+
"source": [
188+
"## Compare generic loading methods"
189+
]
190+
},
191+
{
192+
"cell_type": "code",
193+
"execution_count": null,
194+
"metadata": {
195+
"collapsed": true
196+
},
197+
"outputs": [],
198+
"source": [
199+
"# If no cuda is present, unpickle fails with this net...\n",
200+
"# Need to update pretrained model with cpu to resolve?\n",
201+
"# models_to_test.remove('densenet169')"
202+
]
203+
},
204+
{
205+
"cell_type": "code",
206+
"execution_count": null,
207+
"metadata": {},
208+
"outputs": [],
209+
"source": [
210+
"num_classes = 10\n",
211+
"\n",
212+
"for name in models_to_test:\n",
213+
" print(\"\")\n",
214+
" print(name, \"with %d classes\" % num_classes)\n",
215+
" try:\n",
216+
" model_merged, diff = load_model(name, num_classes)\n",
217+
" diff_vanilla = [d[0] for d in diff]\n",
218+
" result = (\"... merge loading: \" + str(diff_vanilla)).ljust(99) \\\n",
219+
" + \"OK\" if len(diff_vanilla) > 0 else \"X\"\n",
220+
" except Exception as e:\n",
221+
" result = (\"... merge loading: \" + str(e)).ljust(99) + \"X\"\n",
222+
" finally:\n",
223+
" print(result)\n",
224+
" \n",
225+
" try:\n",
226+
" model_resized = load_model_resize_post(name, num_classes)\n",
227+
" diff_merged_resized = [p[0] for p in \\\n",
228+
" diff_states(model_merged.state_dict(), model_resized.state_dict())]\n",
229+
" result = (\"... resizing after load: \" + str(diff_merged_resized)).ljust(99) \\\n",
230+
" + \"OK\" if len(diff_merged_resized) == 0 else \"X\"\n",
231+
" except Exception as e:\n",
232+
" result = (\"... resizing after load: \" + str(e)).ljust(99) + \"X\"\n",
233+
" finally:\n",
234+
" print(result) "
235+
]
236+
}
237+
],
238+
"metadata": {
239+
"kernelspec": {
240+
"display_name": "Python 3",
241+
"language": "python",
242+
"name": "python3"
243+
},
244+
"language_info": {
245+
"codemirror_mode": {
246+
"name": "ipython",
247+
"version": 3
248+
},
249+
"file_extension": ".py",
250+
"mimetype": "text/x-python",
251+
"name": "python",
252+
"nbconvert_exporter": "python",
253+
"pygments_lexer": "ipython3",
254+
"version": "3.6.0"
255+
}
256+
},
257+
"nbformat": 4,
258+
"nbformat_minor": 2
259+
}

0 commit comments

Comments
 (0)