Skip to content

Commit d9917f0

Browse files
author
James Reed
committed
address comments
1 parent 77a60c3 commit d9917f0

File tree

1 file changed

+61
-55
lines changed

1 file changed

+61
-55
lines changed

advanced_source/torch_script_custom_classes.rst

Lines changed: 61 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -64,51 +64,52 @@ Now let's take a look at how we will make this class visible to TorchScript, a p
6464

6565
.. code-block:: cpp
6666
67+
// Notice a few things:
68+
// - We pass the class to be registered as a template parameter to
69+
// `torch::jit::class_`. In this instance, we've passed the
70+
// specialization of the Stack class ``Stack<std::string>``.
71+
// In general, you cannot register a non-specialized template
72+
// class. For non-templated classes, you can just pass the
73+
// class name directly as the template parameter.
74+
// - The single parameter to ``torch::jit::class_()`` is a
75+
// string indicating the name of the class. This is the name
76+
// the class will appear as in both Python and TorchScript.
77+
// For example, our Stack class would appear as ``torch.classes.Stack``.
6778
static auto testStack =
68-
torch::jit::class_<Stack<std::string>>("Stack")
69-
.def(torch::jit::init<std::vector<std::string>>())
70-
.def("top", [](const c10::intrusive_ptr<Stack<std::string>>& self) {
71-
return self->stack_.back();
72-
})
73-
.def("push", &Stack<std::string>::push)
74-
.def("pop", &Stack<std::string>::pop)
75-
.def("clone", &Stack<std::string>::clone)
76-
.def("merge", &Stack<std::string>::merge);
77-
78-
Notice the following:
79-
80-
- We pass the class to be registered as a template parameter to ``torch::jit::class_``.
81-
In this instance, we've passed the specialization of the Stack class ``Stack<std::string>``.
82-
In general, you cannot register a non-specialized template class. For non-templated classes,
83-
you can just pass the class name directly as the template parameter.
84-
- The single parameter to ``torch::jit::class_()`` is a string indicating the name of the class.
85-
This is the name the class will appear as in both Python and TorchScript. For example, our
86-
Stack class would appear as ``torch.classes.Stack``.
87-
- For each method of the class we'd like to expose to Python and TorchScript, we use the
88-
``.def()`` method on ``torch::jit::class_``. We can chain these together to register
89-
multiple methods as well. Let's examine the different callsites of ``def()`` in our example:
90-
91-
- ``torch::jit::init<std::vector<std::string>>()`` registers the contructor of our Stack
92-
class that takes a single ``std::vector<std::string>`` argument, i.e. it exposes the C++
93-
method ``Stack(std::vector<T> init)``. Currently, we do not support registering overloaded
94-
constructors, so for now you can only ``def()`` one instance of ``torch::jit::init``.
95-
- The next line registers a stateless (i.e. no captures) C++ lambda function as a method.
96-
Note that a lambda function must take a ``c10::intrusive_ptr<YourClass>`` (or some
97-
const/rev version of that) to work.
98-
- ``.def("push", &Stack<std::string>::push)`` exposes the ``void push(T x)`` method.
99-
``torch::jit::class_`` will automatically examine the argument and return types of
100-
the passed-in method pointers and expose these to Python and TorchScript accordingly.
101-
Finally, notice that we must take the *address* of the fully-qualified method name,
102-
i.e. use the unary ``&`` operator, due to C++ typing rules.
103-
- The rest of the method registrations follow the same pattern.
79+
torch::jit::class_<Stack<std::string>>("Stack")
80+
// The following line registers the contructor of our Stack
81+
// class that takes a single `std::vector<std::string>` argument,
82+
// i.e. it exposes the C++ method `Stack(std::vector<T> init)`.
83+
// Currently, we do not support registering overloaded
84+
// constructors, so for now you can only `def()` one instance of
85+
// `torch::jit::init`.
86+
.def(torch::jit::init<std::vector<std::string>>())
87+
// The next line registers a stateless (i.e. no captures) C++ lambda
88+
// function as a method. Note that a lambda function must take a
89+
// `c10::intrusive_ptr<YourClass>` (or some const/ref version of that)
90+
// as the first argument. Other arguments can be whatever you want.
91+
.def("top", [](const c10::intrusive_ptr<Stack<std::string>>& self) {
92+
return self->stack_.back();
93+
})
94+
// The following four lines expose methods of the Stack<std::string>
95+
// class as-is. `torch::jit::class_` will automatically examine the
96+
// argument and return types of the passed-in method pointers and
97+
// expose these to Python and TorchScript accordingly. Finally, notice
98+
// that we must take the *address* of the fully-qualified method name,
99+
// i.e. use the unary `&` operator, due to C++ typing rules.
100+
.def("push", &Stack<std::string>::push)
101+
.def("pop", &Stack<std::string>::pop)
102+
.def("clone", &Stack<std::string>::clone)
103+
.def("merge", &Stack<std::string>::merge);
104+
104105
105106
106107
Building the Example as a C++ Project With CMake
107108
------------------------------------------------
108109

