Skip to content

Commit a4e914e

Browse files
committed
update cifar tutorial.
1 parent b311972 commit a4e914e

File tree

1 file changed

+135
-72
lines changed

1 file changed

+135
-72
lines changed

examples/CIFAR10_Low_Precision_Training_Example.ipynb

+135-72
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,12 @@
55
"metadata": {},
66
"source": [
77
"# CIFAR10 Low Precision Training Example\n",
8-
"In this notebook, we present a quick example of how to simulate training a deep neural network in low precision"
8+
"In this notebook, we present a quick example of how to simulate training a deep neural network in low precision with QPyTorch."
99
]
1010
},
1111
{
1212
"cell_type": "code",
13-
"execution_count": 33,
13+
"execution_count": 1,
1414
"metadata": {},
1515
"outputs": [],
1616
"source": [
@@ -25,33 +25,108 @@
2525
"from qtorch.quant import Quantizer\n",
2626
"from qtorch.optim import OptimLP\n",
2727
"from torch.optim import SGD\n",
28-
"from qtorch import FloatingPoint"
28+
"from qtorch import FloatingPoint\n",
29+
"from tqdm import tqdm"
30+
]
31+
},
32+
{
33+
"cell_type": "markdown",
34+
"metadata": {},
35+
"source": [
36+
"We first load the data. In this example, we will experiment with CIFAR10."
2937
]
3038
},
3139
{
3240
"cell_type": "code",
33-
"execution_count": 20,
41+
"execution_count": 2,
42+
"metadata": {
43+
"scrolled": false
44+
},
45+
"outputs": [
46+
{
47+
"name": "stdout",
48+
"output_type": "stream",
49+
"text": [
50+
"Files already downloaded and verified\n",
51+
"Files already downloaded and verified\n"
52+
]
53+
}
54+
],
55+
"source": [
56+
"# loading data\n",
57+
"ds = torchvision.datasets.CIFAR10\n",
58+
"path = os.path.join(\"./data\", \"CIFAR10\")\n",
59+
"transform_train = transforms.Compose([\n",
60+
" transforms.RandomCrop(32, padding=4),\n",
61+
" transforms.RandomHorizontalFlip(),\n",
62+
" transforms.ToTensor(),\n",
63+
" transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),\n",
64+
"])\n",
65+
"transform_test = transforms.Compose([\n",
66+
" transforms.ToTensor(),\n",
67+
" transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),\n",
68+
"])\n",
69+
"train_set = ds(path, train=True, download=True, transform=transform_train)\n",
70+
"test_set = ds(path, train=False, download=True, transform=transform_test)\n",
71+
"loaders = {\n",
72+
" 'train': torch.utils.data.DataLoader(\n",
73+
" train_set,\n",
74+
" batch_size=128,\n",
75+
" shuffle=True,\n",
76+
" num_workers=4,\n",
77+
" pin_memory=True\n",
78+
" ),\n",
79+
" 'test': torch.utils.data.DataLoader(\n",
80+
" test_set,\n",
81+
" batch_size=128,\n",
82+
" num_workers=4,\n",
83+
" pin_memory=True\n",
84+
" )\n",
85+
"}"
86+
]
87+
},
88+
{
89+
"cell_type": "markdown",
90+
"metadata": {},
91+
"source": [
92+
"We then define the quantization setting we are going to use. In particular, here we follow the setting reported in the paper \"Training Deep Neural Networks with 8-bit Floating Point Numbers\", where the authors propose to use specialized 8-bit and 16-bit floating point format."
93+
]
94+
},
95+
{
96+
"cell_type": "code",
97+
"execution_count": 3,
3498
"metadata": {},
3599
"outputs": [],
36100
"source": [
37-
"# let's define the quantizers we are using\n",
101+
"# define two floating point formats\n",
38102
"bit_8 = FloatingPoint(exp=5, man=2)\n",
39103
"bit_16 = FloatingPoint(exp=6, man=9)\n",
104+
"\n",
105+
"# define quantization functions\n",
40106
"weight_quant = Quantizer(forward_number=bit_8, backward_number=None,\n",
41107
" forward_rounding=\"nearest\", backward_rounding=\"nearest\")\n",
42108
"grad_quant = Quantizer(forward_number=bit_8, backward_number=None,\n",
43-
" forward_rounding=\"nearest\", backward_rounding=\"nearest\")\n",
109+
" forward_rounding=\"nearest\", backward_rounding=\"stochastic\")\n",
44110
"momentum_quant = Quantizer(forward_number=bit_16, backward_number=None,\n",
45-
" forward_rounding=\"nearest\", backward_rounding=\"nearest\")\n",
111+
" forward_rounding=\"nearest\", backward_rounding=\"stochastic\")\n",
46112
"acc_quant = Quantizer(forward_number=bit_16, backward_number=None,\n",
47113
" forward_rounding=\"nearest\", backward_rounding=\"nearest\")\n",
114+
"\n",
115+
"# define a lambda function so that the Quantizer module can be duplicated easily\n",
48116
"act_error_quant = lambda : Quantizer(forward_number=bit_8, backward_number=bit_8,\n",
49117
" forward_rounding=\"nearest\", backward_rounding=\"nearest\")"
50118
]
51119
},
120+
{
121+
"cell_type": "markdown",
122+
"metadata": {},
123+
"source": [
124+
"Next, we define a low-precision VGG network. In the definition, we recursively insert quantization module after every convolution layer. Note that the quantization of weight, gradient, momentum, and gradient accumulator are not handled here."
125+
]
126+
},
52127
{
53128
"cell_type": "code",
54-
"execution_count": 23,
129+
"execution_count": 4,
55130
"metadata": {},
56131
"outputs": [],
57132
"source": [
@@ -68,7 +143,7 @@
68143
" filters = int(v) if use_quant else int(v[:-1])\n",
69144
" conv2d = nn.Conv2d(in_channels, filters, kernel_size=3, padding=1)\n",
70145
" layers += [conv2d, nn.ReLU(inplace=True)]\n",
71-
" if use_quant: layers += [quant()]\n",
146+
" if use_quant: layers += [quant()] # inserting quantization modules\n",
72147
" n += 1\n",
73148
" in_channels = filters\n",
74149
" return nn.Sequential(*layers)\n",
@@ -99,30 +174,52 @@
99174
"config = ['64', '64', 'M', '128', '128', 'M', \n",
100175
" '256', '256', '256', 'M', '512', '512', '512', 'M', '512', '512', '512', 'M'] # VGG16\n",
101176
"\n",
102-
"model = VGGLP(config, act_error_quant, )"
177+
"model = VGGLP(config, act_error_quant)"
103178
]
104179
},
105180
{
106181
"cell_type": "code",
107-
"execution_count": 26,
182+
"execution_count": 5,
183+
"metadata": {},
184+
"outputs": [],
185+
"source": [
186+
"device = 'cuda' # change device to 'cpu' if you want to run this example on cpu\n",
187+
"model = model.to(device=device)"
188+
]
189+
},
190+
{
191+
"cell_type": "markdown",
192+
"metadata": {},
193+
"source": [
194+
"We now use the low-precision optimizer wrapper to help define the quantization of weight, gradient, momentum, and gradient accumulator."
195+
]
196+
},
197+
{
198+
"cell_type": "code",
199+
"execution_count": 6,
108200
"metadata": {},
109201
"outputs": [],
110202
"source": [
111-
"# define optimizer\n",
112203
"optimizer = SGD(model.parameters(), lr=0.05, momentum=0.9, weight_decay=1e-4)\n",
113204
"optimizer = OptimLP(optimizer,\n",
114-
" weight_quant=weight_quant,\n",
115-
" grad_quant=grad_quant,\n",
116-
" momentum_quant=momentum_quant\n",
205+
" weight_quant=weight_quant,\n",
206+
" grad_quant=grad_quant,\n",
207+
" momentum_quant=momentum_quant,\n",
208+
" acc_quant=acc_quant\n",
117209
")"
118210
]
119211
},
212+
{
213+
"cell_type": "markdown",
214+
"metadata": {},
215+
"source": [
216+
"We can reuse common training scripts without any extra codes to handle quantization."
217+
]
218+
},
120219
{
121220
"cell_type": "code",
122-
"execution_count": 38,
123-
"metadata": {
124-
"collapsed": true
125-
},
221+
"execution_count": 7,
222+
"metadata": {},
126223
"outputs": [],
127224
"source": [
128225
"def run_epoch(loader, model, criterion, optimizer=None, phase=\"train\"):\n",
@@ -135,12 +232,11 @@
135232
"\n",
136233
" ttl = 0\n",
137234
" with torch.autograd.set_grad_enabled(phase==\"train\"):\n",
138-
" for i, (input, target) in enumerate(loader):\n",
139-
" input = input.cuda(async=True)\n",
140-
" target = target.cuda(async=True)\n",
235+
" for i, (input, target) in tqdm(enumerate(loader), total=len(loader)):\n",
236+
" input = input.to(device=device)\n",
237+
" target = target.to(device=device)\n",
141238
" output = model(input)\n",
142239
" loss = criterion(output, target)\n",
143-
"\n",
144240
" loss_sum += loss.cpu().item() * input.size(0)\n",
145241
" pred = output.data.max(1, keepdim=True)[1]\n",
146242
" correct += pred.eq(target.data.view_as(pred)).sum()\n",
@@ -158,66 +254,33 @@
158254
" }"
159255
]
160256
},
257+
{
258+
"cell_type": "markdown",
259+
"metadata": {},
260+
"source": [
261+
"Begin the training process just as usual. Enjoy!"
262+
]
263+
},
161264
{
162265
"cell_type": "code",
163-
"execution_count": 37,
266+
"execution_count": 8,
164267
"metadata": {},
165268
"outputs": [
166269
{
167-
"name": "stdout",
270+
"name": "stderr",
168271
"output_type": "stream",
169272
"text": [
170-
"Files already downloaded and verified\n",
171-
"Files already downloaded and verified\n"
273+
"100%|██████████| 391/391 [00:34<00:00, 11.34it/s]\n",
274+
"100%|██████████| 79/79 [00:01<00:00, 70.06it/s]\n"
172275
]
173276
}
174277
],
175278
"source": [
176-
"# load data\n",
177-
"ds = torchvision.datasets.CIFAR10\n",
178-
"path = os.path.join(\"./data\", \"CIFAR10\")\n",
179-
"transform_train = transforms.Compose([\n",
180-
" transforms.RandomCrop(32, padding=4),\n",
181-
" transforms.RandomHorizontalFlip(),\n",
182-
" transforms.ToTensor(),\n",
183-
" transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),\n",
184-
"])\n",
185-
"transform_test = transforms.Compose([\n",
186-
" transforms.ToTensor(),\n",
187-
" transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),\n",
188-
"])\n",
189-
"train_set = ds(path, train=True, download=True, transform=transform_train)\n",
190-
"test_set = ds(path, train=False, download=True, transform=transform_test)\n",
191-
"loaders = {\n",
192-
" 'train': torch.utils.data.DataLoader(\n",
193-
" train_set,\n",
194-
" batch_size=128,\n",
195-
" shuffle=True,\n",
196-
" num_workers=4,\n",
197-
" pin_memory=True\n",
198-
" ),\n",
199-
" 'test': torch.utils.data.DataLoader(\n",
200-
" test_set,\n",
201-
" batch_size=128,\n",
202-
" num_workers=4,\n",
203-
" pin_memory=True\n",
204-
" )\n",
205-
"}"
206-
]
207-
},
208-
{
209-
"cell_type": "code",
210-
"execution_count": null,
211-
"metadata": {
212-
"collapsed": true
213-
},
214-
"outputs": [],
215-
"source": [
216-
"for epoch in range(start_epoch, 10):\n",
217-
" train_res = utils.run_epoch(loaders['train'], model, criterion,\n",
279+
"for epoch in range(1):\n",
280+
" train_res = run_epoch(loaders['train'], model, F.cross_entropy,\n",
218281
" optimizer=optimizer, phase=\"train\")\n",
219-
" test_res = utils.run_epoch(loaders['test'], model, criterion,\n",
220-
" optimizer=optimizer, phase=\"test\")"
282+
" test_res = run_epoch(loaders['test'], model, F.cross_entropy,\n",
283+
" optimizer=optimizer, phase=\"eval\")"
221284
]
222285
}
223286
],
@@ -237,7 +300,7 @@
237300
"name": "python",
238301
"nbconvert_exporter": "python",
239302
"pygments_lexer": "ipython3",
240-
"version": "3.6.8"
303+
"version": "3.6.7"
241304
}
242305
},
243306
"nbformat": 4,

0 commit comments

Comments
 (0)