Skip to content

Commit ac2bd4b

Browse files
committed
Add numpy and scipy extension example
1 parent 12d5986 commit ac2bd4b

File tree

1 file changed

+219
-0
lines changed

1 file changed

+219
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,219 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "code",
5+
"execution_count": 92,
6+
"metadata": {
7+
"collapsed": false
8+
},
9+
"outputs": [],
10+
"source": [
11+
"import torch\n",
12+
"from torch.autograd import Function"
13+
]
14+
},
15+
{
16+
"cell_type": "markdown",
17+
"metadata": {},
18+
"source": [
19+
"# Parameter-less example"
20+
]
21+
},
22+
{
23+
"cell_type": "code",
24+
"execution_count": 93,
25+
"metadata": {
26+
"collapsed": false
27+
},
28+
"outputs": [],
29+
"source": [
30+
"from numpy.fft import rfft2, irfft2\n",
31+
"class BadFFTFunction(Function):\n",
32+
" \n",
33+
" def forward(self, input):\n",
34+
" numpy_input = input.numpy()\n",
35+
" result = abs(rfft2(numpy_input))\n",
36+
" return torch.FloatTensor(result)\n",
37+
" \n",
38+
" def backward(self, grad_output):\n",
39+
" numpy_go = grad_output.numpy()\n",
40+
" result = irfft2(numpy_go)\n",
41+
" return torch.FloatTensor(result)\n",
42+
"\n",
43+
"def incorrect_fft(input):\n",
44+
" return FFTFunction()(input)"
45+
]
46+
},
47+
{
48+
"cell_type": "code",
49+
"execution_count": 94,
50+
"metadata": {
51+
"collapsed": false
52+
},
53+
"outputs": [
54+
{
55+
"name": "stdout",
56+
"output_type": "stream",
57+
"text": [
58+
"\n",
59+
" 3.0878 7.1403 7.5860 1.7596 3.0176\n",
60+
" 6.3160 15.2517 11.1081 0.9172 6.8577\n",
61+
" 8.6503 2.2013 6.3555 11.1981 1.9266\n",
62+
" 3.9919 6.8862 8.8132 5.7938 4.2413\n",
63+
" 12.2501 10.7839 6.7181 12.1096 1.1942\n",
64+
" 3.9919 9.3072 2.6704 3.3263 4.2413\n",
65+
" 8.6503 6.8158 12.4148 2.6462 1.9266\n",
66+
" 6.3160 15.2663 9.8261 5.8583 6.8577\n",
67+
"[torch.FloatTensor of size 8x5]\n",
68+
"\n",
69+
"\n",
70+
" 0.0569 -0.3193 0.0401 0.1293 0.0318 0.1293 0.0401 -0.3193\n",
71+
" 0.0570 0.0161 -0.0421 -0.1272 0.0414 0.0121 -0.0592 -0.0874\n",
72+
"-0.1144 -0.0146 0.0604 -0.0023 0.0222 0.0622 0.0825 -0.1057\n",
73+
"-0.0451 0.1061 0.0329 -0.0274 0.0302 -0.0347 0.0227 -0.1079\n",
74+
" 0.1287 0.1796 -0.0766 -0.0698 0.0929 -0.0698 -0.0766 0.1796\n",
75+
"-0.0451 -0.1079 0.0227 -0.0347 0.0302 -0.0274 0.0329 0.1061\n",
76+
"-0.1144 -0.1057 0.0825 0.0622 0.0222 -0.0023 0.0604 -0.0146\n",
77+
" 0.0570 -0.0874 -0.0592 0.0121 0.0414 -0.1272 -0.0421 0.0161\n",
78+
"[torch.FloatTensor of size 8x8]\n",
79+
"\n"
80+
]
81+
}
82+
],
83+
"source": [
84+
"input = Variable(torch.randn(8, 8), requires_grad=True)\n",
85+
"result = incorrect_fft(input)\n",
86+
"print(result.data)\n",
87+
"result.backward(torch.randn(result.size()))\n",
88+
"print(input.grad)"
89+
]
90+
},
91+
{
92+
"cell_type": "markdown",
93+
"metadata": {},
94+
"source": [
95+
"# Parametrized example"
96+
]
97+
},
98+
{
99+
"cell_type": "code",
100+
"execution_count": 95,
101+
"metadata": {
102+
"collapsed": false
103+
},
104+
"outputs": [],
105+
"source": [
106+
"from scipy.signal import convolve2d, correlate2d\n",
107+
"from torch.nn.modules.module import Module\n",
108+
"\n",
109+
"class ScipyConv2dFunction(Function):\n",
110+
" \n",
111+
" def forward(self, input, filter):\n",
112+
" result = correlate2d(input.numpy(), filter.numpy(), mode='valid')\n",
113+
" self.save_for_backward(input, filter)\n",
114+
" return torch.FloatTensor(result)\n",
115+
" \n",
116+
" def backward(self, grad_output):\n",
117+
" input, filter = self.saved_tensors\n",
118+
" grad_input = convolve2d(grad_output.numpy(), filter.t().numpy(), mode='full')\n",
119+
" grad_filter = convolve2d(grad_output.numpy(), input.numpy(), mode='valid')\n",
120+
" return torch.FloatTensor(grad_input), torch.FloatTensor(grad_filter)\n",
121+
"\n",
122+
"\n",
123+
"class ScipyConv2d(Module):\n",
124+
" \n",
125+
" def __init__(self, kh, kw):\n",
126+
" super(ScipyConv2d, self).__init__(\n",
127+
" filter=torch.randn(kh, kw)\n",
128+
" )\n",
129+
" \n",
130+
" def forward(self, input):\n",
131+
" return ScipyConv2dFunction()(input, self.filter)"
132+
]
133+
},
134+
{
135+
"cell_type": "code",
136+
"execution_count": 96,
137+
"metadata": {
138+
"collapsed": false
139+
},
140+
"outputs": [
141+
{
142+
"name": "stdout",
143+
"output_type": "stream",
144+
"text": [
145+
"[Variable containing:\n",
146+
"-1.5070 1.2195 0.3059\n",
147+
"-0.9716 -1.6591 0.0582\n",
148+
" 0.3959 1.4859 0.5762\n",
149+
"[torch.FloatTensor of size 3x3]\n",
150+
"]\n",
151+
"Variable containing:\n",
152+
" 0.8031 -2.6673 -3.7764 0.3957 -3.7494 -1.7617 -1.0052 -5.8402\n",
153+
" 1.3038 6.2255 3.8769 2.4016 -1.7805 -3.1314 4.7049 11.2956\n",
154+
" -3.4491 0.1618 -2.5647 2.3304 -0.2030 0.9072 -3.5095 -1.4599\n",
155+
" 1.7574 0.6292 0.5140 -0.9045 -0.7373 -1.2061 -2.2977 3.6035\n",
156+
" 0.4435 -1.0651 -0.5496 0.6387 1.7522 4.5231 -0.5720 -3.3034\n",
157+
" -0.8580 -0.4809 2.4041 7.1462 -6.4747 -5.3665 2.0541 4.8248\n",
158+
" -3.3959 0.2333 -0.2029 -2.6130 2.9378 2.5276 -0.8665 -2.6157\n",
159+
" 4.6814 -5.2214 5.0351 0.9138 -5.0147 -3.1597 1.9054 -1.2458\n",
160+
"[torch.FloatTensor of size 8x8]\n",
161+
"\n",
162+
"\n",
163+
" 0.1741 -1.9989 -0.2740 3.8120 0.3502 0.6712 3.0274 1.7058 0.4150 -0.3298\n",
164+
"-1.8919 -2.6355 -3.2564 3.6947 2.5255 -6.7857 0.2239 -1.5672 -0.2663 -1.1211\n",
165+
" 2.8815 2.5121 -4.7712 3.5822 -4.3752 0.7339 -0.7228 -1.7776 -2.0243 0.5019\n",
166+
"-0.8926 0.1823 -4.3306 1.6298 1.4614 -1.5850 3.6988 3.1788 -1.2472 1.7891\n",
167+
"-0.4497 2.5219 -0.0277 -2.5140 8.4283 -2.7177 -0.7160 2.5198 4.2670 -1.8847\n",
168+
"-2.7016 -4.0250 2.7055 -0.6101 3.5926 0.5576 -1.8934 -3.3632 5.5995 -4.8563\n",
169+
" 2.6918 -1.4062 1.1848 -1.7458 2.4408 0.9058 -3.6130 -3.0862 -0.1350 -1.6894\n",
170+
"-0.2913 2.1607 4.0600 -1.4186 -4.5283 3.7960 -5.8559 -0.2632 -1.5944 1.9401\n",
171+
" 0.4020 -2.5734 2.3380 -0.0078 -3.0894 3.5005 -1.3228 1.2757 0.7101 1.7986\n",
172+
" 0.1187 -0.4283 -0.0142 -0.5494 -0.2744 0.8786 0.2644 0.7838 0.6230 0.4126\n",
173+
"[torch.FloatTensor of size 10x10]\n",
174+
"\n"
175+
]
176+
}
177+
],
178+
"source": [
179+
"module = ScipyConv2d(3, 3)\n",
180+
"print(list(module.parameters()))\n",
181+
"input = Variable(torch.randn(10, 10), requires_grad=True)\n",
182+
"output = module(input)\n",
183+
"print(output)\n",
184+
"output.backward(torch.randn(8, 8))\n",
185+
"print(input.grad)\n"
186+
]
187+
},
188+
{
189+
"cell_type": "code",
190+
"execution_count": null,
191+
"metadata": {
192+
"collapsed": true
193+
},
194+
"outputs": [],
195+
"source": []
196+
}
197+
],
198+
"metadata": {
199+
"kernelspec": {
200+
"display_name": "Python 3",
201+
"language": "python",
202+
"name": "python3"
203+
},
204+
"language_info": {
205+
"codemirror_mode": {
206+
"name": "ipython",
207+
"version": 3
208+
},
209+
"file_extension": ".py",
210+
"mimetype": "text/x-python",
211+
"name": "python",
212+
"nbconvert_exporter": "python",
213+
"pygments_lexer": "ipython3",
214+
"version": "3.5.2"
215+
}
216+
},
217+
"nbformat": 4,
218+
"nbformat_minor": 1
219+
}

0 commit comments

Comments
 (0)