|
| 1 | +{ |
| 2 | + "cells": [ |
| 3 | + { |
| 4 | + "cell_type": "code", |
| 5 | + "execution_count": null, |
| 6 | + "metadata": {}, |
| 7 | + "outputs": [], |
| 8 | + "source": [ |
| 9 | + "import torchvision\n", |
| 10 | + "import torchvision.models as models\n", |
| 11 | + "import torch.utils.model_zoo as model_zoo" |
| 12 | + ] |
| 13 | + }, |
| 14 | + { |
| 15 | + "cell_type": "markdown", |
| 16 | + "metadata": {}, |
| 17 | + "source": [ |
| 18 | + "### Configuration" |
| 19 | + ] |
| 20 | + }, |
| 21 | + { |
| 22 | + "cell_type": "code", |
| 23 | + "execution_count": null, |
| 24 | + "metadata": {}, |
| 25 | + "outputs": [], |
| 26 | + "source": [ |
| 27 | + "model_urls = {\n", |
| 28 | + " 'alexnet': 'https://download.pytorch.org/models/alexnet-owt-4df8aa71.pth',\n", |
| 29 | + " 'densenet121': 'https://download.pytorch.org/models/densenet121-241335ed.pth',\n", |
| 30 | + " 'densenet169': 'https://download.pytorch.org/models/densenet169-6f0f7f60.pth',\n", |
| 31 | + " 'densenet201': 'https://download.pytorch.org/models/densenet201-4c113574.pth',\n", |
| 32 | + " 'densenet161': 'https://download.pytorch.org/models/densenet161-17b70270.pth',\n", |
| 33 | + " #truncated _google to match module name\n", |
| 34 | + " 'inception_v3': 'https://download.pytorch.org/models/inception_v3_google-1a9a5a14.pth',\n", |
| 35 | + " 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',\n", |
| 36 | + " 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',\n", |
| 37 | + " 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',\n", |
| 38 | + " 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',\n", |
| 39 | + " 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', \n", |
| 40 | + " 'squeezenet1_0': 'https://download.pytorch.org/models/squeezenet1_0-a815701f.pth',\n", |
| 41 | + " 'squeezenet1_1': 'https://download.pytorch.org/models/squeezenet1_1-f364aa15.pth',\n", |
| 42 | + " 'vgg11': 'https://download.pytorch.org/models/vgg11-bbd30ac9.pth',\n", |
| 43 | + " 'vgg13': 'https://download.pytorch.org/models/vgg13-c768596a.pth',\n", |
| 44 | + " 'vgg16': 'https://download.pytorch.org/models/vgg16-397923af.pth',\n", |
| 45 | + " 'vgg19': 'https://download.pytorch.org/models/vgg19-dcbb9e9d.pth', \n", |
| 46 | + "}\n", |
| 47 | + "\n", |
| 48 | + "model_names = model_urls.keys()\n", |
| 49 | + "\n", |
| 50 | + "input_sizes = {\n", |
| 51 | + " 'alexnet' : (224,224),\n", |
| 52 | + " 'densenet': (224,224),\n", |
| 53 | + " 'resnet' : (224,224),\n", |
| 54 | + " 'inception' : (299,299),\n", |
| 55 | + " 'squeezenet' : (224,224),#not 255,255 acc. to https://github.com/pytorch/pytorch/issues/1120\n", |
| 56 | + " 'vgg' : (224,224)\n", |
| 57 | + "}\n", |
| 58 | + "\n", |
| 59 | + "#models_to_test = ['alexnet', 'densenet169', 'inception_v3', \\\n", |
| 60 | + "# 'resnet34', 'squeezenet1_1', 'vgg13']\n", |
| 61 | + "\n", |
| 62 | + "models_to_test = model_names" |
| 63 | + ] |
| 64 | + }, |
| 65 | + { |
| 66 | + "cell_type": "markdown", |
| 67 | + "metadata": {}, |
| 68 | + "source": [ |
| 69 | + "### Generic pretrained model loading" |
| 70 | + ] |
| 71 | + }, |
| 72 | + { |
| 73 | + "cell_type": "code", |
| 74 | + "execution_count": null, |
| 75 | + "metadata": {}, |
| 76 | + "outputs": [], |
| 77 | + "source": [ |
| 78 | + "#We solve the dimensionality mismatch between\n", |
| 79 | + "#final layers in the constructed vs pretrained\n", |
| 80 | + "#modules at the data level.\n", |
| 81 | + "def diff_states(dict_canonical, dict_subset):\n", |
| 82 | + " names1, names2 = (list(dict_canonical.keys()), list(dict_subset.keys()))\n", |
| 83 | + " \n", |
| 84 | + " #Sanity check that param names overlap\n", |
| 85 | + " #Note that params are not necessarily in the same order\n", |
| 86 | + " #for every pretrained model\n", |
| 87 | + " not_in_1 = [n for n in names1 if n not in names2]\n", |
| 88 | + " not_in_2 = [n for n in names2 if n not in names1]\n", |
| 89 | + " assert len(not_in_1) == 0\n", |
| 90 | + " assert len(not_in_2) == 0\n", |
| 91 | + "\n", |
| 92 | + " for name, v1 in dict_canonical.items():\n", |
| 93 | + " v2 = dict_subset[name]\n", |
| 94 | + " assert hasattr(v2, 'size')\n", |
| 95 | + " if v1.size() != v2.size():\n", |
| 96 | + " yield (name, v1) \n", |
| 97 | + "\n", |
| 98 | + "def load_model_named(name): \n", |
| 99 | + " #Densenets don't (yet) pass on num_classes, hack it in\n", |
| 100 | + " if \"densenet\" in name:\n", |
| 101 | + " if name == 'densenet169':\n", |
| 102 | + " return models.DenseNet(num_init_features=64, growth_rate=32, \\\n", |
| 103 | + " block_config=(6, 12, 32, 32), num_classes=num_classes)\n", |
| 104 | + " \n", |
| 105 | + " elif name == 'densenet121':\n", |
| 106 | + " return models.DenseNet(num_init_features=64, growth_rate=32, \\\n", |
| 107 | + " block_config=(6, 12, 24, 16), num_classes=num_classes)\n", |
| 108 | + " \n", |
| 109 | + " elif name == 'densenet201':\n", |
| 110 | + " return models.DenseNet(num_init_features=64, growth_rate=32, \\\n", |
| 111 | + " block_config=(6, 12, 48, 32), num_classes=num_classes)\n", |
| 112 | + "\n", |
| 113 | + " elif name == 'densenet161':\n", |
| 114 | + " return models.DenseNet(num_init_features=96, growth_rate=48, \\\n", |
| 115 | + " block_config=(6, 12, 36, 24), num_classes=num_classes)\n", |
| 116 | + " else:\n", |
| 117 | + " raise ValueError(\"Cirumventing missing num_classes kwargs not implemented for %s\" % name)\n", |
| 118 | + " \n", |
| 119 | + " return models.__dict__[name](num_classes=num_classes)\n", |
| 120 | + " \n", |
| 121 | + " \n", |
| 122 | + "def load_model(name, num_classes):\n", |
| 123 | + " \n", |
| 124 | + " model = load_model_named(name)\n", |
| 125 | + " pretrained_state = model_zoo.load_url(model_urls[name])\n", |
| 126 | + "\n", |
| 127 | + " #Diff\n", |
| 128 | + " diff = list(diff_states(model.state_dict(), pretrained_state))\n", |
| 129 | + " \n", |
| 130 | + " for name, value in diff:\n", |
| 131 | + " pretrained_state[name] = value\n", |
| 132 | + " \n", |
| 133 | + " assert len(list(diff_states(model.state_dict(), pretrained_state))) == 0\n", |
| 134 | + " \n", |
| 135 | + " #Merge\n", |
| 136 | + " model.load_state_dict(pretrained_state)\n", |
| 137 | + " return model, diff" |
| 138 | + ] |
| 139 | + }, |
| 140 | + { |
| 141 | + "cell_type": "code", |
| 142 | + "execution_count": null, |
| 143 | + "metadata": {}, |
| 144 | + "outputs": [], |
| 145 | + "source": [ |
| 146 | + "# Method to mutate module programmatically (PR #175)\n", |
| 147 | + "# https://github.com/pytorch/vision/pull/175\n", |
| 148 | + "\n", |
| 149 | + "def resize_network_output(net, output_size):\n", |
| 150 | + " if isinstance(net, torch.nn.DataParallel):\n", |
| 151 | + " return resize_network_output(net.module, output_size)\n", |
| 152 | + "\n", |
| 153 | + " # Edit: Can't index iterable in python3\n", |
| 154 | + " #output_layer = net._modules.keys()[-1]\n", |
| 155 | + " for output_layer in net._modules.keys():\n", |
| 156 | + " pass\n", |
| 157 | + " old_output_layer = net._modules[output_layer]\n", |
| 158 | + "\n", |
| 159 | + " if isinstance(old_output_layer, nn.Sequential):\n", |
| 160 | + " return resize_network_output(old_output_layer, output_size)\n", |
| 161 | + " elif isinstance(old_output_layer, nn.modules.pooling.AvgPool2d):\n", |
| 162 | + " # Go back in the layer sequence and find the last conv layer and resize that\n", |
| 163 | + " # Only happens for squeezenet1_0\n", |
| 164 | + " # Edit: iteritems deprecated in python3\n", |
| 165 | + " for name, layer in list(net._modules.items())[::-1][1:]:\n", |
| 166 | + " if isinstance(layer, nn.modules.conv.Conv2d):\n", |
| 167 | + " net._modules[name] = nn.modules.conv.Conv2d(layer.in_channels, output_size, layer.kernel_size,\n", |
| 168 | + " layer.stride, layer.padding, layer.dilation, layer.groups)\n", |
| 169 | + " return\n", |
| 170 | + " assert False\n", |
| 171 | + "\n", |
| 172 | + " assert isinstance(old_output_layer, nn.Linear), 'Class of old_output_layer {}'.format(old_output_layer.__class__.__name__)\n", |
| 173 | + " input_size = old_output_layer.weight.size()[1]\n", |
| 174 | + "\n", |
| 175 | + " net._modules[output_layer] = nn.Linear(input_size, output_size)\n", |
| 176 | + "\n", |
| 177 | + "\n", |
| 178 | + "def load_model_resize_post(name, num_classes):\n", |
| 179 | + " model = load_model_named(name)\n", |
| 180 | + " resize_network_output(model, num_classes)\n", |
| 181 | + " return model" |
| 182 | + ] |
| 183 | + }, |
| 184 | + { |
| 185 | + "cell_type": "markdown", |
| 186 | + "metadata": {}, |
| 187 | + "source": [ |
| 188 | + "## Compare generic loading methods" |
| 189 | + ] |
| 190 | + }, |
| 191 | + { |
| 192 | + "cell_type": "code", |
| 193 | + "execution_count": null, |
| 194 | + "metadata": { |
| 195 | + "collapsed": true |
| 196 | + }, |
| 197 | + "outputs": [], |
| 198 | + "source": [ |
| 199 | + "# If no cuda is present, unpickle fails with this net...\n", |
| 200 | + "# Need to update pretrained model with cpu to resolve?\n", |
| 201 | + "# models_to_test.remove('densenet169')" |
| 202 | + ] |
| 203 | + }, |
| 204 | + { |
| 205 | + "cell_type": "code", |
| 206 | + "execution_count": null, |
| 207 | + "metadata": {}, |
| 208 | + "outputs": [], |
| 209 | + "source": [ |
| 210 | + "num_classes = 10\n", |
| 211 | + "\n", |
| 212 | + "for name in models_to_test:\n", |
| 213 | + " print(\"\")\n", |
| 214 | + " print(name, \"with %d classes\" % num_classes)\n", |
| 215 | + " try:\n", |
| 216 | + " model_merged, diff = load_model(name, num_classes)\n", |
| 217 | + " diff_vanilla = [d[0] for d in diff]\n", |
| 218 | + " result = (\"... merge loading: \" + str(diff_vanilla)).ljust(99) \\\n", |
| 219 | + " + \"OK\" if len(diff_vanilla) > 0 else \"X\"\n", |
| 220 | + " except Exception as e:\n", |
| 221 | + " result = (\"... merge loading: \" + str(e)).ljust(99) + \"X\"\n", |
| 222 | + " finally:\n", |
| 223 | + " print(result)\n", |
| 224 | + " \n", |
| 225 | + " try:\n", |
| 226 | + " model_resized = load_model_resize_post(name, num_classes)\n", |
| 227 | + " diff_merged_resized = [p[0] for p in \\\n", |
| 228 | + " diff_states(model_merged.state_dict(), model_resized.state_dict())]\n", |
| 229 | + " result = (\"... resizing after load: \" + str(diff_merged_resized)).ljust(99) \\\n", |
| 230 | + " + \"OK\" if len(diff_merged_resized) == 0 else \"X\"\n", |
| 231 | + " except Exception as e:\n", |
| 232 | + " result = (\"... resizing after load: \" + str(e)).ljust(99) + \"X\"\n", |
| 233 | + " finally:\n", |
| 234 | + " print(result) " |
| 235 | + ] |
| 236 | + } |
| 237 | + ], |
| 238 | + "metadata": { |
| 239 | + "kernelspec": { |
| 240 | + "display_name": "Python 3", |
| 241 | + "language": "python", |
| 242 | + "name": "python3" |
| 243 | + }, |
| 244 | + "language_info": { |
| 245 | + "codemirror_mode": { |
| 246 | + "name": "ipython", |
| 247 | + "version": 3 |
| 248 | + }, |
| 249 | + "file_extension": ".py", |
| 250 | + "mimetype": "text/x-python", |
| 251 | + "name": "python", |
| 252 | + "nbconvert_exporter": "python", |
| 253 | + "pygments_lexer": "ipython3", |
| 254 | + "version": "3.6.0" |
| 255 | + } |
| 256 | + }, |
| 257 | + "nbformat": 4, |
| 258 | + "nbformat_minor": 2 |
| 259 | +} |
0 commit comments