|
22 | 22 | "import torch.nn.functional as F\n",
|
23 | 23 | "import torchvision\n",
|
24 | 24 | "import torchvision.transforms as transforms\n",
|
25 |
| - "from qtorch.quant import Quantizer\n", |
| 25 | + "from qtorch.quant import Quantizer, quantizer\n", |
26 | 26 | "from qtorch.optim import OptimLP\n",
|
27 | 27 | "from torch.optim import SGD\n",
|
28 | 28 | "from qtorch import FloatingPoint\n",
|
29 |
| - "from tqdm import tqdm" |
| 29 | + "from tqdm import tqdm\n", |
| 30 | + "import math" |
30 | 31 | ]
|
31 | 32 | },
|
32 | 33 | {
|
|
103 | 104 | "bit_16 = FloatingPoint(exp=6, man=9)\n",
|
104 | 105 | "\n",
|
105 | 106 | "# define quantization functions\n",
|
106 |
| - "weight_quant = Quantizer(forward_number=bit_8, backward_number=None,\n", |
107 |
| - " forward_rounding=\"nearest\", backward_rounding=\"nearest\")\n", |
108 |
| - "grad_quant = Quantizer(forward_number=bit_8, backward_number=None,\n", |
109 |
| - " forward_rounding=\"nearest\", backward_rounding=\"stochastic\")\n", |
110 |
| - "momentum_quant = Quantizer(forward_number=bit_16, backward_number=None,\n", |
111 |
| - " forward_rounding=\"nearest\", backward_rounding=\"stochastic\")\n", |
112 |
| - "acc_quant = Quantizer(forward_number=bit_16, backward_number=None,\n", |
113 |
| - " forward_rounding=\"nearest\", backward_rounding=\"nearest\")\n", |
| 107 | + "weight_quant = quantizer(forward_number=bit_8,\n", |
| 108 | + " forward_rounding=\"nearest\")\n", |
| 109 | + "grad_quant = quantizer(forward_number=bit_8,\n", |
| 110 | + " forward_rounding=\"nearest\")\n", |
| 111 | + "momentum_quant = quantizer(forward_number=bit_16,\n", |
| 112 | + " forward_rounding=\"stochastic\")\n", |
| 113 | + "acc_quant = quantizer(forward_number=bit_16,\n", |
| 114 | + " forward_rounding=\"stochastic\")\n", |
114 | 115 | "\n",
|
115 | 116 | "# define a lambda function so that the Quantizer module can be duplicated easily\n",
|
116 | 117 | "act_error_quant = lambda : Quantizer(forward_number=bit_8, backward_number=bit_8,\n",
|
|
121 | 122 | "cell_type": "markdown",
|
122 | 123 | "metadata": {},
|
123 | 124 | "source": [
|
124 |
| - "Next, we define a low-precision VGG network. In the definition, we recursively insert quantization module after every convolution layer. Note that the quantization of weight, gradient, momentum, and gradient accumulator are not handled here." |
| 125 | + "Next, we define a low-precision ResNet. In the definition, we recursively insert quantization module after every convolution layer. Note that the quantization of weight, gradient, momentum, and gradient accumulator are not handled here." |
125 | 126 | ]
|
126 | 127 | },
|
127 | 128 | {
|
|
130 | 131 | "metadata": {},
|
131 | 132 | "outputs": [],
|
132 | 133 | "source": [
|
133 |
| - "# let's define the model we are using\n", |
134 |
| - "def make_layers(cfg, quant):\n", |
135 |
| - " layers = list()\n", |
136 |
| - " in_channels = 3\n", |
137 |
| - " n = 1\n", |
138 |
| - " for v in cfg:\n", |
139 |
| - " if v == 'M':\n", |
140 |
| - " layers += [nn.MaxPool2d(kernel_size=2, stride=2)]\n", |
141 |
| - " else:\n", |
142 |
| - " use_quant = v[-1] != 'N'\n", |
143 |
| - " filters = int(v) if use_quant else int(v[:-1])\n", |
144 |
| - " conv2d = nn.Conv2d(in_channels, filters, kernel_size=3, padding=1)\n", |
145 |
| - " layers += [conv2d, nn.ReLU(inplace=True)]\n", |
146 |
| - " if use_quant: layers += [quant()] # inserting quantization modules\n", |
147 |
| - " n += 1\n", |
148 |
| - " in_channels = filters\n", |
149 |
| - " return nn.Sequential(*layers)\n", |
| 134 | + "def conv3x3(in_planes, out_planes, stride=1):\n", |
| 135 | + " return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,\n", |
| 136 | + " padding=1, bias=False)\n", |
150 | 137 | "\n",
|
151 |
| - "class VGGLP(nn.Module):\n", |
152 |
| - " def __init__(self, config, quant=None, num_classes=10):\n", |
| 138 | + "class BasicBlock(nn.Module):\n", |
| 139 | + " expansion = 1\n", |
153 | 140 | "\n",
|
154 |
| - " super(VGGLP, self).__init__()\n", |
155 |
| - " self.features = make_layers(config, quant)\n", |
156 |
| - " self.classifier = nn.Sequential(\n", |
157 |
| - " nn.Dropout(),\n", |
158 |
| - " nn.Linear(512, 512),\n", |
159 |
| - " nn.ReLU(True),\n", |
160 |
| - " quant(),\n", |
161 |
| - " nn.Dropout(),\n", |
162 |
| - " nn.Linear(512, 512),\n", |
163 |
| - " nn.ReLU(True),\n", |
164 |
| - " quant(),\n", |
165 |
| - " nn.Linear(512, num_classes),\n", |
166 |
| - " )\n", |
| 141 | + " def __init__(self, inplanes, planes, quant, stride=1, downsample=None):\n", |
| 142 | + " super(BasicBlock, self).__init__()\n", |
| 143 | + " self.bn1 = nn.BatchNorm2d(inplanes)\n", |
| 144 | + " self.relu = nn.ReLU(inplace=True)\n", |
| 145 | + " self.conv1 = conv3x3(inplanes, planes, stride)\n", |
| 146 | + " self.bn2 = nn.BatchNorm2d(planes)\n", |
| 147 | + " self.conv2 = conv3x3(planes, planes)\n", |
| 148 | + " self.downsample = downsample\n", |
| 149 | + " self.stride = stride\n", |
| 150 | + " self.quant = quant()\n", |
167 | 151 | "\n",
|
168 | 152 | " def forward(self, x):\n",
|
169 |
| - " x = self.features(x)\n", |
170 |
| - " x = x.view(x.size(0), -1)\n", |
171 |
| - " x = self.classifier(x)\n", |
172 |
| - " return x\n", |
| 153 | + " residual = x\n", |
| 154 | + "\n", |
| 155 | + " out = self.bn1(x)\n", |
| 156 | + " out = self.relu(out)\n", |
| 157 | + " out = self.quant(out)\n", |
| 158 | + " out = self.conv1(out)\n", |
| 159 | + " out = self.quant(out)\n", |
| 160 | + "\n", |
| 161 | + " out = self.bn2(out)\n", |
| 162 | + " out = self.relu(out)\n", |
| 163 | + " out = self.quant(out)\n", |
| 164 | + " out = self.conv2(out)\n", |
| 165 | + " out = self.quant(out)\n", |
| 166 | + "\n", |
| 167 | + " if self.downsample is not None:\n", |
| 168 | + " residual = self.downsample(x)\n", |
| 169 | + "\n", |
| 170 | + " out += residual\n", |
| 171 | + "\n", |
| 172 | + " return out\n", |
173 | 173 | " \n",
|
174 |
| - "config = ['64', '64', 'M', '128', '128', 'M', \n", |
175 |
| - " '256', '256', '256', 'M', '512', '512', '512', 'M', '512', '512', '512', 'M'] # VGG16\n", |
| 174 | + "class PreResNet(nn.Module):\n", |
| 175 | + "\n", |
| 176 | + " def __init__(self,quant, num_classes=10, depth=20):\n", |
| 177 | + "\n", |
| 178 | + " super(PreResNet, self).__init__()\n", |
| 179 | + " assert (depth - 2) % 6 == 0, 'depth should be 6n+2'\n", |
| 180 | + " n = (depth - 2) // 6\n", |
| 181 | + "\n", |
| 182 | + " block = BasicBlock\n", |
| 183 | + "\n", |
| 184 | + " self.inplanes = 16\n", |
| 185 | + " self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1,\n", |
| 186 | + " bias=False)\n", |
| 187 | + " self.layer1 = self._make_layer(block, 16, n, quant)\n", |
| 188 | + " self.layer2 = self._make_layer(block, 32, n, quant, stride=2)\n", |
| 189 | + " self.layer3 = self._make_layer(block, 64, n, quant, stride=2)\n", |
| 190 | + " self.bn = nn.BatchNorm2d(64 * block.expansion)\n", |
| 191 | + " self.relu = nn.ReLU(inplace=True)\n", |
| 192 | + " self.avgpool = nn.AvgPool2d(8)\n", |
| 193 | + " self.fc = nn.Linear(64 * block.expansion, num_classes)\n", |
| 194 | + " self.quant = quant()\n", |
| 195 | + " IBM_half = FloatingPoint(exp=6, man=9)\n", |
| 196 | + " self.quant_half = Quantizer(IBM_half, IBM_half, \"nearest\", \"nearest\")\n", |
| 197 | + " for m in self.modules():\n", |
| 198 | + " if isinstance(m, nn.Conv2d):\n", |
| 199 | + " n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels\n", |
| 200 | + " m.weight.data.normal_(0, math.sqrt(2. / n))\n", |
| 201 | + " elif isinstance(m, nn.BatchNorm2d):\n", |
| 202 | + " m.weight.data.fill_(1)\n", |
| 203 | + " m.bias.data.zero_()\n", |
| 204 | + "\n", |
| 205 | + " def _make_layer(self, block, planes, blocks, quant, stride=1):\n", |
| 206 | + " downsample = None\n", |
| 207 | + " if stride != 1 or self.inplanes != planes * block.expansion:\n", |
| 208 | + " downsample = nn.Sequential(\n", |
| 209 | + " nn.Conv2d(self.inplanes, planes * block.expansion,\n", |
| 210 | + " kernel_size=1, stride=stride, bias=False),\n", |
| 211 | + " )\n", |
| 212 | + "\n", |
| 213 | + " layers = list()\n", |
| 214 | + " layers.append(block(self.inplanes, planes, quant , stride, downsample))\n", |
| 215 | + " self.inplanes = planes * block.expansion\n", |
| 216 | + " for i in range(1, blocks):\n", |
| 217 | + " layers.append(block(self.inplanes, planes, quant))\n", |
176 | 218 | "\n",
|
177 |
| - "model = VGGLP(config, act_error_quant)" |
| 219 | + " return nn.Sequential(*layers)\n", |
| 220 | + "\n", |
| 221 | + " def forward(self, x):\n", |
| 222 | + " x = self.quant_half(x)\n", |
| 223 | + " x = self.conv1(x)\n", |
| 224 | + " x = self.quant(x)\n", |
| 225 | + "\n", |
| 226 | + " x = self.layer1(x) # 32x32\n", |
| 227 | + " x = self.layer2(x) # 16x16\n", |
| 228 | + " x = self.layer3(x) # 8x8\n", |
| 229 | + " x = self.bn(x)\n", |
| 230 | + " x = self.relu(x)\n", |
| 231 | + " x = self.quant(x)\n", |
| 232 | + "\n", |
| 233 | + " x = self.avgpool(x)\n", |
| 234 | + " x = x.view(x.size(0), -1)\n", |
| 235 | + " x = self.fc(x)\n", |
| 236 | + " x = self.quant_half(x)\n", |
| 237 | + "\n", |
| 238 | + " return x" |
178 | 239 | ]
|
179 | 240 | },
|
180 | 241 | {
|
181 | 242 | "cell_type": "code",
|
182 | 243 | "execution_count": 5,
|
183 | 244 | "metadata": {},
|
184 | 245 | "outputs": [],
|
| 246 | + "source": [ |
| 247 | + "model = PreResNet(act_error_quant)" |
| 248 | + ] |
| 249 | + }, |
| 250 | + { |
| 251 | + "cell_type": "code", |
| 252 | + "execution_count": 7, |
| 253 | + "metadata": {}, |
| 254 | + "outputs": [], |
185 | 255 | "source": [
|
186 | 256 | "device = 'cuda' # change device to 'cpu' if you want to run this example on cpu\n",
|
187 | 257 | "model = model.to(device=device)"
|
|
196 | 266 | },
|
197 | 267 | {
|
198 | 268 | "cell_type": "code",
|
199 |
| - "execution_count": 6, |
| 269 | + "execution_count": 8, |
200 | 270 | "metadata": {},
|
201 | 271 | "outputs": [],
|
202 | 272 | "source": [
|
203 |
| - "optimizer = SGD(model.parameters(), lr=0.05, momentum=0.9, weight_decay=1e-4)\n", |
| 273 | + "optimizer = SGD(model.parameters(), lr=0.05, momentum=0.9, weight_decay=5e-4)\n", |
204 | 274 | "optimizer = OptimLP(optimizer,\n",
|
205 | 275 | " weight_quant=weight_quant,\n",
|
206 | 276 | " grad_quant=grad_quant,\n",
|
207 | 277 | " momentum_quant=momentum_quant,\n",
|
208 |
| - " acc_quant=acc_quant\n", |
| 278 | + " acc_quant=acc_quant,\n", |
| 279 | + " grad_scaling=1/1000\n", |
209 | 280 | ")"
|
210 | 281 | ]
|
211 | 282 | },
|
|
218 | 289 | },
|
219 | 290 | {
|
220 | 291 | "cell_type": "code",
|
221 |
| - "execution_count": 7, |
| 292 | + "execution_count": 9, |
222 | 293 | "metadata": {},
|
223 | 294 | "outputs": [],
|
224 | 295 | "source": [
|
|
243 | 314 | " ttl += input.size()[0]\n",
|
244 | 315 | "\n",
|
245 | 316 | " if phase==\"train\":\n",
|
| 317 | + " loss = loss * 1000\n", |
246 | 318 | " optimizer.zero_grad()\n",
|
247 | 319 | " loss.backward()\n",
|
248 | 320 | " optimizer.step()\n",
|
|
263 | 335 | },
|
264 | 336 | {
|
265 | 337 | "cell_type": "code",
|
266 |
| - "execution_count": 8, |
| 338 | + "execution_count": 10, |
267 | 339 | "metadata": {},
|
268 | 340 | "outputs": [
|
269 | 341 | {
|
270 | 342 | "name": "stderr",
|
271 | 343 | "output_type": "stream",
|
272 | 344 | "text": [
|
273 |
| - "100%|██████████| 391/391 [00:34<00:00, 11.34it/s]\n", |
274 |
| - "100%|██████████| 79/79 [00:01<00:00, 70.06it/s]\n" |
| 345 | + "100%|██████████| 391/391 [00:14<00:00, 26.41it/s]\n", |
| 346 | + "100%|██████████| 79/79 [00:01<00:00, 78.18it/s]\n" |
275 | 347 | ]
|
276 | 348 | }
|
277 | 349 | ],
|
|
282 | 354 | " test_res = run_epoch(loaders['test'], model, F.cross_entropy,\n",
|
283 | 355 | " optimizer=optimizer, phase=\"eval\")"
|
284 | 356 | ]
|
| 357 | + }, |
| 358 | + { |
| 359 | + "cell_type": "code", |
| 360 | + "execution_count": 11, |
| 361 | + "metadata": {}, |
| 362 | + "outputs": [ |
| 363 | + { |
| 364 | + "data": { |
| 365 | + "text/plain": [ |
| 366 | + "{'loss': 1.6471979439544677, 'accuracy': 37.566}" |
| 367 | + ] |
| 368 | + }, |
| 369 | + "execution_count": 11, |
| 370 | + "metadata": {}, |
| 371 | + "output_type": "execute_result" |
| 372 | + } |
| 373 | + ], |
| 374 | + "source": [ |
| 375 | + "train_res" |
| 376 | + ] |
| 377 | + }, |
| 378 | + { |
| 379 | + "cell_type": "code", |
| 380 | + "execution_count": 12, |
| 381 | + "metadata": {}, |
| 382 | + "outputs": [ |
| 383 | + { |
| 384 | + "data": { |
| 385 | + "text/plain": [ |
| 386 | + "{'loss': 1.5749474658966065, 'accuracy': 43.63}" |
| 387 | + ] |
| 388 | + }, |
| 389 | + "execution_count": 12, |
| 390 | + "metadata": {}, |
| 391 | + "output_type": "execute_result" |
| 392 | + } |
| 393 | + ], |
| 394 | + "source": [ |
| 395 | + "test_res" |
| 396 | + ] |
285 | 397 | }
|
286 | 398 | ],
|
287 | 399 | "metadata": {
|
|
300 | 412 | "name": "python",
|
301 | 413 | "nbconvert_exporter": "python",
|
302 | 414 | "pygments_lexer": "ipython3",
|
303 |
| - "version": "3.6.7" |
| 415 | + "version": "3.7.3" |
304 | 416 | }
|
305 | 417 | },
|
306 | 418 | "nbformat": 4,
|
|
0 commit comments