|
5 | 5 | "metadata": {},
|
6 | 6 | "source": [
|
7 | 7 | "# CIFAR10 Low Precision Training Example\n",
|
8 |
| - "In this notebook, we present a quick example of how to simulate training a deep neural network in low precision" |
| 8 | + "In this notebook, we present a quick example of how to simulate training a deep neural network in low precision with QPyTorch." |
9 | 9 | ]
|
10 | 10 | },
|
11 | 11 | {
|
12 | 12 | "cell_type": "code",
|
13 |
| - "execution_count": 33, |
| 13 | + "execution_count": 1, |
14 | 14 | "metadata": {},
|
15 | 15 | "outputs": [],
|
16 | 16 | "source": [
|
|
25 | 25 | "from qtorch.quant import Quantizer\n",
|
26 | 26 | "from qtorch.optim import OptimLP\n",
|
27 | 27 | "from torch.optim import SGD\n",
|
28 |
| - "from qtorch import FloatingPoint" |
| 28 | + "from qtorch import FloatingPoint\n", |
| 29 | + "from tqdm import tqdm" |
| 30 | + ] |
| 31 | + }, |
| 32 | + { |
| 33 | + "cell_type": "markdown", |
| 34 | + "metadata": {}, |
| 35 | + "source": [ |
| 36 | + "We first load the data. In this example, we will experiment with CIFAR10." |
29 | 37 | ]
|
30 | 38 | },
|
31 | 39 | {
|
32 | 40 | "cell_type": "code",
|
33 |
| - "execution_count": 20, |
| 41 | + "execution_count": 2, |
| 42 | + "metadata": { |
| 43 | + "scrolled": false |
| 44 | + }, |
| 45 | + "outputs": [ |
| 46 | + { |
| 47 | + "name": "stdout", |
| 48 | + "output_type": "stream", |
| 49 | + "text": [ |
| 50 | + "Files already downloaded and verified\n", |
| 51 | + "Files already downloaded and verified\n" |
| 52 | + ] |
| 53 | + } |
| 54 | + ], |
| 55 | + "source": [ |
| 56 | + "# loading data\n", |
| 57 | + "ds = torchvision.datasets.CIFAR10\n", |
| 58 | + "path = os.path.join(\"./data\", \"CIFAR10\")\n", |
| 59 | + "transform_train = transforms.Compose([\n", |
| 60 | + " transforms.RandomCrop(32, padding=4),\n", |
| 61 | + " transforms.RandomHorizontalFlip(),\n", |
| 62 | + " transforms.ToTensor(),\n", |
| 63 | + " transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),\n", |
| 64 | + "])\n", |
| 65 | + "transform_test = transforms.Compose([\n", |
| 66 | + " transforms.ToTensor(),\n", |
| 67 | + " transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),\n", |
| 68 | + "])\n", |
| 69 | + "train_set = ds(path, train=True, download=True, transform=transform_train)\n", |
| 70 | + "test_set = ds(path, train=False, download=True, transform=transform_test)\n", |
| 71 | + "loaders = {\n", |
| 72 | + " 'train': torch.utils.data.DataLoader(\n", |
| 73 | + " train_set,\n", |
| 74 | + " batch_size=128,\n", |
| 75 | + " shuffle=True,\n", |
| 76 | + " num_workers=4,\n", |
| 77 | + " pin_memory=True\n", |
| 78 | + " ),\n", |
| 79 | + " 'test': torch.utils.data.DataLoader(\n", |
| 80 | + " test_set,\n", |
| 81 | + " batch_size=128,\n", |
| 82 | + " num_workers=4,\n", |
| 83 | + " pin_memory=True\n", |
| 84 | + " )\n", |
| 85 | + "}" |
| 86 | + ] |
| 87 | + }, |
| 88 | + { |
| 89 | + "cell_type": "markdown", |
| 90 | + "metadata": {}, |
| 91 | + "source": [ |
| 92 | + "We then define the quantization setting we are going to use. In particular, here we follow the setting reported in the paper \"Training Deep Neural Networks with 8-bit Floating Point Numbers\", where the authors propose to use specialized 8-bit and 16-bit floating point format." |
| 93 | + ] |
| 94 | + }, |
| 95 | + { |
| 96 | + "cell_type": "code", |
| 97 | + "execution_count": 3, |
34 | 98 | "metadata": {},
|
35 | 99 | "outputs": [],
|
36 | 100 | "source": [
|
37 |
| - "# let's define the quantizers we are using\n", |
| 101 | + "# define two floating point formats\n", |
38 | 102 | "bit_8 = FloatingPoint(exp=5, man=2)\n",
|
39 | 103 | "bit_16 = FloatingPoint(exp=6, man=9)\n",
|
| 104 | + "\n", |
| 105 | + "# define quantization functions\n", |
40 | 106 | "weight_quant = Quantizer(forward_number=bit_8, backward_number=None,\n",
|
41 | 107 | " forward_rounding=\"nearest\", backward_rounding=\"nearest\")\n",
|
42 | 108 | "grad_quant = Quantizer(forward_number=bit_8, backward_number=None,\n",
|
43 |
| - " forward_rounding=\"nearest\", backward_rounding=\"nearest\")\n", |
| 109 | + " forward_rounding=\"nearest\", backward_rounding=\"stochastic\")\n", |
44 | 110 | "momentum_quant = Quantizer(forward_number=bit_16, backward_number=None,\n",
|
45 |
| - " forward_rounding=\"nearest\", backward_rounding=\"nearest\")\n", |
| 111 | + " forward_rounding=\"nearest\", backward_rounding=\"stochastic\")\n", |
46 | 112 | "acc_quant = Quantizer(forward_number=bit_16, backward_number=None,\n",
|
47 | 113 | " forward_rounding=\"nearest\", backward_rounding=\"nearest\")\n",
|
| 114 | + "\n", |
| 115 | + "# define a lambda function so that the Quantizer module can be duplicated easily\n", |
48 | 116 | "act_error_quant = lambda : Quantizer(forward_number=bit_8, backward_number=bit_8,\n",
|
49 | 117 | " forward_rounding=\"nearest\", backward_rounding=\"nearest\")"
|
50 | 118 | ]
|
51 | 119 | },
|
| 120 | + { |
| 121 | + "cell_type": "markdown", |
| 122 | + "metadata": {}, |
| 123 | + "source": [ |
| 124 | + "Next, we define a low-precision VGG network. In the definition, we recursively insert quantization module after every convolution layer. Note that the quantization of weight, gradient, momentum, and gradient accumulator are not handled here." |
| 125 | + ] |
| 126 | + }, |
52 | 127 | {
|
53 | 128 | "cell_type": "code",
|
54 |
| - "execution_count": 23, |
| 129 | + "execution_count": 4, |
55 | 130 | "metadata": {},
|
56 | 131 | "outputs": [],
|
57 | 132 | "source": [
|
|
68 | 143 | " filters = int(v) if use_quant else int(v[:-1])\n",
|
69 | 144 | " conv2d = nn.Conv2d(in_channels, filters, kernel_size=3, padding=1)\n",
|
70 | 145 | " layers += [conv2d, nn.ReLU(inplace=True)]\n",
|
71 |
| - " if use_quant: layers += [quant()]\n", |
| 146 | + " if use_quant: layers += [quant()] # inserting quantization modules\n", |
72 | 147 | " n += 1\n",
|
73 | 148 | " in_channels = filters\n",
|
74 | 149 | " return nn.Sequential(*layers)\n",
|
|
99 | 174 | "config = ['64', '64', 'M', '128', '128', 'M', \n",
|
100 | 175 | " '256', '256', '256', 'M', '512', '512', '512', 'M', '512', '512', '512', 'M'] # VGG16\n",
|
101 | 176 | "\n",
|
102 |
| - "model = VGGLP(config, act_error_quant, )" |
| 177 | + "model = VGGLP(config, act_error_quant)" |
103 | 178 | ]
|
104 | 179 | },
|
105 | 180 | {
|
106 | 181 | "cell_type": "code",
|
107 |
| - "execution_count": 26, |
| 182 | + "execution_count": 5, |
| 183 | + "metadata": {}, |
| 184 | + "outputs": [], |
| 185 | + "source": [ |
| 186 | + "device = 'cuda' # change device to 'cpu' if you want to run this example on cpu\n", |
| 187 | + "model = model.to(device=device)" |
| 188 | + ] |
| 189 | + }, |
| 190 | + { |
| 191 | + "cell_type": "markdown", |
| 192 | + "metadata": {}, |
| 193 | + "source": [ |
| 194 | + "We now use the low-precision optimizer wrapper to help define the quantization of weight, gradient, momentum, and gradient accumulator." |
| 195 | + ] |
| 196 | + }, |
| 197 | + { |
| 198 | + "cell_type": "code", |
| 199 | + "execution_count": 6, |
108 | 200 | "metadata": {},
|
109 | 201 | "outputs": [],
|
110 | 202 | "source": [
|
111 |
| - "# define optimizer\n", |
112 | 203 | "optimizer = SGD(model.parameters(), lr=0.05, momentum=0.9, weight_decay=1e-4)\n",
|
113 | 204 | "optimizer = OptimLP(optimizer,\n",
|
114 |
| - " weight_quant=weight_quant,\n", |
115 |
| - " grad_quant=grad_quant,\n", |
116 |
| - " momentum_quant=momentum_quant\n", |
| 205 | + " weight_quant=weight_quant,\n", |
| 206 | + " grad_quant=grad_quant,\n", |
| 207 | + " momentum_quant=momentum_quant,\n", |
| 208 | + " acc_quant=acc_quant\n", |
117 | 209 | ")"
|
118 | 210 | ]
|
119 | 211 | },
|
| 212 | + { |
| 213 | + "cell_type": "markdown", |
| 214 | + "metadata": {}, |
| 215 | + "source": [ |
| 216 | + "We can reuse common training scripts without any extra codes to handle quantization." |
| 217 | + ] |
| 218 | + }, |
120 | 219 | {
|
121 | 220 | "cell_type": "code",
|
122 |
| - "execution_count": 38, |
123 |
| - "metadata": { |
124 |
| - "collapsed": true |
125 |
| - }, |
| 221 | + "execution_count": 7, |
| 222 | + "metadata": {}, |
126 | 223 | "outputs": [],
|
127 | 224 | "source": [
|
128 | 225 | "def run_epoch(loader, model, criterion, optimizer=None, phase=\"train\"):\n",
|
|
135 | 232 | "\n",
|
136 | 233 | " ttl = 0\n",
|
137 | 234 | " with torch.autograd.set_grad_enabled(phase==\"train\"):\n",
|
138 |
| - " for i, (input, target) in enumerate(loader):\n", |
139 |
| - " input = input.cuda(async=True)\n", |
140 |
| - " target = target.cuda(async=True)\n", |
| 235 | + " for i, (input, target) in tqdm(enumerate(loader), total=len(loader)):\n", |
| 236 | + " input = input.to(device=device)\n", |
| 237 | + " target = target.to(device=device)\n", |
141 | 238 | " output = model(input)\n",
|
142 | 239 | " loss = criterion(output, target)\n",
|
143 |
| - "\n", |
144 | 240 | " loss_sum += loss.cpu().item() * input.size(0)\n",
|
145 | 241 | " pred = output.data.max(1, keepdim=True)[1]\n",
|
146 | 242 | " correct += pred.eq(target.data.view_as(pred)).sum()\n",
|
|
158 | 254 | " }"
|
159 | 255 | ]
|
160 | 256 | },
|
| 257 | + { |
| 258 | + "cell_type": "markdown", |
| 259 | + "metadata": {}, |
| 260 | + "source": [ |
| 261 | + "Begin the training process just as usual. Enjoy!" |
| 262 | + ] |
| 263 | + }, |
161 | 264 | {
|
162 | 265 | "cell_type": "code",
|
163 |
| - "execution_count": 37, |
| 266 | + "execution_count": 8, |
164 | 267 | "metadata": {},
|
165 | 268 | "outputs": [
|
166 | 269 | {
|
167 |
| - "name": "stdout", |
| 270 | + "name": "stderr", |
168 | 271 | "output_type": "stream",
|
169 | 272 | "text": [
|
170 |
| - "Files already downloaded and verified\n", |
171 |
| - "Files already downloaded and verified\n" |
| 273 | + "100%|██████████| 391/391 [00:34<00:00, 11.34it/s]\n", |
| 274 | + "100%|██████████| 79/79 [00:01<00:00, 70.06it/s]\n" |
172 | 275 | ]
|
173 | 276 | }
|
174 | 277 | ],
|
175 | 278 | "source": [
|
176 |
| - "# load data\n", |
177 |
| - "ds = torchvision.datasets.CIFAR10\n", |
178 |
| - "path = os.path.join(\"./data\", \"CIFAR10\")\n", |
179 |
| - "transform_train = transforms.Compose([\n", |
180 |
| - " transforms.RandomCrop(32, padding=4),\n", |
181 |
| - " transforms.RandomHorizontalFlip(),\n", |
182 |
| - " transforms.ToTensor(),\n", |
183 |
| - " transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),\n", |
184 |
| - "])\n", |
185 |
| - "transform_test = transforms.Compose([\n", |
186 |
| - " transforms.ToTensor(),\n", |
187 |
| - " transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),\n", |
188 |
| - "])\n", |
189 |
| - "train_set = ds(path, train=True, download=True, transform=transform_train)\n", |
190 |
| - "test_set = ds(path, train=False, download=True, transform=transform_test)\n", |
191 |
| - "loaders = {\n", |
192 |
| - " 'train': torch.utils.data.DataLoader(\n", |
193 |
| - " train_set,\n", |
194 |
| - " batch_size=128,\n", |
195 |
| - " shuffle=True,\n", |
196 |
| - " num_workers=4,\n", |
197 |
| - " pin_memory=True\n", |
198 |
| - " ),\n", |
199 |
| - " 'test': torch.utils.data.DataLoader(\n", |
200 |
| - " test_set,\n", |
201 |
| - " batch_size=128,\n", |
202 |
| - " num_workers=4,\n", |
203 |
| - " pin_memory=True\n", |
204 |
| - " )\n", |
205 |
| - "}" |
206 |
| - ] |
207 |
| - }, |
208 |
| - { |
209 |
| - "cell_type": "code", |
210 |
| - "execution_count": null, |
211 |
| - "metadata": { |
212 |
| - "collapsed": true |
213 |
| - }, |
214 |
| - "outputs": [], |
215 |
| - "source": [ |
216 |
| - "for epoch in range(start_epoch, 10):\n", |
217 |
| - " train_res = utils.run_epoch(loaders['train'], model, criterion,\n", |
| 279 | + "for epoch in range(1):\n", |
| 280 | + " train_res = run_epoch(loaders['train'], model, F.cross_entropy,\n", |
218 | 281 | " optimizer=optimizer, phase=\"train\")\n",
|
219 |
| - " test_res = utils.run_epoch(loaders['test'], model, criterion,\n", |
220 |
| - " optimizer=optimizer, phase=\"test\")" |
| 282 | + " test_res = run_epoch(loaders['test'], model, F.cross_entropy,\n", |
| 283 | + " optimizer=optimizer, phase=\"eval\")" |
221 | 284 | ]
|
222 | 285 | }
|
223 | 286 | ],
|
|
237 | 300 | "name": "python",
|
238 | 301 | "nbconvert_exporter": "python",
|
239 | 302 | "pygments_lexer": "ipython3",
|
240 |
| - "version": "3.6.8" |
| 303 | + "version": "3.6.7" |
241 | 304 | }
|
242 | 305 | },
|
243 | 306 | "nbformat": 4,
|
|
0 commit comments