@@ -6,7 +6,7 @@ First, you have to write your C functions.
6
6
7
7
Below you can find an example implementation of forward and backward functions of a module that adds its both inputs.
8
8
9
- In your .c files you can include TH using an #include <TH/TH.h> directive, and THC using #include <THC/THC.h>.
9
+ In your ` .c ` files you can include TH using an ` #include <TH/TH.h> ` directive, and THC using ` #include <THC/THC.h> ` .
10
10
11
11
ffi utils will make sure a compiler can find them during the build.
12
12
@@ -17,18 +17,18 @@ ffi utils will make sure a compiler can find them during the build.
17
17
int my_lib_add_forward (THFloatTensor * input1, THFloatTensor * input2,
18
18
THFloatTensor * output)
19
19
{
20
- if (!THFloatTensor_isSameSizeAs(input1, input2))
21
- return 0;
22
- THFloatTensor_resizeAs(output, input1);
23
- THFloatTensor_add(output, input1, input2);
24
- return 1;
20
+ if (!THFloatTensor_isSameSizeAs(input1, input2))
21
+ return 0;
22
+ THFloatTensor_resizeAs(output, input1);
23
+ THFloatTensor_add(output, input1, input2);
24
+ return 1;
25
25
}
26
26
27
27
int my_lib_add_backward(THFloatTensor * grad_output, THFloatTensor * grad_input)
28
28
{
29
- THFloatTensor_resizeAs(grad_input, grad_output);
30
- THFloatTensor_fill(grad_input, 1);
31
- return 1;
29
+ THFloatTensor_resizeAs(grad_input, grad_output);
30
+ THFloatTensor_fill(grad_input, 1);
31
+ return 1;
32
32
}
33
33
```
34
34
@@ -39,8 +39,7 @@ It will be used by the ffi utils to generate appropriate wrappers.
39
39
40
40
```C
41
41
/* src/my_lib.h */
42
- int my_lib_add_forward(THFloatTensor *input1, THFloatTensor *input2,
43
- THFloatTensor *output);
42
+ int my_lib_add_forward(THFloatTensor *input1, THFloatTensor *input2, THFloatTensor *output);
44
43
int my_lib_add_backward(THFloatTensor *grad_output, THFloatTensor *grad_input);
45
44
```
46
45
@@ -59,7 +58,7 @@ with_cuda=False
59
58
60
59
## Step 2: Include it in your Python code
61
60
62
- After you run it, pytorch will create an _ ext directory and put my_lib inside.
61
+ After you run it, pytorch will create an ` _ext ` directory and put ` my_lib ` inside.
63
62
64
63
Package name can have an arbitrary number of packages preceding the final module name (including none).
65
64
If the build succeeded you can import your extension just like a regular python file.
@@ -72,16 +71,15 @@ from _ext import my_lib
72
71
73
72
74
73
class MyAddFunction (Function ):
75
-
76
- def forward (self , input1 , input2 ):
77
- output = torch.FloatTensor()
78
- my_lib.my_lib_add_forward(input1, input2, output)
79
- return output
80
-
81
- def backward (self , grad_output ):
82
- grad_input = torch.FloatTensor()
83
- my_lib.my_lib_add_backward(grad_output, grad_input)
84
- return grad_input
74
+ def forward (self , input1 , input2 ):
75
+ output = torch.FloatTensor()
76
+ my_lib.my_lib_add_forward(input1, input2, output)
77
+ return output
78
+
79
+ def backward (self , grad_output ):
80
+ grad_input = torch.FloatTensor()
81
+ my_lib.my_lib_add_backward(grad_output, grad_input)
82
+ return grad_input
85
83
```
86
84
87
85
``` python
@@ -90,9 +88,8 @@ from torch.nn import Module
90
88
from functions.add import MyAddFunction
91
89
92
90
class MyAddModule (Module ):
93
-
94
- def forward (self , input1 , input2 ):
95
- return MyAddFunction()(input1, input2)
91
+ def forward (self , input1 , input2 ):
92
+ return MyAddFunction()(input1, input2)
96
93
```
97
94
98
95
``` python
@@ -102,13 +99,13 @@ from torch.autograd import Variable
102
99
from modules.add import MyAddModule
103
100
104
101
class MyNetwork (nn .Container ):
105
- def __init__ (self ):
106
- super (MyNetwork, self ).__init__ (
107
- add = MyAddModule(),
108
- )
102
+ def __init__ (self ):
103
+ super (MyNetwork, self ).__init__ (
104
+ add = MyAddModule(),
105
+ )
109
106
110
- def forward (self , input1 , input2 ):
111
- return self .add(input1, input2)
107
+ def forward (self , input1 , input2 ):
108
+ return self .add(input1, input2)
112
109
113
110
model = MyNetwork()
114
111
input1, input2 = Variable(torch.randn(5 , 5 )), Variable(torch.randn(5 , 5 ))
0 commit comments