|
| 1 | +Extending TorchScript with Custom C++ Classes |
| 2 | +=============================================== |
| 3 | + |
| 4 | +This tutorial is a follow-on to the |
| 5 | +`custom operator <https://pytorch.org/tutorials/advanced/torch_script_custom_ops.html>`_ |
| 6 | +tutorial, and introduces the API we've built for binding C++ classes into TorchScript |
| 7 | +and Python simultaneously. The API is very similar to |
| 8 | +`pybind11 <https://github.com/pybind/pybind11>`_, and most of the concepts will transfer |
| 9 | +over if you're familiar with that system. |
| 10 | + |
| 11 | +Implementing and Binding the Class in C++ |
| 12 | +----------------------------------------- |
| 13 | + |
| 14 | +For this tutorial, we are going to define a simple C++ class that maintains persistent |
| 15 | +state in a member variable. |
| 16 | + |
| 17 | +.. code-block:: cpp |
| 18 | +
|
| 19 | + #include <torch/custom_class.h> |
| 20 | +
|
| 21 | + #include <string> |
| 22 | + #include <vector> |
| 23 | +
|
| 24 | + template <class T> |
| 25 | + struct Stack : torch::jit::CustomClassHolder { |
| 26 | + std::vector<T> stack_; |
| 27 | + Stack(std::vector<T> init) : stack_(init.begin(), init.end()) {} |
| 28 | +
|
| 29 | + void push(T x) { |
| 30 | + stack_.push_back(x); |
| 31 | + } |
| 32 | + T pop() { |
| 33 | + auto val = stack_.back(); |
| 34 | + stack_.pop_back(); |
| 35 | + return val; |
| 36 | + } |
| 37 | +
|
| 38 | + c10::intrusive_ptr<Stack> clone() const { |
| 39 | + return c10::make_intrusive<Stack>(stack_); |
| 40 | + } |
| 41 | +
|
| 42 | + void merge(const c10::intrusive_ptr<Stack>& c) { |
| 43 | + for (auto& elem : c->stack_) { |
| 44 | + push(elem); |
| 45 | + } |
| 46 | + } |
| 47 | + }; |
| 48 | +
|
| 49 | +There are several things to note: |
| 50 | + |
| 51 | +- ``torch/custom_class.h`` is the header you need to include to extend TorchScript |
| 52 | + with your custom class. |
| 53 | +- Notice that whenever we are working with instances of the custom |
| 54 | + class, we do it via instances of ``c10::intrusive_ptr<>``. Think of ``intrusive_ptr`` |
| 55 | + as a smart pointer like ``std::shared_ptr``. The reason for using this smart pointer |
| 56 | + is to ensure consistent lifetime management of the object instances between languages |
| 57 | + (C++, Python and TorchScript). |
| 58 | +- The second thing to notice is that the user-defined class must inherit from |
| 59 | + ``torch::jit::CustomClassHolder``. This ensures that everything is set up to handle |
| 60 | + the lifetime management system previously mentioned. |
| 61 | + |
| 62 | +Now let's take a look at how we will make this class visible to TorchScript, a process called |
| 63 | +*binding* the class: |
| 64 | + |
| 65 | +.. code-block:: cpp |
| 66 | +
|
| 67 | + static auto testStack = |
| 68 | + torch::jit::class_<Stack<std::string>>("Stack") |
| 69 | + .def(torch::jit::init<std::vector<std::string>>()) |
| 70 | + .def("top", [](const c10::intrusive_ptr<Stack<std::string>>& self) { |
| 71 | + return self->stack_.back(); |
| 72 | + }) |
| 73 | + .def("push", &Stack<std::string>::push) |
| 74 | + .def("pop", &Stack<std::string>::pop) |
| 75 | + .def("clone", &Stack<std::string>::clone) |
| 76 | + .def("merge", &Stack<std::string>::merge); |
| 77 | +
|
| 78 | +Notice the following: |
| 79 | + |
| 80 | +- We pass the class to be registered as a template parameter to ``torch::jit::class_``. |
| 81 | + In this instance, we've passed the specialization of the Stack class ``Stack<std::string>``. |
| 82 | + In general, you cannot register a non-specialized template class. For non-templated classes, |
| 83 | + you can just pass the class name directly as the template parameter. |
| 84 | +- The single parameter to ``torch::jit::class_()`` is a string indicating the name of the class. |
| 85 | + This is the name the class will appear as in both Python and TorchScript. For example, our |
| 86 | + Stack class would appear as ``torch.classes.Stack``. |
| 87 | +- For each method of the class we'd like to expose to Python and TorchScript, we use the |
| 88 | + ``.def()`` method on ``torch::jit::class_``. We can chain these together to register |
| 89 | + multiple methods as well. Let's examine the different callsites of ``def()`` in our example: |
| 90 | + |
| 91 | + - ``torch::jit::init<std::vector<std::string>>()`` registers the contructor of our Stack |
| 92 | + class that takes a single ``std::vector<std::string>`` argument, i.e. it exposes the C++ |
| 93 | + method ``Stack(std::vector<T> init)``. Currently, we do not support registering overloaded |
| 94 | + constructors, so for now you can only ``def()`` one instance of ``torch::jit::init``. |
| 95 | + - The next line registers a stateless (i.e. no captures) C++ lambda function as a method. |
| 96 | + Note that a lambda function must take a ``c10::intrusive_ptr<YourClass>`` (or some |
| 97 | + const/rev version of that) to work. |
| 98 | + - ``.def("push", &Stack<std::string>::push)`` exposes the ``void push(T x)`` method. |
| 99 | + ``torch::jit::class_`` will automatically examine the argument and return types of |
| 100 | + the passed-in method pointers and expose these to Python and TorchScript accordingly. |
| 101 | + Finally, notice that we must take the *address* of the fully-qualified method name, |
| 102 | + i.e. use the unary ``&`` operator, due to C++ typing rules. |
| 103 | + - The rest of the method registrations follow the same pattern. |
| 104 | + |
| 105 | + |
| 106 | +Building the Example as a C++ Project With CMake |
| 107 | +------------------------------------------------ |
| 108 | + |
| 109 | +Now, we're going to build the above C++ code with the `CMake |
| 110 | +<https://cmake.org>`_ build system. First, put all the C++ code |
| 111 | +we've covered so far, and place it in a file called ``class.cpp``. |
| 112 | +Then, write a simple ``CMakeLists.txt`` file and place it in the |
| 113 | +same directory. Here is what ``CMakeLists.txt`` should look like: |
| 114 | + |
| 115 | +.. code-block:: cmake |
| 116 | +
|
| 117 | + cmake_minimum_required(VERSION 3.1 FATAL_ERROR) |
| 118 | + project(custom_class) |
| 119 | +
|
| 120 | + find_package(Torch REQUIRED) |
| 121 | +
|
| 122 | + # Define our library target |
| 123 | + add_library(custom_class SHARED class.cpp) |
| 124 | + set(CMAKE_CXX_STANDARD 14) |
| 125 | + # Link against LibTorch |
| 126 | + target_link_libraries(custom_class "${TORCH_LIBRARIES}") |
| 127 | +
|
| 128 | +Also, create a ``build`` directory. Your file tree should look like this:: |
| 129 | + |
| 130 | + custom_class_project/ |
| 131 | + class.cpp |
| 132 | + CMakeLists.txt |
| 133 | + build/ |
| 134 | + |
| 135 | +Now, to build the project, go ahead and download the appropriate libtorch |
| 136 | +binary from the `PyTorch website <https://pytorch.org/>`_. Extract the |
| 137 | +zip archive somewhere (within the project directory might be convenient) |
| 138 | +and note the path you've extracted it to. Next, go ahead and invoke cmake and |
| 139 | +then make to build the project: |
| 140 | + |
| 141 | +.. code-block:: shell |
| 142 | +
|
| 143 | + $ cd build |
| 144 | + $ cmake -DCMAKE_PREFIX_PATH=/path/to/libtorch .. |
| 145 | + -- The C compiler identification is GNU 7.3.1 |
| 146 | + -- The CXX compiler identification is GNU 7.3.1 |
| 147 | + -- Check for working C compiler: /opt/rh/devtoolset-7/root/usr/bin/cc |
| 148 | + -- Check for working C compiler: /opt/rh/devtoolset-7/root/usr/bin/cc -- works |
| 149 | + -- Detecting C compiler ABI info |
| 150 | + -- Detecting C compiler ABI info - done |
| 151 | + -- Detecting C compile features |
| 152 | + -- Detecting C compile features - done |
| 153 | + -- Check for working CXX compiler: /opt/rh/devtoolset-7/root/usr/bin/c++ |
| 154 | + -- Check for working CXX compiler: /opt/rh/devtoolset-7/root/usr/bin/c++ -- works |
| 155 | + -- Detecting CXX compiler ABI info |
| 156 | + -- Detecting CXX compiler ABI info - done |
| 157 | + -- Detecting CXX compile features |
| 158 | + -- Detecting CXX compile features - done |
| 159 | + -- Looking for pthread.h |
| 160 | + -- Looking for pthread.h - found |
| 161 | + -- Looking for pthread_create |
| 162 | + -- Looking for pthread_create - not found |
| 163 | + -- Looking for pthread_create in pthreads |
| 164 | + -- Looking for pthread_create in pthreads - not found |
| 165 | + -- Looking for pthread_create in pthread |
| 166 | + -- Looking for pthread_create in pthread - found |
| 167 | + -- Found Threads: TRUE |
| 168 | + -- Found torch: /torchbind_tutorial/libtorch/lib/libtorch.so |
| 169 | + -- Configuring done |
| 170 | + -- Generating done |
| 171 | + -- Build files have been written to: /torchbind_tutorial/build |
| 172 | + $ make -j |
| 173 | + Scanning dependencies of target custom_class |
| 174 | + [ 50%] Building CXX object CMakeFiles/custom_class.dir/class.cpp.o |
| 175 | + [100%] Linking CXX shared library libcustom_class.so |
| 176 | + [100%] Built target custom_class |
| 177 | +
|
| 178 | +What you'll find is there is now (among other things) a libcustom_class.so |
| 179 | +file present in the build directory. So the file tree should look like:: |
| 180 | +
|
| 181 | + custom_class_project/ |
| 182 | + class.cpp |
| 183 | + CMakeLists.txt |
| 184 | + build/ |
| 185 | + libcustom_class.so |
| 186 | +
|
| 187 | +Using the C++ Class from Python and TorchScript |
| 188 | +----------------------------------------------- |
| 189 | +
|
| 190 | +Now that we have our class and its registration compiled into an ``.so`` file, |
| 191 | +we can load that `.so` into Python and try it out. Here's a script that |
| 192 | +demonstrates that: |
| 193 | +
|
| 194 | +.. code-block:: python |
| 195 | +
|
| 196 | + import torch |
| 197 | +
|
| 198 | + # `torch.classes.load_library()` allows you to pass the path to your .so file |
| 199 | + # to load it in and make the custom C++ classes available to both Python and |
| 200 | + # TorchScript |
| 201 | + torch.classes.load_library("libcustom_class.so") |
| 202 | +
|
| 203 | + # We can find and instantiate our custom C++ class in python by using the |
| 204 | + # `torch.classes` namespace: |
| 205 | + # |
| 206 | + # This instantiation will invoke the Stack(std::vector<T> init) constructor |
| 207 | + # we registered earlier |
| 208 | + s = torch.classes.Stack(["foo", "bar"]) |
| 209 | +
|
| 210 | + # We can call methods in Python |
| 211 | + s.push("pushed") |
| 212 | + assert s.pop() == "pushed" |
| 213 | +
|
| 214 | + # Returning and passing instances of custom classes works as you'd expect |
| 215 | +
|
| 216 | + s2 = s.clone() |
| 217 | + s.merge(s2) |
| 218 | + for expected in ["bar", "foo", "bar", "foo"]: |
| 219 | + assert s.pop() == expected |
| 220 | +
|
| 221 | + # We can also use the class in TorchScript |
| 222 | + # For now, we need to assign the class's type to the local in order to |
| 223 | + # annotate the type on the TorchScript function |
| 224 | + Stack = torch.classes.Stack |
| 225 | +
|
| 226 | + # This demonstrates: |
| 227 | + # - passing a custom class instance to TorchScript |
| 228 | + # - instantiating a class in TorchScript |
| 229 | + # - calling a custom class's methods in torchscript |
| 230 | + # - returning a custom class instance from TorchScript |
| 231 | + @torch.jit.script |
| 232 | + def do_stacks(s : Stack): |
| 233 | + s2 = torch.classes.Stack(["hi", "mom"]) |
| 234 | + s2.merge(s) |
| 235 | + return s2.clone(), s2.top() |
| 236 | +
|
| 237 | + stack, top = do_stacks(torch.classes.Stack(["wow"])) |
| 238 | + assert top == "wow" |
| 239 | + for expected in ["wow", "mom", "hi"]: |
| 240 | + assert stack.pop() == expected |
| 241 | +
|
| 242 | +Saving, Loading, and Running TorchScript Code Using Custom Classes |
| 243 | +------------------------------------------------------------------ |
| 244 | +
|
| 245 | +We can also use custom-registered C++ classes in a C++ process using |
| 246 | +libtorch. As an example, let's define a simple ``nn.Module`` that |
| 247 | +instantiates and calls a method on our Stack class: |
| 248 | +
|
| 249 | +.. code-block:: python |
| 250 | +
|
| 251 | + import torch |
| 252 | +
|
| 253 | + torch.classes.load_library('libcustom_class.so') |
| 254 | +
|
| 255 | + class Foo(torch.nn.Module): |
| 256 | + def __init__(self): |
| 257 | + super().__init__() |
| 258 | +
|
| 259 | + def forward(self, s : str) -> str: |
| 260 | + stack = torch.classes.Stack(["hi", "mom"]) |
| 261 | + return stack.pop() + s |
| 262 | +
|
| 263 | + scripted_foo = torch.jit.script(Foo()) |
| 264 | + print(scripted_foo.graph) |
| 265 | +
|
| 266 | + scripted_foo.save('foo.pt') |
| 267 | +
|
| 268 | +``foo.pt`` in our filesystem now contains the serialized TorchScript |
| 269 | +program we've just defined. |
| 270 | +
|
| 271 | +Now, we're going to define a new CMake project to show how you can load |
| 272 | +this model and its required .so file. For a full treatment of how to do this, |
| 273 | +please have a look at the `Loading a TorchScript Model in C++ Tutorial <https://pytorch.org/tutorials/advanced/cpp_export.html>`_. |
| 274 | +
|
| 275 | +Similarly to before, let's create a file structure containing the following:: |
| 276 | +
|
| 277 | + cpp_inference_example/ |
| 278 | + infer.cpp |
| 279 | + CMakeLists.txt |
| 280 | + foo.pt |
| 281 | + build/ |
| 282 | + custom_class_project/ |
| 283 | +
|
| 284 | +Notice we've copied over the serialized ``foo.pt`` file, as well as the source |
| 285 | +tree from the ``custom_class_project`` above. We will be adding the |
| 286 | +``custom_class_project`` as a dependency to this C++ project so that we can |
| 287 | +build the custom class into the binary. |
| 288 | +
|
| 289 | +Let's populate ``infer.cpp`` with the following: |
| 290 | +
|
| 291 | +.. code-block:: cpp |
| 292 | +
|
| 293 | + #include <torch/script.h> |
| 294 | +
|
| 295 | + #include <iostream> |
| 296 | + #include <memory> |
| 297 | +
|
| 298 | + int main(int argc, const char* argv[]) { |
| 299 | + torch::jit::script::Module module; |
| 300 | + try { |
| 301 | + // Deserialize the ScriptModule from a file using torch::jit::load(). |
| 302 | + module = torch::jit::load("foo.pt"); |
| 303 | + } |
| 304 | + catch (const c10::Error& e) { |
| 305 | + std::cerr << "error loading the model\n"; |
| 306 | + return -1; |
| 307 | + } |
| 308 | +
|
| 309 | + std::vector<c10::IValue> inputs = {"foobarbaz"}; |
| 310 | + auto output = module.forward(inputs).toString(); |
| 311 | + std::cout << output->string() << std::endl; |
| 312 | + } |
| 313 | +
|
| 314 | +And similarly let's define our CMakeLists.txt file: |
| 315 | +
|
| 316 | +.. code-block: cmake |
| 317 | +
|
| 318 | + cmake_minimum_required(VERSION 3.1 FATAL_ERROR) |
| 319 | + project(infer) |
| 320 | +
|
| 321 | + find_package(Torch REQUIRED) |
| 322 | +
|
| 323 | + add_subdirectory(custom_class_project) |
| 324 | +
|
| 325 | + # Define our library target |
| 326 | + add_executable(infer infer.cpp) |
| 327 | + set(CMAKE_CXX_STANDARD 14) |
| 328 | + # Link against LibTorch |
| 329 | + target_link_libraries(infer "${TORCH_LIBRARIES}") |
| 330 | + target_link_libraries(infer -Wl,--no-as-needed custom_class) |
| 331 | +
|
| 332 | +You know the drill: ``cd build``, ``cmake``, and ``make``: |
| 333 | +
|
| 334 | +.. code-block: shell |
| 335 | +
|
| 336 | + $ cd build |
| 337 | + $ cmake -DCMAKE_PREFIX_PATH=/path/to/libtorch .. |
| 338 | + -- The C compiler identification is GNU 7.3.1 |
| 339 | + -- The CXX compiler identification is GNU 7.3.1 |
| 340 | + -- Check for working C compiler: /opt/rh/devtoolset-7/root/usr/bin/cc |
| 341 | + -- Check for working C compiler: /opt/rh/devtoolset-7/root/usr/bin/cc -- works |
| 342 | + -- Detecting C compiler ABI info |
| 343 | + -- Detecting C compiler ABI info - done |
| 344 | + -- Detecting C compile features |
| 345 | + -- Detecting C compile features - done |
| 346 | + -- Check for working CXX compiler: /opt/rh/devtoolset-7/root/usr/bin/c++ |
| 347 | + -- Check for working CXX compiler: /opt/rh/devtoolset-7/root/usr/bin/c++ -- works |
| 348 | + -- Detecting CXX compiler ABI info |
| 349 | + -- Detecting CXX compiler ABI info - done |
| 350 | + -- Detecting CXX compile features |
| 351 | + -- Detecting CXX compile features - done |
| 352 | + -- Looking for pthread.h |
| 353 | + -- Looking for pthread.h - found |
| 354 | + -- Looking for pthread_create |
| 355 | + -- Looking for pthread_create - not found |
| 356 | + -- Looking for pthread_create in pthreads |
| 357 | + -- Looking for pthread_create in pthreads - not found |
| 358 | + -- Looking for pthread_create in pthread |
| 359 | + -- Looking for pthread_create in pthread - found |
| 360 | + -- Found Threads: TRUE |
| 361 | + -- Found torch: /local/miniconda3/lib/python3.7/site-packages/torch/lib/libtorch.so |
| 362 | + -- Configuring done |
| 363 | + -- Generating done |
| 364 | + -- Build files have been written to: /cpp_inference_example/build |
| 365 | + $ make -j |
| 366 | + Scanning dependencies of target custom_class |
| 367 | + [ 25%] Building CXX object custom_class_project/CMakeFiles/custom_class.dir/class.cpp.o |
| 368 | + [ 50%] Linking CXX shared library libcustom_class.so |
| 369 | + [ 50%] Built target custom_class |
| 370 | + Scanning dependencies of target infer |
| 371 | + [ 75%] Building CXX object CMakeFiles/infer.dir/infer.cpp.o |
| 372 | + [100%] Linking CXX executable infer |
| 373 | + [100%] Built target infer |
| 374 | +
|
| 375 | +And now we can run our exciting C++ binary: |
| 376 | +
|
| 377 | +.. code-block: shell |
| 378 | +
|
| 379 | + $ ./infer |
| 380 | + momfoobarbaz |
| 381 | +
|
| 382 | +Incredible! |
| 383 | +
|
| 384 | +Conclusion |
| 385 | +---------- |
| 386 | +
|
| 387 | +This tutorial walked you through how to expose a C++ class to TorchScript |
| 388 | +(and by extension Python), how to register its method, how to use that |
| 389 | +class from Python and TorchScript, and how to save and load code using |
| 390 | +the class and run that code in a standalone C++ process. You are now ready |
| 391 | +to extend your TorchScript models with C++ classes that interface with |
| 392 | +third party C++ libraries or implement any other use case that requires the |
| 393 | +lines between Python, TorchScript and C++ to blend smoothly. |
| 394 | +
|
| 395 | +As always, if you run into any problems or have questions, you can use our |
| 396 | +`forum <https://discuss.pytorch.org/>`_ or `GitHub issues |
| 397 | +<https://github.com/pytorch/pytorch/issues>`_ to get in touch. Also, our |
| 398 | +`frequently asked questions (FAQ) page |
| 399 | +<https://pytorch.org/cppdocs/notes/faq.html>`_ may have helpful information. |
0 commit comments