|
| 1 | +{ |
| 2 | + "cells": [ |
| 3 | + { |
| 4 | + "cell_type": "code", |
| 5 | + "execution_count": 92, |
| 6 | + "metadata": { |
| 7 | + "collapsed": false |
| 8 | + }, |
| 9 | + "outputs": [], |
| 10 | + "source": [ |
| 11 | + "import torch\n", |
| 12 | + "from torch.autograd import Function" |
| 13 | + ] |
| 14 | + }, |
| 15 | + { |
| 16 | + "cell_type": "markdown", |
| 17 | + "metadata": {}, |
| 18 | + "source": [ |
| 19 | + "# Parameter-less example" |
| 20 | + ] |
| 21 | + }, |
| 22 | + { |
| 23 | + "cell_type": "code", |
| 24 | + "execution_count": 93, |
| 25 | + "metadata": { |
| 26 | + "collapsed": false |
| 27 | + }, |
| 28 | + "outputs": [], |
| 29 | + "source": [ |
| 30 | + "from numpy.fft import rfft2, irfft2\n", |
| 31 | + "class BadFFTFunction(Function):\n", |
| 32 | + " \n", |
| 33 | + " def forward(self, input):\n", |
| 34 | + " numpy_input = input.numpy()\n", |
| 35 | + " result = abs(rfft2(numpy_input))\n", |
| 36 | + " return torch.FloatTensor(result)\n", |
| 37 | + " \n", |
| 38 | + " def backward(self, grad_output):\n", |
| 39 | + " numpy_go = grad_output.numpy()\n", |
| 40 | + " result = irfft2(numpy_go)\n", |
| 41 | + " return torch.FloatTensor(result)\n", |
| 42 | + "\n", |
| 43 | + "def incorrect_fft(input):\n", |
| 44 | + " return FFTFunction()(input)" |
| 45 | + ] |
| 46 | + }, |
| 47 | + { |
| 48 | + "cell_type": "code", |
| 49 | + "execution_count": 94, |
| 50 | + "metadata": { |
| 51 | + "collapsed": false |
| 52 | + }, |
| 53 | + "outputs": [ |
| 54 | + { |
| 55 | + "name": "stdout", |
| 56 | + "output_type": "stream", |
| 57 | + "text": [ |
| 58 | + "\n", |
| 59 | + " 3.0878 7.1403 7.5860 1.7596 3.0176\n", |
| 60 | + " 6.3160 15.2517 11.1081 0.9172 6.8577\n", |
| 61 | + " 8.6503 2.2013 6.3555 11.1981 1.9266\n", |
| 62 | + " 3.9919 6.8862 8.8132 5.7938 4.2413\n", |
| 63 | + " 12.2501 10.7839 6.7181 12.1096 1.1942\n", |
| 64 | + " 3.9919 9.3072 2.6704 3.3263 4.2413\n", |
| 65 | + " 8.6503 6.8158 12.4148 2.6462 1.9266\n", |
| 66 | + " 6.3160 15.2663 9.8261 5.8583 6.8577\n", |
| 67 | + "[torch.FloatTensor of size 8x5]\n", |
| 68 | + "\n", |
| 69 | + "\n", |
| 70 | + " 0.0569 -0.3193 0.0401 0.1293 0.0318 0.1293 0.0401 -0.3193\n", |
| 71 | + " 0.0570 0.0161 -0.0421 -0.1272 0.0414 0.0121 -0.0592 -0.0874\n", |
| 72 | + "-0.1144 -0.0146 0.0604 -0.0023 0.0222 0.0622 0.0825 -0.1057\n", |
| 73 | + "-0.0451 0.1061 0.0329 -0.0274 0.0302 -0.0347 0.0227 -0.1079\n", |
| 74 | + " 0.1287 0.1796 -0.0766 -0.0698 0.0929 -0.0698 -0.0766 0.1796\n", |
| 75 | + "-0.0451 -0.1079 0.0227 -0.0347 0.0302 -0.0274 0.0329 0.1061\n", |
| 76 | + "-0.1144 -0.1057 0.0825 0.0622 0.0222 -0.0023 0.0604 -0.0146\n", |
| 77 | + " 0.0570 -0.0874 -0.0592 0.0121 0.0414 -0.1272 -0.0421 0.0161\n", |
| 78 | + "[torch.FloatTensor of size 8x8]\n", |
| 79 | + "\n" |
| 80 | + ] |
| 81 | + } |
| 82 | + ], |
| 83 | + "source": [ |
| 84 | + "input = Variable(torch.randn(8, 8), requires_grad=True)\n", |
| 85 | + "result = incorrect_fft(input)\n", |
| 86 | + "print(result.data)\n", |
| 87 | + "result.backward(torch.randn(result.size()))\n", |
| 88 | + "print(input.grad)" |
| 89 | + ] |
| 90 | + }, |
| 91 | + { |
| 92 | + "cell_type": "markdown", |
| 93 | + "metadata": {}, |
| 94 | + "source": [ |
| 95 | + "# Parametrized example" |
| 96 | + ] |
| 97 | + }, |
| 98 | + { |
| 99 | + "cell_type": "code", |
| 100 | + "execution_count": 95, |
| 101 | + "metadata": { |
| 102 | + "collapsed": false |
| 103 | + }, |
| 104 | + "outputs": [], |
| 105 | + "source": [ |
| 106 | + "from scipy.signal import convolve2d, correlate2d\n", |
| 107 | + "from torch.nn.modules.module import Module\n", |
| 108 | + "\n", |
| 109 | + "class ScipyConv2dFunction(Function):\n", |
| 110 | + " \n", |
| 111 | + " def forward(self, input, filter):\n", |
| 112 | + " result = correlate2d(input.numpy(), filter.numpy(), mode='valid')\n", |
| 113 | + " self.save_for_backward(input, filter)\n", |
| 114 | + " return torch.FloatTensor(result)\n", |
| 115 | + " \n", |
| 116 | + " def backward(self, grad_output):\n", |
| 117 | + " input, filter = self.saved_tensors\n", |
| 118 | + " grad_input = convolve2d(grad_output.numpy(), filter.t().numpy(), mode='full')\n", |
| 119 | + " grad_filter = convolve2d(grad_output.numpy(), input.numpy(), mode='valid')\n", |
| 120 | + " return torch.FloatTensor(grad_input), torch.FloatTensor(grad_filter)\n", |
| 121 | + "\n", |
| 122 | + "\n", |
| 123 | + "class ScipyConv2d(Module):\n", |
| 124 | + " \n", |
| 125 | + " def __init__(self, kh, kw):\n", |
| 126 | + " super(ScipyConv2d, self).__init__(\n", |
| 127 | + " filter=torch.randn(kh, kw)\n", |
| 128 | + " )\n", |
| 129 | + " \n", |
| 130 | + " def forward(self, input):\n", |
| 131 | + " return ScipyConv2dFunction()(input, self.filter)" |
| 132 | + ] |
| 133 | + }, |
| 134 | + { |
| 135 | + "cell_type": "code", |
| 136 | + "execution_count": 96, |
| 137 | + "metadata": { |
| 138 | + "collapsed": false |
| 139 | + }, |
| 140 | + "outputs": [ |
| 141 | + { |
| 142 | + "name": "stdout", |
| 143 | + "output_type": "stream", |
| 144 | + "text": [ |
| 145 | + "[Variable containing:\n", |
| 146 | + "-1.5070 1.2195 0.3059\n", |
| 147 | + "-0.9716 -1.6591 0.0582\n", |
| 148 | + " 0.3959 1.4859 0.5762\n", |
| 149 | + "[torch.FloatTensor of size 3x3]\n", |
| 150 | + "]\n", |
| 151 | + "Variable containing:\n", |
| 152 | + " 0.8031 -2.6673 -3.7764 0.3957 -3.7494 -1.7617 -1.0052 -5.8402\n", |
| 153 | + " 1.3038 6.2255 3.8769 2.4016 -1.7805 -3.1314 4.7049 11.2956\n", |
| 154 | + " -3.4491 0.1618 -2.5647 2.3304 -0.2030 0.9072 -3.5095 -1.4599\n", |
| 155 | + " 1.7574 0.6292 0.5140 -0.9045 -0.7373 -1.2061 -2.2977 3.6035\n", |
| 156 | + " 0.4435 -1.0651 -0.5496 0.6387 1.7522 4.5231 -0.5720 -3.3034\n", |
| 157 | + " -0.8580 -0.4809 2.4041 7.1462 -6.4747 -5.3665 2.0541 4.8248\n", |
| 158 | + " -3.3959 0.2333 -0.2029 -2.6130 2.9378 2.5276 -0.8665 -2.6157\n", |
| 159 | + " 4.6814 -5.2214 5.0351 0.9138 -5.0147 -3.1597 1.9054 -1.2458\n", |
| 160 | + "[torch.FloatTensor of size 8x8]\n", |
| 161 | + "\n", |
| 162 | + "\n", |
| 163 | + " 0.1741 -1.9989 -0.2740 3.8120 0.3502 0.6712 3.0274 1.7058 0.4150 -0.3298\n", |
| 164 | + "-1.8919 -2.6355 -3.2564 3.6947 2.5255 -6.7857 0.2239 -1.5672 -0.2663 -1.1211\n", |
| 165 | + " 2.8815 2.5121 -4.7712 3.5822 -4.3752 0.7339 -0.7228 -1.7776 -2.0243 0.5019\n", |
| 166 | + "-0.8926 0.1823 -4.3306 1.6298 1.4614 -1.5850 3.6988 3.1788 -1.2472 1.7891\n", |
| 167 | + "-0.4497 2.5219 -0.0277 -2.5140 8.4283 -2.7177 -0.7160 2.5198 4.2670 -1.8847\n", |
| 168 | + "-2.7016 -4.0250 2.7055 -0.6101 3.5926 0.5576 -1.8934 -3.3632 5.5995 -4.8563\n", |
| 169 | + " 2.6918 -1.4062 1.1848 -1.7458 2.4408 0.9058 -3.6130 -3.0862 -0.1350 -1.6894\n", |
| 170 | + "-0.2913 2.1607 4.0600 -1.4186 -4.5283 3.7960 -5.8559 -0.2632 -1.5944 1.9401\n", |
| 171 | + " 0.4020 -2.5734 2.3380 -0.0078 -3.0894 3.5005 -1.3228 1.2757 0.7101 1.7986\n", |
| 172 | + " 0.1187 -0.4283 -0.0142 -0.5494 -0.2744 0.8786 0.2644 0.7838 0.6230 0.4126\n", |
| 173 | + "[torch.FloatTensor of size 10x10]\n", |
| 174 | + "\n" |
| 175 | + ] |
| 176 | + } |
| 177 | + ], |
| 178 | + "source": [ |
| 179 | + "module = ScipyConv2d(3, 3)\n", |
| 180 | + "print(list(module.parameters()))\n", |
| 181 | + "input = Variable(torch.randn(10, 10), requires_grad=True)\n", |
| 182 | + "output = module(input)\n", |
| 183 | + "print(output)\n", |
| 184 | + "output.backward(torch.randn(8, 8))\n", |
| 185 | + "print(input.grad)\n" |
| 186 | + ] |
| 187 | + }, |
| 188 | + { |
| 189 | + "cell_type": "code", |
| 190 | + "execution_count": null, |
| 191 | + "metadata": { |
| 192 | + "collapsed": true |
| 193 | + }, |
| 194 | + "outputs": [], |
| 195 | + "source": [] |
| 196 | + } |
| 197 | + ], |
| 198 | + "metadata": { |
| 199 | + "kernelspec": { |
| 200 | + "display_name": "Python 3", |
| 201 | + "language": "python", |
| 202 | + "name": "python3" |
| 203 | + }, |
| 204 | + "language_info": { |
| 205 | + "codemirror_mode": { |
| 206 | + "name": "ipython", |
| 207 | + "version": 3 |
| 208 | + }, |
| 209 | + "file_extension": ".py", |
| 210 | + "mimetype": "text/x-python", |
| 211 | + "name": "python", |
| 212 | + "nbconvert_exporter": "python", |
| 213 | + "pygments_lexer": "ipython3", |
| 214 | + "version": "3.5.2" |
| 215 | + } |
| 216 | + }, |
| 217 | + "nbformat": 4, |
| 218 | + "nbformat_minor": 1 |
| 219 | +} |
0 commit comments