Skip to content

Commit 77a60c3

Browse files
author
James Reed
committed
Tutorial for custom C++ classes in TorchScript
1 parent 8244bff commit 77a60c3

File tree

2 files changed

+405
-0
lines changed

2 files changed

+405
-0
lines changed
Lines changed: 399 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,399 @@
1+
Extending TorchScript with Custom C++ Classes
2+
===============================================
3+
4+
This tutorial is a follow-on to the
5+
`custom operator <https://pytorch.org/tutorials/advanced/torch_script_custom_ops.html>`_
6+
tutorial, and introduces the API we've built for binding C++ classes into TorchScript
7+
and Python simultaneously. The API is very similar to
8+
`pybind11 <https://github.com/pybind/pybind11>`_, and most of the concepts will transfer
9+
over if you're familiar with that system.
10+
11+
Implementing and Binding the Class in C++
12+
-----------------------------------------
13+
14+
For this tutorial, we are going to define a simple C++ class that maintains persistent
15+
state in a member variable.
16+
17+
.. code-block:: cpp
18+
19+
#include <torch/custom_class.h>
20+
21+
#include <string>
22+
#include <vector>
23+
24+
template <class T>
25+
struct Stack : torch::jit::CustomClassHolder {
26+
std::vector<T> stack_;
27+
Stack(std::vector<T> init) : stack_(init.begin(), init.end()) {}
28+
29+
void push(T x) {
30+
stack_.push_back(x);
31+
}
32+
T pop() {
33+
auto val = stack_.back();
34+
stack_.pop_back();
35+
return val;
36+
}
37+
38+
c10::intrusive_ptr<Stack> clone() const {
39+
return c10::make_intrusive<Stack>(stack_);
40+
}
41+
42+
void merge(const c10::intrusive_ptr<Stack>& c) {
43+
for (auto& elem : c->stack_) {
44+
push(elem);
45+
}
46+
}
47+
};
48+
49+
There are several things to note:
50+
51+
- ``torch/custom_class.h`` is the header you need to include to extend TorchScript
52+
with your custom class.
53+
- Notice that whenever we are working with instances of the custom
54+
class, we do it via instances of ``c10::intrusive_ptr<>``. Think of ``intrusive_ptr``
55+
as a smart pointer like ``std::shared_ptr``. The reason for using this smart pointer
56+
is to ensure consistent lifetime management of the object instances between languages
57+
(C++, Python and TorchScript).
58+
- The second thing to notice is that the user-defined class must inherit from
59+
``torch::jit::CustomClassHolder``. This ensures that everything is set up to handle
60+
the lifetime management system previously mentioned.
61+
62+
Now let's take a look at how we will make this class visible to TorchScript, a process called
63+
*binding* the class:
64+
65+
.. code-block:: cpp
66+
67+
static auto testStack =
68+
torch::jit::class_<Stack<std::string>>("Stack")
69+
.def(torch::jit::init<std::vector<std::string>>())
70+
.def("top", [](const c10::intrusive_ptr<Stack<std::string>>& self) {
71+
return self->stack_.back();
72+
})
73+
.def("push", &Stack<std::string>::push)
74+
.def("pop", &Stack<std::string>::pop)
75+
.def("clone", &Stack<std::string>::clone)
76+
.def("merge", &Stack<std::string>::merge);
77+
78+
Notice the following:
79+
80+
- We pass the class to be registered as a template parameter to ``torch::jit::class_``.
81+
In this instance, we've passed the specialization of the Stack class ``Stack<std::string>``.
82+
In general, you cannot register a non-specialized template class. For non-templated classes,
83+
you can just pass the class name directly as the template parameter.
84+
- The single parameter to ``torch::jit::class_()`` is a string indicating the name of the class.
85+
This is the name the class will appear as in both Python and TorchScript. For example, our
86+
Stack class would appear as ``torch.classes.Stack``.
87+
- For each method of the class we'd like to expose to Python and TorchScript, we use the
88+
``.def()`` method on ``torch::jit::class_``. We can chain these together to register
89+
multiple methods as well. Let's examine the different callsites of ``def()`` in our example:
90+
91+
- ``torch::jit::init<std::vector<std::string>>()`` registers the contructor of our Stack
92+
class that takes a single ``std::vector<std::string>`` argument, i.e. it exposes the C++
93+
method ``Stack(std::vector<T> init)``. Currently, we do not support registering overloaded
94+
constructors, so for now you can only ``def()`` one instance of ``torch::jit::init``.
95+
- The next line registers a stateless (i.e. no captures) C++ lambda function as a method.
96+
Note that a lambda function must take a ``c10::intrusive_ptr<YourClass>`` (or some
97+
const/rev version of that) to work.
98+
- ``.def("push", &Stack<std::string>::push)`` exposes the ``void push(T x)`` method.
99+
``torch::jit::class_`` will automatically examine the argument and return types of
100+
the passed-in method pointers and expose these to Python and TorchScript accordingly.
101+
Finally, notice that we must take the *address* of the fully-qualified method name,
102+
i.e. use the unary ``&`` operator, due to C++ typing rules.
103+
- The rest of the method registrations follow the same pattern.
104+
105+
106+
Building the Example as a C++ Project With CMake
107+
------------------------------------------------
108+
109+
Now, we're going to build the above C++ code with the `CMake
110+
<https://cmake.org>`_ build system. First, put all the C++ code
111+
we've covered so far, and place it in a file called ``class.cpp``.
112+
Then, write a simple ``CMakeLists.txt`` file and place it in the
113+
same directory. Here is what ``CMakeLists.txt`` should look like:
114+
115+
.. code-block:: cmake
116+
117+
cmake_minimum_required(VERSION 3.1 FATAL_ERROR)
118+
project(custom_class)
119+
120+
find_package(Torch REQUIRED)
121+
122+
# Define our library target
123+
add_library(custom_class SHARED class.cpp)
124+
set(CMAKE_CXX_STANDARD 14)
125+
# Link against LibTorch
126+
target_link_libraries(custom_class "${TORCH_LIBRARIES}")
127+
128+
Also, create a ``build`` directory. Your file tree should look like this::
129+
130+
custom_class_project/
131+
class.cpp
132+
CMakeLists.txt
133+
build/
134+
135+
Now, to build the project, go ahead and download the appropriate libtorch
136+
binary from the `PyTorch website <https://pytorch.org/>`_. Extract the
137+
zip archive somewhere (within the project directory might be convenient)
138+
and note the path you've extracted it to. Next, go ahead and invoke cmake and
139+
then make to build the project:
140+
141+
.. code-block:: shell
142+
143+
$ cd build
144+
$ cmake -DCMAKE_PREFIX_PATH=/path/to/libtorch ..
145+
-- The C compiler identification is GNU 7.3.1
146+
-- The CXX compiler identification is GNU 7.3.1
147+
-- Check for working C compiler: /opt/rh/devtoolset-7/root/usr/bin/cc
148+
-- Check for working C compiler: /opt/rh/devtoolset-7/root/usr/bin/cc -- works
149+
-- Detecting C compiler ABI info
150+
-- Detecting C compiler ABI info - done
151+
-- Detecting C compile features
152+
-- Detecting C compile features - done
153+
-- Check for working CXX compiler: /opt/rh/devtoolset-7/root/usr/bin/c++
154+
-- Check for working CXX compiler: /opt/rh/devtoolset-7/root/usr/bin/c++ -- works
155+
-- Detecting CXX compiler ABI info
156+
-- Detecting CXX compiler ABI info - done
157+
-- Detecting CXX compile features
158+
-- Detecting CXX compile features - done
159+
-- Looking for pthread.h
160+
-- Looking for pthread.h - found
161+
-- Looking for pthread_create
162+
-- Looking for pthread_create - not found
163+
-- Looking for pthread_create in pthreads
164+
-- Looking for pthread_create in pthreads - not found
165+
-- Looking for pthread_create in pthread
166+
-- Looking for pthread_create in pthread - found
167+
-- Found Threads: TRUE
168+
-- Found torch: /torchbind_tutorial/libtorch/lib/libtorch.so
169+
-- Configuring done
170+
-- Generating done
171+
-- Build files have been written to: /torchbind_tutorial/build
172+
$ make -j
173+
Scanning dependencies of target custom_class
174+
[ 50%] Building CXX object CMakeFiles/custom_class.dir/class.cpp.o
175+
[100%] Linking CXX shared library libcustom_class.so
176+
[100%] Built target custom_class
177+
178+
What you'll find is there is now (among other things) a libcustom_class.so
179+
file present in the build directory. So the file tree should look like::
180+
181+
custom_class_project/
182+
class.cpp
183+
CMakeLists.txt
184+
build/
185+
libcustom_class.so
186+
187+
Using the C++ Class from Python and TorchScript
188+
-----------------------------------------------
189+
190+
Now that we have our class and its registration compiled into an ``.so`` file,
191+
we can load that `.so` into Python and try it out. Here's a script that
192+
demonstrates that:
193+
194+
.. code-block:: python
195+
196+
import torch
197+
198+
# `torch.classes.load_library()` allows you to pass the path to your .so file
199+
# to load it in and make the custom C++ classes available to both Python and
200+
# TorchScript
201+
torch.classes.load_library("libcustom_class.so")
202+
203+
# We can find and instantiate our custom C++ class in python by using the
204+
# `torch.classes` namespace:
205+
#
206+
# This instantiation will invoke the Stack(std::vector<T> init) constructor
207+
# we registered earlier
208+
s = torch.classes.Stack(["foo", "bar"])
209+
210+
# We can call methods in Python
211+
s.push("pushed")
212+
assert s.pop() == "pushed"
213+
214+
# Returning and passing instances of custom classes works as you'd expect
215+
216+
s2 = s.clone()
217+
s.merge(s2)
218+
for expected in ["bar", "foo", "bar", "foo"]:
219+
assert s.pop() == expected
220+
221+
# We can also use the class in TorchScript
222+
# For now, we need to assign the class's type to the local in order to
223+
# annotate the type on the TorchScript function
224+
Stack = torch.classes.Stack
225+
226+
# This demonstrates:
227+
# - passing a custom class instance to TorchScript
228+
# - instantiating a class in TorchScript
229+
# - calling a custom class's methods in torchscript
230+
# - returning a custom class instance from TorchScript
231+
@torch.jit.script
232+
def do_stacks(s : Stack):
233+
s2 = torch.classes.Stack(["hi", "mom"])
234+
s2.merge(s)
235+
return s2.clone(), s2.top()
236+
237+
stack, top = do_stacks(torch.classes.Stack(["wow"]))
238+
assert top == "wow"
239+
for expected in ["wow", "mom", "hi"]:
240+
assert stack.pop() == expected
241+
242+
Saving, Loading, and Running TorchScript Code Using Custom Classes
243+
------------------------------------------------------------------
244+
245+
We can also use custom-registered C++ classes in a C++ process using
246+
libtorch. As an example, let's define a simple ``nn.Module`` that
247+
instantiates and calls a method on our Stack class:
248+
249+
.. code-block:: python
250+
251+
import torch
252+
253+
torch.classes.load_library('libcustom_class.so')
254+
255+
class Foo(torch.nn.Module):
256+
def __init__(self):
257+
super().__init__()
258+
259+
def forward(self, s : str) -> str:
260+
stack = torch.classes.Stack(["hi", "mom"])
261+
return stack.pop() + s
262+
263+
scripted_foo = torch.jit.script(Foo())
264+
print(scripted_foo.graph)
265+
266+
scripted_foo.save('foo.pt')
267+
268+
``foo.pt`` in our filesystem now contains the serialized TorchScript
269+
program we've just defined.
270+
271+
Now, we're going to define a new CMake project to show how you can load
272+
this model and its required .so file. For a full treatment of how to do this,
273+
please have a look at the `Loading a TorchScript Model in C++ Tutorial <https://pytorch.org/tutorials/advanced/cpp_export.html>`_.
274+
275+
Similarly to before, let's create a file structure containing the following::
276+
277+
cpp_inference_example/
278+
infer.cpp
279+
CMakeLists.txt
280+
foo.pt
281+
build/
282+
custom_class_project/
283+
284+
Notice we've copied over the serialized ``foo.pt`` file, as well as the source
285+
tree from the ``custom_class_project`` above. We will be adding the
286+
``custom_class_project`` as a dependency to this C++ project so that we can
287+
build the custom class into the binary.
288+
289+
Let's populate ``infer.cpp`` with the following:
290+
291+
.. code-block:: cpp
292+
293+
#include <torch/script.h>
294+
295+
#include <iostream>
296+
#include <memory>
297+
298+
int main(int argc, const char* argv[]) {
299+
torch::jit::script::Module module;
300+
try {
301+
// Deserialize the ScriptModule from a file using torch::jit::load().
302+
module = torch::jit::load("foo.pt");
303+
}
304+
catch (const c10::Error& e) {
305+
std::cerr << "error loading the model\n";
306+
return -1;
307+
}
308+
309+
std::vector<c10::IValue> inputs = {"foobarbaz"};
310+
auto output = module.forward(inputs).toString();
311+
std::cout << output->string() << std::endl;
312+
}
313+
314+
And similarly let's define our CMakeLists.txt file:
315+
316+
.. code-block: cmake
317+
318+
cmake_minimum_required(VERSION 3.1 FATAL_ERROR)
319+
project(infer)
320+
321+
find_package(Torch REQUIRED)
322+
323+
add_subdirectory(custom_class_project)
324+
325+
# Define our library target
326+
add_executable(infer infer.cpp)
327+
set(CMAKE_CXX_STANDARD 14)
328+
# Link against LibTorch
329+
target_link_libraries(infer "${TORCH_LIBRARIES}")
330+
target_link_libraries(infer -Wl,--no-as-needed custom_class)
331+
332+
You know the drill: ``cd build``, ``cmake``, and ``make``:
333+
334+
.. code-block: shell
335+
336+
$ cd build
337+
$ cmake -DCMAKE_PREFIX_PATH=/path/to/libtorch ..
338+
-- The C compiler identification is GNU 7.3.1
339+
-- The CXX compiler identification is GNU 7.3.1
340+
-- Check for working C compiler: /opt/rh/devtoolset-7/root/usr/bin/cc
341+
-- Check for working C compiler: /opt/rh/devtoolset-7/root/usr/bin/cc -- works
342+
-- Detecting C compiler ABI info
343+
-- Detecting C compiler ABI info - done
344+
-- Detecting C compile features
345+
-- Detecting C compile features - done
346+
-- Check for working CXX compiler: /opt/rh/devtoolset-7/root/usr/bin/c++
347+
-- Check for working CXX compiler: /opt/rh/devtoolset-7/root/usr/bin/c++ -- works
348+
-- Detecting CXX compiler ABI info
349+
-- Detecting CXX compiler ABI info - done
350+
-- Detecting CXX compile features
351+
-- Detecting CXX compile features - done
352+
-- Looking for pthread.h
353+
-- Looking for pthread.h - found
354+
-- Looking for pthread_create
355+
-- Looking for pthread_create - not found
356+
-- Looking for pthread_create in pthreads
357+
-- Looking for pthread_create in pthreads - not found
358+
-- Looking for pthread_create in pthread
359+
-- Looking for pthread_create in pthread - found
360+
-- Found Threads: TRUE
361+
-- Found torch: /local/miniconda3/lib/python3.7/site-packages/torch/lib/libtorch.so
362+
-- Configuring done
363+
-- Generating done
364+
-- Build files have been written to: /cpp_inference_example/build
365+
$ make -j
366+
Scanning dependencies of target custom_class
367+
[ 25%] Building CXX object custom_class_project/CMakeFiles/custom_class.dir/class.cpp.o
368+
[ 50%] Linking CXX shared library libcustom_class.so
369+
[ 50%] Built target custom_class
370+
Scanning dependencies of target infer
371+
[ 75%] Building CXX object CMakeFiles/infer.dir/infer.cpp.o
372+
[100%] Linking CXX executable infer
373+
[100%] Built target infer
374+
375+
And now we can run our exciting C++ binary:
376+
377+
.. code-block: shell
378+
379+
$ ./infer
380+
momfoobarbaz
381+
382+
Incredible!
383+
384+
Conclusion
385+
----------
386+
387+
This tutorial walked you through how to expose a C++ class to TorchScript
388+
(and by extension Python), how to register its method, how to use that
389+
class from Python and TorchScript, and how to save and load code using
390+
the class and run that code in a standalone C++ process. You are now ready
391+
to extend your TorchScript models with C++ classes that interface with
392+
third party C++ libraries or implement any other use case that requires the
393+
lines between Python, TorchScript and C++ to blend smoothly.
394+
395+
As always, if you run into any problems or have questions, you can use our
396+
`forum <https://discuss.pytorch.org/>`_ or `GitHub issues
397+
<https://github.com/pytorch/pytorch/issues>`_ to get in touch. Also, our
398+
`frequently asked questions (FAQ) page
399+
<https://pytorch.org/cppdocs/notes/faq.html>`_ may have helpful information.

0 commit comments

Comments
 (0)