Skip to content

Commit

Permalink
updated tutorial with more info
Browse files Browse the repository at this point in the history
  • Loading branch information
soumith committed Oct 14, 2016
1 parent ac2bd4b commit ba60845
Show file tree
Hide file tree
Showing 2 changed files with 103 additions and 62 deletions.
164 changes: 102 additions & 62 deletions Creating extensions using numpy and scipy.ipynb
Original file line number Diff line number Diff line change
@@ -1,27 +1,48 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Creating extensions using numpy and scipy\n",
"\n",
"In this notebook, we shall go through two tasks:\n",
"\n",
"1. Create a neural network layer with no parameters. \n",
" - This calls into **numpy** as part of it's implementation\n",
"2. Create a neural network layer that has learnable weights\n",
" - This calls into **SciPy** as part of it's implementation"
]
},
{
"cell_type": "code",
"execution_count": 92,
"execution_count": 37,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"import torch\n",
"from torch.autograd import Function"
"from torch.autograd import Function\n",
"from torch.autograd import Variable"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Parameter-less example"
"## Parameter-less example\n",
"\n",
"This layer doesn't particularly do anything useful or mathematically correct.\n",
"\n",
"It is aptly named BadFFTFunction\n",
"\n",
"**Layer Implementation**"
]
},
{
"cell_type": "code",
"execution_count": 93,
"execution_count": 38,
"metadata": {
"collapsed": false
},
Expand All @@ -40,13 +61,22 @@
" result = irfft2(numpy_go)\n",
" return torch.FloatTensor(result)\n",
"\n",
"# since this layer does not have any parameters, we can\n",
"# simply declare this as a function, rather than as an nn.Module class\n",
"def incorrect_fft(input):\n",
" return FFTFunction()(input)"
" return BadFFTFunction()(input)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Example usage of the created layer:**"
]
},
{
"cell_type": "code",
"execution_count": 94,
"execution_count": 39,
"metadata": {
"collapsed": false
},
Expand All @@ -56,25 +86,25 @@
"output_type": "stream",
"text": [
"\n",
" 3.0878 7.1403 7.5860 1.7596 3.0176\n",
" 6.3160 15.2517 11.1081 0.9172 6.8577\n",
" 8.6503 2.2013 6.3555 11.1981 1.9266\n",
" 3.9919 6.8862 8.8132 5.7938 4.2413\n",
" 12.2501 10.7839 6.7181 12.1096 1.1942\n",
" 3.9919 9.3072 2.6704 3.3263 4.2413\n",
" 8.6503 6.8158 12.4148 2.6462 1.9266\n",
" 6.3160 15.2663 9.8261 5.8583 6.8577\n",
" 4.7742 8.5149 9.8856 10.2735 8.4410\n",
" 3.8592 2.2888 5.0019 5.9478 5.1993\n",
" 4.6596 3.4522 5.9725 11.0878 7.8076\n",
" 8.2634 6.6598 6.0634 15.5515 6.9418\n",
" 0.6407 7.4943 0.8726 4.4138 7.1496\n",
" 8.2634 6.8300 2.8353 8.3108 6.9418\n",
" 4.6596 1.9511 6.3037 5.1471 7.8076\n",
" 3.8592 7.3977 7.2260 1.6832 5.1993\n",
"[torch.FloatTensor of size 8x5]\n",
"\n",
"\n",
" 0.0569 -0.3193 0.0401 0.1293 0.0318 0.1293 0.0401 -0.3193\n",
" 0.0570 0.0161 -0.0421 -0.1272 0.0414 0.0121 -0.0592 -0.0874\n",
"-0.1144 -0.0146 0.0604 -0.0023 0.0222 0.0622 0.0825 -0.1057\n",
"-0.0451 0.1061 0.0329 -0.0274 0.0302 -0.0347 0.0227 -0.1079\n",
" 0.1287 0.1796 -0.0766 -0.0698 0.0929 -0.0698 -0.0766 0.1796\n",
"-0.0451 -0.1079 0.0227 -0.0347 0.0302 -0.0274 0.0329 0.1061\n",
"-0.1144 -0.1057 0.0825 0.0622 0.0222 -0.0023 0.0604 -0.0146\n",
" 0.0570 -0.0874 -0.0592 0.0121 0.0414 -0.1272 -0.0421 0.0161\n",
" 0.1044 0.0067 -0.0247 -0.0800 -0.1355 -0.0800 -0.0247 0.0067\n",
"-0.1948 -0.0138 -0.1396 -0.0084 0.0774 0.0370 0.1352 0.1332\n",
"-0.0153 -0.0668 0.1799 0.0574 0.0394 0.1392 0.0268 -0.1462\n",
" 0.0199 0.0676 -0.1475 -0.0332 0.1312 0.0740 -0.1128 -0.1948\n",
"-0.0416 -0.0159 -0.0166 -0.0070 0.1471 -0.0070 -0.0166 -0.0159\n",
" 0.0199 -0.1948 -0.1128 0.0740 0.1312 -0.0332 -0.1475 0.0676\n",
"-0.0153 -0.1462 0.0268 0.1392 0.0394 0.0574 0.1799 -0.0668\n",
"-0.1948 0.1332 0.1352 0.0370 0.0774 -0.0084 -0.1396 -0.0138\n",
"[torch.FloatTensor of size 8x8]\n",
"\n"
]
Expand All @@ -92,12 +122,24 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"# Parametrized example"
"## Parametrized example\n",
"\n",
"This implements a layer with learnable weights.\n",
"\n",
"It implements the Cross-correlation with a learnable kernel.\n",
"\n",
"In deep learning literature, it's confusingly referred to as Convolution.\n",
"\n",
"The backward computes the gradients wrt the input and gradients wrt the filter.\n",
"\n",
"**Implementation:**\n",
"\n",
"*Please Note that the implementation serves as an illustration, and we did not verify it's correctness*"
]
},
{
"cell_type": "code",
"execution_count": 95,
"execution_count": 40,
"metadata": {
"collapsed": false
},
Expand All @@ -116,7 +158,7 @@
" def backward(self, grad_output):\n",
" input, filter = self.saved_tensors\n",
" grad_input = convolve2d(grad_output.numpy(), filter.t().numpy(), mode='full')\n",
" grad_filter = convolve2d(grad_output.numpy(), input.numpy(), mode='valid')\n",
" grad_filter = convolve2d(input.numpy(), grad_output.numpy(), mode='valid')\n",
" return torch.FloatTensor(grad_input), torch.FloatTensor(grad_filter)\n",
"\n",
"\n",
Expand All @@ -131,9 +173,16 @@
" return ScipyConv2dFunction()(input, self.filter)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Example usage: **"
]
},
{
"cell_type": "code",
"execution_count": 96,
"execution_count": 41,
"metadata": {
"collapsed": false
},
Expand All @@ -143,33 +192,33 @@
"output_type": "stream",
"text": [
"[Variable containing:\n",
"-1.5070 1.2195 0.3059\n",
"-0.9716 -1.6591 0.0582\n",
" 0.3959 1.4859 0.5762\n",
"-1.0235 0.9875 0.2565\n",
" 0.1980 -0.6102 0.1088\n",
"-0.2887 0.4421 0.4697\n",
"[torch.FloatTensor of size 3x3]\n",
"]\n",
"Variable containing:\n",
" 0.8031 -2.6673 -3.7764 0.3957 -3.7494 -1.7617 -1.0052 -5.8402\n",
" 1.3038 6.2255 3.8769 2.4016 -1.7805 -3.1314 4.7049 11.2956\n",
" -3.4491 0.1618 -2.5647 2.3304 -0.2030 0.9072 -3.5095 -1.4599\n",
" 1.7574 0.6292 0.5140 -0.9045 -0.7373 -1.2061 -2.2977 3.6035\n",
" 0.4435 -1.0651 -0.5496 0.6387 1.7522 4.5231 -0.5720 -3.3034\n",
" -0.8580 -0.4809 2.4041 7.1462 -6.4747 -5.3665 2.0541 4.8248\n",
" -3.3959 0.2333 -0.2029 -2.6130 2.9378 2.5276 -0.8665 -2.6157\n",
" 4.6814 -5.2214 5.0351 0.9138 -5.0147 -3.1597 1.9054 -1.2458\n",
" 0.7426 -0.4963 2.1839 -0.0167 -1.6349 -0.7259 -0.2989 0.0568\n",
"-0.3100 2.2298 -2.2832 0.5753 4.0489 0.1377 0.1672 0.6429\n",
"-1.8680 1.3115 1.8970 0.3323 -4.5448 -0.0464 -2.3960 1.5496\n",
"-0.6578 0.6759 0.5512 -0.3498 2.6668 1.3984 1.9388 -1.6464\n",
"-0.5867 0.5676 2.8697 -0.5566 -2.8876 1.2372 -1.1336 -0.0219\n",
"-2.1587 1.1444 -0.5513 -0.5551 1.8229 0.6331 -0.0577 -1.4510\n",
" 2.6664 1.4183 2.1640 0.4424 -0.3112 -2.0792 1.7458 -3.3291\n",
"-0.4942 -2.1142 -0.2624 0.8993 1.4487 2.1706 -1.4943 0.8073\n",
"[torch.FloatTensor of size 8x8]\n",
"\n",
"\n",
" 0.1741 -1.9989 -0.2740 3.8120 0.3502 0.6712 3.0274 1.7058 0.4150 -0.3298\n",
"-1.8919 -2.6355 -3.2564 3.6947 2.5255 -6.7857 0.2239 -1.5672 -0.2663 -1.1211\n",
" 2.8815 2.5121 -4.7712 3.5822 -4.3752 0.7339 -0.7228 -1.7776 -2.0243 0.5019\n",
"-0.8926 0.1823 -4.3306 1.6298 1.4614 -1.5850 3.6988 3.1788 -1.2472 1.7891\n",
"-0.4497 2.5219 -0.0277 -2.5140 8.4283 -2.7177 -0.7160 2.5198 4.2670 -1.8847\n",
"-2.7016 -4.0250 2.7055 -0.6101 3.5926 0.5576 -1.8934 -3.3632 5.5995 -4.8563\n",
" 2.6918 -1.4062 1.1848 -1.7458 2.4408 0.9058 -3.6130 -3.0862 -0.1350 -1.6894\n",
"-0.2913 2.1607 4.0600 -1.4186 -4.5283 3.7960 -5.8559 -0.2632 -1.5944 1.9401\n",
" 0.4020 -2.5734 2.3380 -0.0078 -3.0894 3.5005 -1.3228 1.2757 0.7101 1.7986\n",
" 0.1187 -0.4283 -0.0142 -0.5494 -0.2744 0.8786 0.2644 0.7838 0.6230 0.4126\n",
" 0.2528 0.6793 1.4519 0.8932 -1.6100 0.2802 0.7728 -1.7915 0.6271 -0.4103\n",
" 1.1033 0.9326 -0.6076 0.0806 2.0530 -1.5469 -0.4001 2.3436 -1.4082 0.6746\n",
"-2.2699 0.4997 -1.0990 -0.9396 -2.2007 -0.3414 -1.1383 1.5647 -0.8794 0.9267\n",
"-0.0902 -2.0114 1.1145 -1.1107 0.4190 -0.7028 2.7191 -0.6072 1.3405 -0.2114\n",
" 3.1340 -1.3749 0.5132 0.1247 1.3468 0.2727 -1.0975 0.5712 0.2452 -1.0394\n",
"-1.7159 2.4817 -0.0412 -0.9571 0.8877 0.5806 0.1002 0.0128 -0.6611 -0.6181\n",
"-1.6527 -2.9061 -3.1407 0.1848 -1.4983 0.1549 0.0607 -1.4082 0.7121 -0.5538\n",
" 0.8319 2.1323 -0.5079 -1.8576 -0.9979 -1.6148 -1.2104 -0.2222 -0.6102 0.1271\n",
"-0.0115 -0.5239 2.0231 1.3474 0.3604 1.7257 -0.3180 1.3881 0.0142 0.9140\n",
"-0.0512 -0.3274 -0.1038 -0.1919 0.4578 1.0406 0.5750 1.0693 0.4735 0.4023\n",
"[torch.FloatTensor of size 10x10]\n",
"\n"
]
Expand All @@ -182,36 +231,27 @@
"output = module(input)\n",
"print(output)\n",
"output.backward(torch.randn(8, 8))\n",
"print(input.grad)\n"
"print(input.grad)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"display_name": "Python 2",
"language": "python",
"name": "python3"
"name": "python2"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.5.2"
"pygments_lexer": "ipython2",
"version": "2.7.12"
}
},
"nbformat": 4,
Expand Down
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@ PyTorch Tutorials
- A perfect introduction to PyTorch's torch, autograd, nn and optim APIs for the former Torch user
2. Custom C extensions
- [Write your own C code that interfaces into PyTorch via FFI](Creating%20Extensions%20using%20FFI.md)
3. [Writing your own neural network module that uses numpy and scipy](Creating extensions using numpy and scipy.ipynb)

0 comments on commit ba60845

Please sign in to comment.