Skip to content

Commit 92016cd

Browse files
committed
adding FFI tutorial
1 parent 0cd9f0b commit 92016cd

File tree

1 file changed

+117
-0
lines changed

1 file changed

+117
-0
lines changed

Creating Extensions using FFI.md

+117
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
# Custom C extensions for pytorch
2+
3+
## Step 1. prepare your C code
4+
5+
First, you have to write your C functions.
6+
7+
Below you can find an example implementation of forward and backward functions of a module that adds its both inputs.
8+
9+
In your .c files you can include TH using an #include <TH/TH.h> directive, and THC using #include <THC/THC.h>.
10+
11+
ffi utils will make sure a compiler can find them during the build.
12+
13+
```C
14+
/* src/my_lib.c */
15+
#include <TH/TH.h>
16+
17+
int my_lib_add_forward(THFloatTensor *input1, THFloatTensor *input2,
18+
THFloatTensor *output)
19+
{
20+
if (!THFloatTensor_isSameSizeAs(input1, input2))
21+
return 0;
22+
THFloatTensor_resizeAs(output, input1);
23+
THFloatTensor_add(output, input1, input2);
24+
return 1;
25+
}
26+
27+
int my_lib_add_backward(THFloatTensor *grad_output, THFloatTensor *grad_input)
28+
{
29+
THFloatTensor_resizeAs(grad_input, grad_output);
30+
THFloatTensor_fill(grad_input, 1);
31+
return 1;
32+
}
33+
```
34+
35+
There are no constraints on the code, except that you will have to prepare a single header,
36+
which will list all functions want to call from python.
37+
38+
It will be used by the ffi utils to generate appropriate wrappers.
39+
40+
```C
41+
/* src/my_lib.h */
42+
int my_lib_add_forward(THFloatTensor *input1, THFloatTensor *input2,
43+
THFloatTensor *output);
44+
int my_lib_add_backward(THFloatTensor *grad_output, THFloatTensor *grad_input);
45+
```
46+
47+
Now, you'll need a super short file, that will build your custom extension:
48+
49+
```python
50+
# build.py
51+
from torch.utils.ffi import compile_extension
52+
compile_extension(
53+
name='_ext.my_lib',
54+
header='src/my_lib.h',
55+
sources=['src/my_lib.c'],
56+
with_cuda=False
57+
)
58+
```
59+
60+
## Step 2: Include it in your Python code
61+
62+
After you run it, pytorch will create an _ext directory and put my_lib inside.
63+
64+
Package name can have an arbitrary number of packages preceding the final module name (including none).
65+
If the build succeeded you can import your extension just like a regular python file.
66+
67+
```python
68+
# functions/add.py
69+
import torch
70+
from torch.autograd import Function
71+
from _ext import my_lib
72+
73+
74+
class MyAddFunction(Function):
75+
76+
def forward(self, input1, input2):
77+
output = torch.FloatTensor()
78+
my_lib.my_lib_add_forward(input1, input2, output)
79+
return output
80+
81+
def backward(self, grad_output):
82+
grad_input = torch.FloatTensor()
83+
my_lib.my_lib_add_backward(grad_output, grad_input)
84+
return grad_input
85+
```
86+
87+
```python
88+
# modules/add.py
89+
from torch.nn import Module
90+
from functions.add import MyAddFunction
91+
92+
class MyAddModule(Module):
93+
94+
def forward(self, input1, input2):
95+
return MyAddFunction()(input1, input2)
96+
```
97+
98+
```python
99+
# main.py
100+
import torch.nn as nn
101+
from torch.autograd import Variable
102+
from modules.add import MyAddModule
103+
104+
class MyNetwork(nn.Container):
105+
def __init__(self):
106+
super(MyNetwork, self).__init__(
107+
add=MyAddModule(),
108+
)
109+
110+
def forward(self, input1, input2):
111+
return self.add(input1, input2)
112+
113+
model = MyNetwork()
114+
input1, input2 = Variable(torch.randn(5, 5)), Variable(torch.randn(5, 5))
115+
print(model(input1, input2))
116+
print(input1 + input2)
117+
```

0 commit comments

Comments
 (0)