Skip to content

Commit c870b94

Browse files
committed
fix bug.
1 parent 7962268 commit c870b94

File tree

1 file changed

+170
-58
lines changed

1 file changed

+170
-58
lines changed

examples/tutorial/CIFAR10_Low_Precision_Training_Example.ipynb

+170-58
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,12 @@
2222
"import torch.nn.functional as F\n",
2323
"import torchvision\n",
2424
"import torchvision.transforms as transforms\n",
25-
"from qtorch.quant import Quantizer\n",
25+
"from qtorch.quant import Quantizer, quantizer\n",
2626
"from qtorch.optim import OptimLP\n",
2727
"from torch.optim import SGD\n",
2828
"from qtorch import FloatingPoint\n",
29-
"from tqdm import tqdm"
29+
"from tqdm import tqdm\n",
30+
"import math"
3031
]
3132
},
3233
{
@@ -103,14 +104,14 @@
103104
"bit_16 = FloatingPoint(exp=6, man=9)\n",
104105
"\n",
105106
"# define quantization functions\n",
106-
"weight_quant = Quantizer(forward_number=bit_8, backward_number=None,\n",
107-
" forward_rounding=\"nearest\", backward_rounding=\"nearest\")\n",
108-
"grad_quant = Quantizer(forward_number=bit_8, backward_number=None,\n",
109-
" forward_rounding=\"nearest\", backward_rounding=\"stochastic\")\n",
110-
"momentum_quant = Quantizer(forward_number=bit_16, backward_number=None,\n",
111-
" forward_rounding=\"nearest\", backward_rounding=\"stochastic\")\n",
112-
"acc_quant = Quantizer(forward_number=bit_16, backward_number=None,\n",
113-
" forward_rounding=\"nearest\", backward_rounding=\"nearest\")\n",
107+
"weight_quant = quantizer(forward_number=bit_8,\n",
108+
" forward_rounding=\"nearest\")\n",
109+
"grad_quant = quantizer(forward_number=bit_8,\n",
110+
" forward_rounding=\"nearest\")\n",
111+
"momentum_quant = quantizer(forward_number=bit_16,\n",
112+
" forward_rounding=\"stochastic\")\n",
113+
"acc_quant = quantizer(forward_number=bit_16,\n",
114+
" forward_rounding=\"stochastic\")\n",
114115
"\n",
115116
"# define a lambda function so that the Quantizer module can be duplicated easily\n",
116117
"act_error_quant = lambda : Quantizer(forward_number=bit_8, backward_number=bit_8,\n",
@@ -121,7 +122,7 @@
121122
"cell_type": "markdown",
122123
"metadata": {},
123124
"source": [
124-
"Next, we define a low-precision VGG network. In the definition, we recursively insert quantization module after every convolution layer. Note that the quantization of weight, gradient, momentum, and gradient accumulator are not handled here."
125+
"Next, we define a low-precision ResNet. In the definition, we recursively insert quantization module after every convolution layer. Note that the quantization of weight, gradient, momentum, and gradient accumulator are not handled here."
125126
]
126127
},
127128
{
@@ -130,58 +131,127 @@
130131
"metadata": {},
131132
"outputs": [],
132133
"source": [
133-
"# let's define the model we are using\n",
134-
"def make_layers(cfg, quant):\n",
135-
" layers = list()\n",
136-
" in_channels = 3\n",
137-
" n = 1\n",
138-
" for v in cfg:\n",
139-
" if v == 'M':\n",
140-
" layers += [nn.MaxPool2d(kernel_size=2, stride=2)]\n",
141-
" else:\n",
142-
" use_quant = v[-1] != 'N'\n",
143-
" filters = int(v) if use_quant else int(v[:-1])\n",
144-
" conv2d = nn.Conv2d(in_channels, filters, kernel_size=3, padding=1)\n",
145-
" layers += [conv2d, nn.ReLU(inplace=True)]\n",
146-
" if use_quant: layers += [quant()] # inserting quantization modules\n",
147-
" n += 1\n",
148-
" in_channels = filters\n",
149-
" return nn.Sequential(*layers)\n",
134+
"def conv3x3(in_planes, out_planes, stride=1):\n",
135+
" return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,\n",
136+
" padding=1, bias=False)\n",
150137
"\n",
151-
"class VGGLP(nn.Module):\n",
152-
" def __init__(self, config, quant=None, num_classes=10):\n",
138+
"class BasicBlock(nn.Module):\n",
139+
" expansion = 1\n",
153140
"\n",
154-
" super(VGGLP, self).__init__()\n",
155-
" self.features = make_layers(config, quant)\n",
156-
" self.classifier = nn.Sequential(\n",
157-
" nn.Dropout(),\n",
158-
" nn.Linear(512, 512),\n",
159-
" nn.ReLU(True),\n",
160-
" quant(),\n",
161-
" nn.Dropout(),\n",
162-
" nn.Linear(512, 512),\n",
163-
" nn.ReLU(True),\n",
164-
" quant(),\n",
165-
" nn.Linear(512, num_classes),\n",
166-
" )\n",
141+
" def __init__(self, inplanes, planes, quant, stride=1, downsample=None):\n",
142+
" super(BasicBlock, self).__init__()\n",
143+
" self.bn1 = nn.BatchNorm2d(inplanes)\n",
144+
" self.relu = nn.ReLU(inplace=True)\n",
145+
" self.conv1 = conv3x3(inplanes, planes, stride)\n",
146+
" self.bn2 = nn.BatchNorm2d(planes)\n",
147+
" self.conv2 = conv3x3(planes, planes)\n",
148+
" self.downsample = downsample\n",
149+
" self.stride = stride\n",
150+
" self.quant = quant()\n",
167151
"\n",
168152
" def forward(self, x):\n",
169-
" x = self.features(x)\n",
170-
" x = x.view(x.size(0), -1)\n",
171-
" x = self.classifier(x)\n",
172-
" return x\n",
153+
" residual = x\n",
154+
"\n",
155+
" out = self.bn1(x)\n",
156+
" out = self.relu(out)\n",
157+
" out = self.quant(out)\n",
158+
" out = self.conv1(out)\n",
159+
" out = self.quant(out)\n",
160+
"\n",
161+
" out = self.bn2(out)\n",
162+
" out = self.relu(out)\n",
163+
" out = self.quant(out)\n",
164+
" out = self.conv2(out)\n",
165+
" out = self.quant(out)\n",
166+
"\n",
167+
" if self.downsample is not None:\n",
168+
" residual = self.downsample(x)\n",
169+
"\n",
170+
" out += residual\n",
171+
"\n",
172+
" return out\n",
173173
" \n",
174-
"config = ['64', '64', 'M', '128', '128', 'M', \n",
175-
" '256', '256', '256', 'M', '512', '512', '512', 'M', '512', '512', '512', 'M'] # VGG16\n",
174+
"class PreResNet(nn.Module):\n",
175+
"\n",
176+
" def __init__(self,quant, num_classes=10, depth=20):\n",
177+
"\n",
178+
" super(PreResNet, self).__init__()\n",
179+
" assert (depth - 2) % 6 == 0, 'depth should be 6n+2'\n",
180+
" n = (depth - 2) // 6\n",
181+
"\n",
182+
" block = BasicBlock\n",
183+
"\n",
184+
" self.inplanes = 16\n",
185+
" self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1,\n",
186+
" bias=False)\n",
187+
" self.layer1 = self._make_layer(block, 16, n, quant)\n",
188+
" self.layer2 = self._make_layer(block, 32, n, quant, stride=2)\n",
189+
" self.layer3 = self._make_layer(block, 64, n, quant, stride=2)\n",
190+
" self.bn = nn.BatchNorm2d(64 * block.expansion)\n",
191+
" self.relu = nn.ReLU(inplace=True)\n",
192+
" self.avgpool = nn.AvgPool2d(8)\n",
193+
" self.fc = nn.Linear(64 * block.expansion, num_classes)\n",
194+
" self.quant = quant()\n",
195+
" IBM_half = FloatingPoint(exp=6, man=9)\n",
196+
" self.quant_half = Quantizer(IBM_half, IBM_half, \"nearest\", \"nearest\")\n",
197+
" for m in self.modules():\n",
198+
" if isinstance(m, nn.Conv2d):\n",
199+
" n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels\n",
200+
" m.weight.data.normal_(0, math.sqrt(2. / n))\n",
201+
" elif isinstance(m, nn.BatchNorm2d):\n",
202+
" m.weight.data.fill_(1)\n",
203+
" m.bias.data.zero_()\n",
204+
"\n",
205+
" def _make_layer(self, block, planes, blocks, quant, stride=1):\n",
206+
" downsample = None\n",
207+
" if stride != 1 or self.inplanes != planes * block.expansion:\n",
208+
" downsample = nn.Sequential(\n",
209+
" nn.Conv2d(self.inplanes, planes * block.expansion,\n",
210+
" kernel_size=1, stride=stride, bias=False),\n",
211+
" )\n",
212+
"\n",
213+
" layers = list()\n",
214+
" layers.append(block(self.inplanes, planes, quant , stride, downsample))\n",
215+
" self.inplanes = planes * block.expansion\n",
216+
" for i in range(1, blocks):\n",
217+
" layers.append(block(self.inplanes, planes, quant))\n",
176218
"\n",
177-
"model = VGGLP(config, act_error_quant)"
219+
" return nn.Sequential(*layers)\n",
220+
"\n",
221+
" def forward(self, x):\n",
222+
" x = self.quant_half(x)\n",
223+
" x = self.conv1(x)\n",
224+
" x = self.quant(x)\n",
225+
"\n",
226+
" x = self.layer1(x) # 32x32\n",
227+
" x = self.layer2(x) # 16x16\n",
228+
" x = self.layer3(x) # 8x8\n",
229+
" x = self.bn(x)\n",
230+
" x = self.relu(x)\n",
231+
" x = self.quant(x)\n",
232+
"\n",
233+
" x = self.avgpool(x)\n",
234+
" x = x.view(x.size(0), -1)\n",
235+
" x = self.fc(x)\n",
236+
" x = self.quant_half(x)\n",
237+
"\n",
238+
" return x"
178239
]
179240
},
180241
{
181242
"cell_type": "code",
182243
"execution_count": 5,
183244
"metadata": {},
184245
"outputs": [],
246+
"source": [
247+
"model = PreResNet(act_error_quant)"
248+
]
249+
},
250+
{
251+
"cell_type": "code",
252+
"execution_count": 7,
253+
"metadata": {},
254+
"outputs": [],
185255
"source": [
186256
"device = 'cuda' # change device to 'cpu' if you want to run this example on cpu\n",
187257
"model = model.to(device=device)"
@@ -196,16 +266,17 @@
196266
},
197267
{
198268
"cell_type": "code",
199-
"execution_count": 6,
269+
"execution_count": 8,
200270
"metadata": {},
201271
"outputs": [],
202272
"source": [
203-
"optimizer = SGD(model.parameters(), lr=0.05, momentum=0.9, weight_decay=1e-4)\n",
273+
"optimizer = SGD(model.parameters(), lr=0.05, momentum=0.9, weight_decay=5e-4)\n",
204274
"optimizer = OptimLP(optimizer,\n",
205275
" weight_quant=weight_quant,\n",
206276
" grad_quant=grad_quant,\n",
207277
" momentum_quant=momentum_quant,\n",
208-
" acc_quant=acc_quant\n",
278+
" acc_quant=acc_quant,\n",
279+
" grad_scaling=1/1000\n",
209280
")"
210281
]
211282
},
@@ -218,7 +289,7 @@
218289
},
219290
{
220291
"cell_type": "code",
221-
"execution_count": 7,
292+
"execution_count": 9,
222293
"metadata": {},
223294
"outputs": [],
224295
"source": [
@@ -243,6 +314,7 @@
243314
" ttl += input.size()[0]\n",
244315
"\n",
245316
" if phase==\"train\":\n",
317+
" loss = loss * 1000\n",
246318
" optimizer.zero_grad()\n",
247319
" loss.backward()\n",
248320
" optimizer.step()\n",
@@ -263,15 +335,15 @@
263335
},
264336
{
265337
"cell_type": "code",
266-
"execution_count": 8,
338+
"execution_count": 10,
267339
"metadata": {},
268340
"outputs": [
269341
{
270342
"name": "stderr",
271343
"output_type": "stream",
272344
"text": [
273-
"100%|██████████| 391/391 [00:34<00:00, 11.34it/s]\n",
274-
"100%|██████████| 79/79 [00:01<00:00, 70.06it/s]\n"
345+
"100%|██████████| 391/391 [00:14<00:00, 26.41it/s]\n",
346+
"100%|██████████| 79/79 [00:01<00:00, 78.18it/s]\n"
275347
]
276348
}
277349
],
@@ -282,6 +354,46 @@
282354
" test_res = run_epoch(loaders['test'], model, F.cross_entropy,\n",
283355
" optimizer=optimizer, phase=\"eval\")"
284356
]
357+
},
358+
{
359+
"cell_type": "code",
360+
"execution_count": 11,
361+
"metadata": {},
362+
"outputs": [
363+
{
364+
"data": {
365+
"text/plain": [
366+
"{'loss': 1.6471979439544677, 'accuracy': 37.566}"
367+
]
368+
},
369+
"execution_count": 11,
370+
"metadata": {},
371+
"output_type": "execute_result"
372+
}
373+
],
374+
"source": [
375+
"train_res"
376+
]
377+
},
378+
{
379+
"cell_type": "code",
380+
"execution_count": 12,
381+
"metadata": {},
382+
"outputs": [
383+
{
384+
"data": {
385+
"text/plain": [
386+
"{'loss': 1.5749474658966065, 'accuracy': 43.63}"
387+
]
388+
},
389+
"execution_count": 12,
390+
"metadata": {},
391+
"output_type": "execute_result"
392+
}
393+
],
394+
"source": [
395+
"test_res"
396+
]
285397
}
286398
],
287399
"metadata": {
@@ -300,7 +412,7 @@
300412
"name": "python",
301413
"nbconvert_exporter": "python",
302414
"pygments_lexer": "ipython3",
303-
"version": "3.6.7"
415+
"version": "3.7.3"
304416
}
305417
},
306418
"nbformat": 4,

0 commit comments

Comments
 (0)