Skip to content

Commit 593759e

Browse files
authored
Add custom class tutorial back (#3546)
Reverts part of #3453 which removed the Custom Class tutorial Custom classes are a generic concept in PyTorch. Although it was developed for TorchScript, it is also supported in Pytorch 2.0 as many users (like TensorRT and internal Ads frameworks) rely on it. Therefore, we should not delete this tutorial. Additionally the [Custom Class x PT2 tutorial](https://docs.pytorch.org/tutorials/advanced/custom_class_pt2.html) also references this tutorial, although CI did not seem to catch the broken link (cc @svekars) As part of revert, I removed references to TS.
1 parent 3cbb0f7 commit 593759e

File tree

12 files changed

+536
-8
lines changed

12 files changed

+536
-8
lines changed

advanced_source/custom_class_pt2.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ Supporting Custom C++ Classes in torch.compile/torch.export
33

44

55
This tutorial is a follow-on to the
6-
:doc:`custom C++ classes <torch_script_custom_classes>` tutorial, and
6+
:doc:`custom C++ classes <custom_classes>` tutorial, and
77
introduces additional steps that are needed to support custom C++ classes in
88
torch.compile/torch.export.
99

@@ -30,7 +30,7 @@ Concretely, there are a few steps:
3030
states returned by ``__obj_flatten__``.
3131

3232
Here is a breakdown of the diff. Following the guide in
33-
:doc:`Extending TorchScript with Custom C++ Classes <torch_script_custom_classes>`,
33+
:doc:`Extending TorchScript with Custom C++ Classes <custom_classes>`,
3434
we can create a thread-safe tensor queue and build it.
3535

3636
.. code-block:: cpp

advanced_source/custom_classes.rst

Lines changed: 231 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,231 @@
1+
Extending PyTorch with Custom C++ Classes
2+
===============================================
3+
4+
5+
This tutorial introduces an API for binding C++ classes into PyTorch.
6+
The API is very similar to
7+
`pybind11 <https://github.com/pybind/pybind11>`_, and most of the concepts will transfer
8+
over if you're familiar with that system.
9+
10+
Implementing and Binding the Class in C++
11+
-----------------------------------------
12+
13+
For this tutorial, we are going to define a simple C++ class that maintains persistent
14+
state in a member variable.
15+
16+
.. literalinclude:: ../advanced_source/custom_classes/custom_class_project/class.cpp
17+
:language: cpp
18+
:start-after: BEGIN class
19+
:end-before: END class
20+
21+
There are several things to note:
22+
23+
- ``torch/custom_class.h`` is the header you need to include to extend PyTorch
24+
with your custom class.
25+
- Notice that whenever we are working with instances of the custom
26+
class, we do it via instances of ``c10::intrusive_ptr<>``. Think of ``intrusive_ptr``
27+
as a smart pointer like ``std::shared_ptr``, but the reference count is stored
28+
directly in the object, as opposed to a separate metadata block (as is done in
29+
``std::shared_ptr``. ``torch::Tensor`` internally uses the same pointer type;
30+
and custom classes have to also use this pointer type so that we can
31+
consistently manage different object types.
32+
- The second thing to notice is that the user-defined class must inherit from
33+
``torch::CustomClassHolder``. This ensures that the custom class has space to
34+
store the reference count.
35+
36+
Now let's take a look at how we will make this class visible to PyTorch, a process called
37+
*binding* the class:
38+
39+
.. literalinclude:: ../advanced_source/custom_classes/custom_class_project/class.cpp
40+
:language: cpp
41+
:start-after: BEGIN binding
42+
:end-before: END binding
43+
:append:
44+
;
45+
}
46+
47+
48+
49+
Building the Example as a C++ Project With CMake
50+
------------------------------------------------
51+
52+
Now, we're going to build the above C++ code with the `CMake
53+
<https://cmake.org>`_ build system. First, take all the C++ code
54+
we've covered so far and place it in a file called ``class.cpp``.
55+
Then, write a simple ``CMakeLists.txt`` file and place it in the
56+
same directory. Here is what ``CMakeLists.txt`` should look like:
57+
58+
.. literalinclude:: ../advanced_source/custom_classes/custom_class_project/CMakeLists.txt
59+
:language: cmake
60+
61+
Also, create a ``build`` directory. Your file tree should look like this::
62+
63+
custom_class_project/
64+
class.cpp
65+
CMakeLists.txt
66+
build/
67+
68+
Go ahead and invoke cmake and then make to build the project:
69+
70+
.. code-block:: shell
71+
72+
$ cd build
73+
$ cmake -DCMAKE_PREFIX_PATH="$(python -c 'import torch.utils; print(torch.utils.cmake_prefix_path)')" ..
74+
-- The C compiler identification is GNU 7.3.1
75+
-- The CXX compiler identification is GNU 7.3.1
76+
-- Check for working C compiler: /opt/rh/devtoolset-7/root/usr/bin/cc
77+
-- Check for working C compiler: /opt/rh/devtoolset-7/root/usr/bin/cc -- works
78+
-- Detecting C compiler ABI info
79+
-- Detecting C compiler ABI info - done
80+
-- Detecting C compile features
81+
-- Detecting C compile features - done
82+
-- Check for working CXX compiler: /opt/rh/devtoolset-7/root/usr/bin/c++
83+
-- Check for working CXX compiler: /opt/rh/devtoolset-7/root/usr/bin/c++ -- works
84+
-- Detecting CXX compiler ABI info
85+
-- Detecting CXX compiler ABI info - done
86+
-- Detecting CXX compile features
87+
-- Detecting CXX compile features - done
88+
-- Looking for pthread.h
89+
-- Looking for pthread.h - found
90+
-- Looking for pthread_create
91+
-- Looking for pthread_create - not found
92+
-- Looking for pthread_create in pthreads
93+
-- Looking for pthread_create in pthreads - not found
94+
-- Looking for pthread_create in pthread
95+
-- Looking for pthread_create in pthread - found
96+
-- Found Threads: TRUE
97+
-- Found torch: /torchbind_tutorial/libtorch/lib/libtorch.so
98+
-- Configuring done
99+
-- Generating done
100+
-- Build files have been written to: /torchbind_tutorial/build
101+
$ make -j
102+
Scanning dependencies of target custom_class
103+
[ 50%] Building CXX object CMakeFiles/custom_class.dir/class.cpp.o
104+
[100%] Linking CXX shared library libcustom_class.so
105+
[100%] Built target custom_class
106+
107+
What you'll find is there is now (among other things) a dynamic library
108+
file present in the build directory. On Linux, this is probably named
109+
``libcustom_class.so``. So the file tree should look like::
110+
111+
custom_class_project/
112+
class.cpp
113+
CMakeLists.txt
114+
build/
115+
libcustom_class.so
116+
117+
Using the C++ Class from Python
118+
-----------------------------------------------
119+
120+
Now that we have our class and its registration compiled into an ``.so`` file,
121+
we can load that `.so` into Python and try it out. Here's a script that
122+
demonstrates that:
123+
124+
.. literalinclude:: ../advanced_source/custom_classes/custom_class_project/custom_test.py
125+
:language: python
126+
127+
128+
Defining Serialization/Deserialization Methods for Custom C++ Classes
129+
---------------------------------------------------------------------
130+
131+
If you try to save a ``ScriptModule`` with a custom-bound C++ class as
132+
an attribute, you'll get the following error:
133+
134+
.. literalinclude:: ../advanced_source/custom_classes/custom_class_project/export_attr.py
135+
:language: python
136+
137+
.. code-block:: shell
138+
139+
$ python export_attr.py
140+
RuntimeError: Cannot serialize custom bound C++ class __torch__.torch.classes.my_classes.MyStackClass. Please define serialization methods via def_pickle for this class. (pushIValueImpl at ../torch/csrc/jit/pickler.cpp:128)
141+
142+
This is because PyTorch cannot automatically figure out what information
143+
save from your C++ class. You must specify that manually. The way to do that
144+
is to define ``__getstate__`` and ``__setstate__`` methods on the class using
145+
the special ``def_pickle`` method on ``class_``.
146+
147+
.. note::
148+
The semantics of ``__getstate__`` and ``__setstate__`` are
149+
equivalent to that of the Python pickle module. You can
150+
`read more <https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/docs/serialization.md#getstate-and-setstate>`_
151+
about how we use these methods.
152+
153+
Here is an example of the ``def_pickle`` call we can add to the registration of
154+
``MyStackClass`` to include serialization methods:
155+
156+
.. literalinclude:: ../advanced_source/custom_classes/custom_class_project/class.cpp
157+
:language: cpp
158+
:start-after: BEGIN def_pickle
159+
:end-before: END def_pickle
160+
161+
.. note::
162+
We take a different approach from pybind11 in the pickle API. Whereas pybind11
163+
as a special function ``pybind11::pickle()`` which you pass into ``class_::def()``,
164+
we have a separate method ``def_pickle`` for this purpose. This is because the
165+
name ``torch::jit::pickle`` was already taken, and we didn't want to cause confusion.
166+
167+
Once we have defined the (de)serialization behavior in this way, our script can
168+
now run successfully:
169+
170+
.. code-block:: shell
171+
172+
$ python ../export_attr.py
173+
testing
174+
175+
Defining Custom Operators that Take or Return Bound C++ Classes
176+
---------------------------------------------------------------
177+
178+
Once you've defined a custom C++ class, you can also use that class
179+
as an argument or return from a custom operator (i.e. free functions). Suppose
180+
you have the following free function:
181+
182+
.. literalinclude:: ../advanced_source/custom_classes/custom_class_project/class.cpp
183+
:language: cpp
184+
:start-after: BEGIN free_function
185+
:end-before: END free_function
186+
187+
You can register it running the following code inside your ``TORCH_LIBRARY``
188+
block:
189+
190+
.. literalinclude:: ../advanced_source/custom_classes/custom_class_project/class.cpp
191+
:language: cpp
192+
:start-after: BEGIN def_free
193+
:end-before: END def_free
194+
195+
Once this is done, you can use the op like the following example:
196+
197+
.. code-block:: python
198+
199+
class TryCustomOp(torch.nn.Module):
200+
def __init__(self):
201+
super(TryCustomOp, self).__init__()
202+
self.f = torch.classes.my_classes.MyStackClass(["foo", "bar"])
203+
204+
def forward(self):
205+
return torch.ops.my_classes.manipulate_instance(self.f)
206+
207+
.. note::
208+
209+
Registration of an operator that takes a C++ class as an argument requires that
210+
the custom class has already been registered. You can enforce this by
211+
making sure the custom class registration and your free function definitions
212+
are in the same ``TORCH_LIBRARY`` block, and that the custom class
213+
registration comes first. In the future, we may relax this requirement,
214+
so that these can be registered in any order.
215+
216+
217+
Conclusion
218+
----------
219+
220+
This tutorial walked you through how to expose a C++ class to PyTorch, how to
221+
register its methods, how to use that class from Python, and how to save and
222+
load code using the class and run that code in a standalone C++ process. You
223+
are now ready to extend your PyTorch models with C++ classes that interface
224+
with third party C++ libraries or implement any other use case that requires
225+
the lines between Python and C++ to blend smoothly.
226+
227+
As always, if you run into any problems or have questions, you can use our
228+
`forum <https://discuss.pytorch.org/>`_ or `GitHub issues
229+
<https://github.com/pytorch/pytorch/issues>`_ to get in touch. Also, our
230+
`frequently asked questions (FAQ) page
231+
<https://pytorch.org/cppdocs/notes/faq.html>`_ may have helpful information.
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
cmake_minimum_required(VERSION 3.1 FATAL_ERROR)
2+
project(infer)
3+
4+
find_package(Torch REQUIRED)
5+
6+
add_subdirectory(custom_class_project)
7+
8+
# Define our library target
9+
add_executable(infer infer.cpp)
10+
set(CMAKE_CXX_STANDARD 14)
11+
# Link against LibTorch
12+
target_link_libraries(infer "${TORCH_LIBRARIES}")
13+
# This is where we link in our libcustom_class code, making our
14+
# custom class available in our binary.
15+
target_link_libraries(infer -Wl,--no-as-needed custom_class)
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
cmake_minimum_required(VERSION 3.1 FATAL_ERROR)
2+
project(custom_class)
3+
4+
find_package(Torch REQUIRED)
5+
6+
# Define our library target
7+
add_library(custom_class SHARED class.cpp)
8+
set(CMAKE_CXX_STANDARD 14)
9+
# Link against LibTorch
10+
target_link_libraries(custom_class "${TORCH_LIBRARIES}")

0 commit comments

Comments
 (0)