@@ -64,51 +64,52 @@ Now let's take a look at how we will make this class visible to TorchScript, a p
64
64
65
65
.. code-block :: cpp
66
66
67
+ // Notice a few things:
68
+ // - We pass the class to be registered as a template parameter to
69
+ // `torch::jit::class_`. In this instance, we've passed the
70
+ // specialization of the Stack class ``Stack<std::string>``.
71
+ // In general, you cannot register a non-specialized template
72
+ // class. For non-templated classes, you can just pass the
73
+ // class name directly as the template parameter.
74
+ // - The single parameter to ``torch::jit::class_()`` is a
75
+ // string indicating the name of the class. This is the name
76
+ // the class will appear as in both Python and TorchScript.
77
+ // For example, our Stack class would appear as ``torch.classes.Stack``.
67
78
static auto testStack =
68
- torch::jit::class_<Stack<std::string>>("Stack")
69
- .def(torch::jit::init<std::vector<std::string>>())
70
- .def("top", [](const c10::intrusive_ptr<Stack<std::string>>& self) {
71
- return self->stack_.back();
72
- })
73
- .def("push", &Stack<std::string>::push)
74
- .def("pop", &Stack<std::string>::pop)
75
- .def("clone", &Stack<std::string>::clone)
76
- .def("merge", &Stack<std::string>::merge);
77
-
78
- Notice the following:
79
-
80
- - We pass the class to be registered as a template parameter to ``torch::jit::class_ ``.
81
- In this instance, we've passed the specialization of the Stack class ``Stack<std::string> ``.
82
- In general, you cannot register a non-specialized template class. For non-templated classes,
83
- you can just pass the class name directly as the template parameter.
84
- - The single parameter to ``torch::jit::class_() `` is a string indicating the name of the class.
85
- This is the name the class will appear as in both Python and TorchScript. For example, our
86
- Stack class would appear as ``torch.classes.Stack ``.
87
- - For each method of the class we'd like to expose to Python and TorchScript, we use the
88
- ``.def() `` method on ``torch::jit::class_ ``. We can chain these together to register
89
- multiple methods as well. Let's examine the different callsites of ``def() `` in our example:
90
-
91
- - ``torch::jit::init<std::vector<std::string>>() `` registers the contructor of our Stack
92
- class that takes a single ``std::vector<std::string> `` argument, i.e. it exposes the C++
93
- method ``Stack(std::vector<T> init) ``. Currently, we do not support registering overloaded
94
- constructors, so for now you can only ``def() `` one instance of ``torch::jit::init ``.
95
- - The next line registers a stateless (i.e. no captures) C++ lambda function as a method.
96
- Note that a lambda function must take a ``c10::intrusive_ptr<YourClass> `` (or some
97
- const/rev version of that) to work.
98
- - ``.def("push", &Stack<std::string>::push) `` exposes the ``void push(T x) `` method.
99
- ``torch::jit::class_ `` will automatically examine the argument and return types of
100
- the passed-in method pointers and expose these to Python and TorchScript accordingly.
101
- Finally, notice that we must take the *address * of the fully-qualified method name,
102
- i.e. use the unary ``& `` operator, due to C++ typing rules.
103
- - The rest of the method registrations follow the same pattern.
79
+ torch::jit::class_<Stack<std::string>>("Stack")
80
+ // The following line registers the contructor of our Stack
81
+ // class that takes a single `std::vector<std::string>` argument,
82
+ // i.e. it exposes the C++ method `Stack(std::vector<T> init)`.
83
+ // Currently, we do not support registering overloaded
84
+ // constructors, so for now you can only `def()` one instance of
85
+ // `torch::jit::init`.
86
+ .def(torch::jit::init<std::vector<std::string>>())
87
+ // The next line registers a stateless (i.e. no captures) C++ lambda
88
+ // function as a method. Note that a lambda function must take a
89
+ // `c10::intrusive_ptr<YourClass>` (or some const/ref version of that)
90
+ // as the first argument. Other arguments can be whatever you want.
91
+ .def("top", [](const c10::intrusive_ptr<Stack<std::string>>& self) {
92
+ return self->stack_.back();
93
+ })
94
+ // The following four lines expose methods of the Stack<std::string>
95
+ // class as-is. `torch::jit::class_` will automatically examine the
96
+ // argument and return types of the passed-in method pointers and
97
+ // expose these to Python and TorchScript accordingly. Finally, notice
98
+ // that we must take the *address* of the fully-qualified method name,
99
+ // i.e. use the unary `&` operator, due to C++ typing rules.
100
+ .def("push", &Stack<std::string>::push)
101
+ .def("pop", &Stack<std::string>::pop)
102
+ .def("clone", &Stack<std::string>::clone)
103
+ .def("merge", &Stack<std::string>::merge);
104
+
104
105
105
106
106
107
Building the Example as a C++ Project With CMake
107
108
------------------------------------------------
108
109
109
110
Now, we're going to build the above C++ code with the `CMake
110
- <https://cmake.org> `_ build system. First, put all the C++ code
111
- we've covered so far, and place it in a file called ``class.cpp ``.
111
+ <https://cmake.org> `_ build system. First, take all the C++ code
112
+ we've covered so far and place it in a file called ``class.cpp ``.
112
113
Then, write a simple ``CMakeLists.txt `` file and place it in the
113
114
same directory. Here is what ``CMakeLists.txt `` should look like:
114
115
@@ -175,8 +176,9 @@ then make to build the project:
175
176
[100%] Linking CXX shared library libcustom_class.so
176
177
[100%] Built target custom_class
177
178
178
- What you' ll find is there is now (among other things) a libcustom_class.so
179
- file present in the build directory. So the file tree should look like::
179
+ What you' ll find is there is now (among other things) a dynamic library
180
+ file present in the build directory. On Linux, this is probably named
181
+ ``libcustom_class.so``. So the file tree should look like::
180
182
181
183
custom_class_project/
182
184
class.cpp
@@ -199,6 +201,9 @@ demonstrates that:
199
201
# to load it in and make the custom C++ classes available to both Python and
200
202
# TorchScript
201
203
torch.classes.load_library(" libcustom_class.so" )
204
+ # You can query the loaded libraries like this:
205
+ print(torch.classes.loaded_libraries)
206
+ # prints {'/custom_class_project/build/libcustom_class.so'}
202
207
203
208
# We can find and instantiate our custom C++ class in python by using the
204
209
# `torch.classes` namespace:
@@ -212,27 +217,23 @@ demonstrates that:
212
217
assert s.pop () == " pushed"
213
218
214
219
# Returning and passing instances of custom classes works as you'd expect
215
-
216
220
s2 = s.clone ()
217
221
s.merge(s2)
218
222
for expected in [" bar" , " foo" , " bar" , " foo" ]:
219
223
assert s.pop () == expected
220
224
221
225
# We can also use the class in TorchScript
222
- # For now, we need to assign the class's type to the local in order to
223
- # annotate the type on the TorchScript function
226
+ # For now, we need to assign the class's type to a local in order to
227
+ # annotate the type on the TorchScript function. This may change
228
+ # in the future.
224
229
Stack = torch.classes.Stack
225
230
226
- # This demonstrates:
227
- # - passing a custom class instance to TorchScript
228
- # - instantiating a class in TorchScript
229
- # - calling a custom class's methods in torchscript
230
- # - returning a custom class instance from TorchScript
231
231
@torch.jit.script
232
- def do_stacks(s : Stack):
233
- s2 = torch.classes.Stack([" hi" , " mom" ])
234
- s2.merge(s)
235
- return s2.clone (), s2.top ()
232
+ def do_stacks(s : Stack): # We can pass a custom class instance to TorchScript
233
+ s2 = torch.classes.Stack([" hi" , " mom" ]) # We can instantiate the class
234
+ s2.merge(s) # We can call a method on the class
235
+ return s2.clone (), s2.top () # We can also return instances of the class
236
+ # from TorchScript function/methods
236
237
237
238
stack, top = do_stacks(torch.classes.Stack([" wow" ]))
238
239
assert top == " wow"
@@ -280,6 +281,9 @@ Similarly to before, let's create a file structure containing the following::
280
281
foo.pt
281
282
build/
282
283
custom_class_project/
284
+ class.cpp
285
+ CMakeLists.txt
286
+ build/
283
287
284
288
Notice we' ve copied over the serialized ``foo.pt`` file, as well as the source
285
289
tree from the ``custom_class_project`` above. We will be adding the
@@ -313,7 +317,7 @@ Let's populate ``infer.cpp`` with the following:
313
317
314
318
And similarly let's define our CMakeLists.txt file:
315
319
316
- .. code-block: cmake
320
+ .. code-block:: cmake
317
321
318
322
cmake_minimum_required(VERSION 3.1 FATAL_ERROR)
319
323
project(infer)
@@ -327,11 +331,13 @@ And similarly let's define our CMakeLists.txt file:
327
331
set(CMAKE_CXX_STANDARD 14)
328
332
# Link against LibTorch
329
333
target_link_libraries(infer "${TORCH_LIBRARIES} ")
334
+ # This is where we link in our libcustom_class code, making our
335
+ # custom class available in our binary.
330
336
target_link_libraries(infer -Wl,--no-as-needed custom_class)
331
337
332
338
You know the drill: ` ` cd build` ` , ` ` cmake` ` , and ` ` make` ` :
333
339
334
- .. code-block: shell
340
+ .. code-block:: shell
335
341
336
342
$ cd build
337
343
$ cmake -DCMAKE_PREFIX_PATH=/path/to/libtorch ..
@@ -374,7 +380,7 @@ You know the drill: ``cd build``, ``cmake``, and ``make``:
374
380
375
381
And now we can run our exciting C++ binary:
376
382
377
- .. code-block: shell
383
+ .. code-block:: shell
378
384
379
385
$ ./infer
380
386
momfoobarbaz
0 commit comments