Skip to content

Commit c6c0e8d

Browse files
Formatting and indentation fixes
1 parent 9de1575 commit c6c0e8d

File tree

1 file changed

+28
-31
lines changed

1 file changed

+28
-31
lines changed

Creating Extensions using FFI.md

+28-31
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ First, you have to write your C functions.
66

77
Below you can find an example implementation of forward and backward functions of a module that adds its both inputs.
88

9-
In your .c files you can include TH using an #include <TH/TH.h> directive, and THC using #include <THC/THC.h>.
9+
In your `.c` files you can include TH using an `#include <TH/TH.h>` directive, and THC using `#include <THC/THC.h>`.
1010

1111
ffi utils will make sure a compiler can find them during the build.
1212

@@ -17,18 +17,18 @@ ffi utils will make sure a compiler can find them during the build.
1717
int my_lib_add_forward(THFloatTensor *input1, THFloatTensor *input2,
1818
THFloatTensor *output)
1919
{
20-
if (!THFloatTensor_isSameSizeAs(input1, input2))
21-
return 0;
22-
THFloatTensor_resizeAs(output, input1);
23-
THFloatTensor_add(output, input1, input2);
24-
return 1;
20+
if (!THFloatTensor_isSameSizeAs(input1, input2))
21+
return 0;
22+
THFloatTensor_resizeAs(output, input1);
23+
THFloatTensor_add(output, input1, input2);
24+
return 1;
2525
}
2626

2727
int my_lib_add_backward(THFloatTensor *grad_output, THFloatTensor *grad_input)
2828
{
29-
THFloatTensor_resizeAs(grad_input, grad_output);
30-
THFloatTensor_fill(grad_input, 1);
31-
return 1;
29+
THFloatTensor_resizeAs(grad_input, grad_output);
30+
THFloatTensor_fill(grad_input, 1);
31+
return 1;
3232
}
3333
```
3434
@@ -39,8 +39,7 @@ It will be used by the ffi utils to generate appropriate wrappers.
3939
4040
```C
4141
/* src/my_lib.h */
42-
int my_lib_add_forward(THFloatTensor *input1, THFloatTensor *input2,
43-
THFloatTensor *output);
42+
int my_lib_add_forward(THFloatTensor *input1, THFloatTensor *input2, THFloatTensor *output);
4443
int my_lib_add_backward(THFloatTensor *grad_output, THFloatTensor *grad_input);
4544
```
4645

@@ -59,7 +58,7 @@ with_cuda=False
5958

6059
## Step 2: Include it in your Python code
6160

62-
After you run it, pytorch will create an _ext directory and put my_lib inside.
61+
After you run it, pytorch will create an `_ext` directory and put `my_lib` inside.
6362

6463
Package name can have an arbitrary number of packages preceding the final module name (including none).
6564
If the build succeeded you can import your extension just like a regular python file.
@@ -72,16 +71,15 @@ from _ext import my_lib
7271

7372

7473
class MyAddFunction(Function):
75-
76-
def forward(self, input1, input2):
77-
output = torch.FloatTensor()
78-
my_lib.my_lib_add_forward(input1, input2, output)
79-
return output
80-
81-
def backward(self, grad_output):
82-
grad_input = torch.FloatTensor()
83-
my_lib.my_lib_add_backward(grad_output, grad_input)
84-
return grad_input
74+
def forward(self, input1, input2):
75+
output = torch.FloatTensor()
76+
my_lib.my_lib_add_forward(input1, input2, output)
77+
return output
78+
79+
def backward(self, grad_output):
80+
grad_input = torch.FloatTensor()
81+
my_lib.my_lib_add_backward(grad_output, grad_input)
82+
return grad_input
8583
```
8684

8785
```python
@@ -90,9 +88,8 @@ from torch.nn import Module
9088
from functions.add import MyAddFunction
9189

9290
class MyAddModule(Module):
93-
94-
def forward(self, input1, input2):
95-
return MyAddFunction()(input1, input2)
91+
def forward(self, input1, input2):
92+
return MyAddFunction()(input1, input2)
9693
```
9794

9895
```python
@@ -102,13 +99,13 @@ from torch.autograd import Variable
10299
from modules.add import MyAddModule
103100

104101
class MyNetwork(nn.Container):
105-
def __init__(self):
106-
super(MyNetwork, self).__init__(
107-
add=MyAddModule(),
108-
)
102+
def __init__(self):
103+
super(MyNetwork, self).__init__(
104+
add=MyAddModule(),
105+
)
109106

110-
def forward(self, input1, input2):
111-
return self.add(input1, input2)
107+
def forward(self, input1, input2):
108+
return self.add(input1, input2)
112109

113110
model = MyNetwork()
114111
input1, input2 = Variable(torch.randn(5, 5)), Variable(torch.randn(5, 5))

0 commit comments

Comments
 (0)