109110
Now, we're going to build the above C++ code with the `CMake
110-
<https://cmake.org>`_ build system. First, put all the C++ code
111-
we've covered so far, and place it in a file called ``class.cpp``.
111+
<https://cmake.org>`_ build system. First, take all the C++ code
112+
we've covered so far and place it in a file called ``class.cpp``.
112113
Then, write a simple ``CMakeLists.txt`` file and place it in the
113114
same directory. Here is what ``CMakeLists.txt`` should look like:
114115

@@ -175,8 +176,9 @@ then make to build the project:
175176
[100%] Linking CXX shared library libcustom_class.so
176177
[100%] Built target custom_class
177178
178-
What you'll find is there is now (among other things) a libcustom_class.so
179-
file present in the build directory. So the file tree should look like::
179+
What you'll find is there is now (among other things) a dynamic library
180+
file present in the build directory. On Linux, this is probably named
181+
``libcustom_class.so``. So the file tree should look like::
180182
181183
custom_class_project/
182184
class.cpp
@@ -199,6 +201,9 @@ demonstrates that:
199201
# to load it in and make the custom C++ classes available to both Python and
200202
# TorchScript
201203
torch.classes.load_library("libcustom_class.so")
204+
# You can query the loaded libraries like this:
205+
print(torch.classes.loaded_libraries)
206+
# prints {'/custom_class_project/build/libcustom_class.so'}
202207
203208
# We can find and instantiate our custom C++ class in python by using the
204209
# `torch.classes` namespace:
@@ -212,27 +217,23 @@ demonstrates that:
212217
assert s.pop() == "pushed"
213218
214219
# Returning and passing instances of custom classes works as you'd expect
215-
216220
s2 = s.clone()
217221
s.merge(s2)
218222
for expected in ["bar", "foo", "bar", "foo"]:
219223
assert s.pop() == expected
220224
221225
# We can also use the class in TorchScript
222-
# For now, we need to assign the class's type to the local in order to
223-
# annotate the type on the TorchScript function
226+
# For now, we need to assign the class's type to a local in order to
227+
# annotate the type on the TorchScript function. This may change
228+
# in the future.
224229
Stack = torch.classes.Stack
225230
226-
# This demonstrates:
227-
# - passing a custom class instance to TorchScript
228-
# - instantiating a class in TorchScript
229-
# - calling a custom class's methods in torchscript
230-
# - returning a custom class instance from TorchScript
231231
@torch.jit.script
232-
def do_stacks(s : Stack):
233-
s2 = torch.classes.Stack(["hi", "mom"])
234-
s2.merge(s)
235-
return s2.clone(), s2.top()
232+
def do_stacks(s : Stack): # We can pass a custom class instance to TorchScript
233+
s2 = torch.classes.Stack(["hi", "mom"]) # We can instantiate the class
234+
s2.merge(s) # We can call a method on the class
235+
return s2.clone(), s2.top() # We can also return instances of the class
236+
# from TorchScript function/methods
236237
237238
stack, top = do_stacks(torch.classes.Stack(["wow"]))
238239
assert top == "wow"
@@ -280,6 +281,9 @@ Similarly to before, let's create a file structure containing the following::
280281
foo.pt
281282
build/
282283
custom_class_project/
284+
class.cpp
285+
CMakeLists.txt
286+
build/
283287
284288
Notice we've copied over the serialized ``foo.pt`` file, as well as the source
285289
tree from the ``custom_class_project`` above. We will be adding the
@@ -313,7 +317,7 @@ Let's populate ``infer.cpp`` with the following:
313317
314318
And similarly let's define our CMakeLists.txt file:
315319
316-
.. code-block: cmake
320+
.. code-block:: cmake
317321
318322
cmake_minimum_required(VERSION 3.1 FATAL_ERROR)
319323
project(infer)
@@ -327,11 +331,13 @@ And similarly let's define our CMakeLists.txt file:
327331
set(CMAKE_CXX_STANDARD 14)
328332
# Link against LibTorch
329333
target_link_libraries(infer "${TORCH_LIBRARIES}")
334+
# This is where we link in our libcustom_class code, making our
335+
# custom class available in our binary.
330336
target_link_libraries(infer -Wl,--no-as-needed custom_class)
331337
332338
You know the drill: ``cd build``, ``cmake``, and ``make``:
333339
334-
.. code-block: shell
340+
.. code-block:: shell
335341
336342
$ cd build
337343
$ cmake -DCMAKE_PREFIX_PATH=/path/to/libtorch ..
@@ -374,7 +380,7 @@ You know the drill: ``cd build``, ``cmake``, and ``make``:
374380
375381
And now we can run our exciting C++ binary:
376382
377-
.. code-block: shell
383+
.. code-block:: shell
378384
379385
$ ./infer
380386
momfoobarbaz

0 commit comments

Comments
 (0)