Skip to content

Commit

Permalink
Unify C++ API with C++ extensions (pytorch#11510)
Browse files Browse the repository at this point in the history
Summary:
Currently the C++ API and C++ extensions are effectively two different, entirely orthogonal code paths. This PR unifies the C++ API with the C++ extension API by adding an element of Python binding support to the C++ API. This means the `torch/torch.h` included by C++ extensions, which currently routes to `torch/csrc/torch.h`, can now be rerouted to `torch/csrc/api/include/torch/torch.h` -- i.e. the main C++ API header. This header then includes Python binding support conditioned on a define (`TORCH_WITH_PYTHON_BINDINGS`), *which is only passed when building a C++ extension*.

Currently stacked on top of pytorch#11498

Why is this useful?

1. One less codepath. In particular, there has been trouble again and again due to the two `torch/torch.h` header files and ambiguity when both ended up in the include path. This is now fixed.
2. I have found that it is quite common to want to bind a C++ API module back into Python. This could be for simple experimentation, or to have your training loop in Python but your models in C++. This PR makes this easier by adding pybind11 support to the C++ API.
3. The C++ extension API simply becomes richer by gaining access to the C++ API headers.

soumith ezyang apaszke
Pull Request resolved: pytorch#11510

Reviewed By: ezyang

Differential Revision: D9998835

Pulled By: goldsborough

fbshipit-source-id: 7a94b44a9d7e0377b7f1cfc99ba2060874d51535
  • Loading branch information
goldsborough authored and facebook-github-bot committed Sep 24, 2018
1 parent 1c09bfd commit e05d689
Show file tree
Hide file tree
Showing 22 changed files with 320 additions and 76 deletions.
16 changes: 10 additions & 6 deletions cmake/TorchConfig.cmake.in
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,13 @@ endif()

# Include directories.
if (EXISTS "${TORCH_INSTALL_PREFIX}/lib/include")
set(TORCH_INCLUDE_DIRS "${TORCH_INSTALL_PREFIX}/lib/include")
set(TORCH_INCLUDE_DIRS
${TORCH_INSTALL_PREFIX}/lib/include
${TORCH_INSTALL_PREFIX}/lib/include/torch/csrc/api/include)
else()
set(TORCH_INCLUDE_DIRS "${TORCH_INSTALL_PREFIX}/include")
set(TORCH_INCLUDE_DIRS
${TORCH_INSTALL_PREFIX}/include
${TORCH_INSTALL_PREFIX}/include/torch/csrc/api/include)
endif()

# Library dependencies.
Expand All @@ -45,7 +49,7 @@ if (@USE_CUDA@)
set(TORCH_CUDA_LIBRARIES
${NVTOOLEXT_HOME}/lib/x64/nvToolsExt64_1.lib
${CUDA_LIBRARIES})
list(APPEND TORCH_INCLUDE_DIRS "${NVTOOLEXT_HOME}/include")
list(APPEND TORCH_INCLUDE_DIRS ${NVTOOLEXT_HOME}/include)
elseif(APPLE)
set(TORCH_CUDA_LIBRARIES
${CUDA_TOOLKIT_ROOT_DIR}/lib/libcudart.dylib
Expand All @@ -66,8 +70,8 @@ endif()
set(TORCH_CXX_FLAGS "-D_GLIBCXX_USE_CXX11_ABI=@GLIBCXX_USE_CXX11_ABI@")

set_target_properties(torch PROPERTIES
IMPORTED_LOCATION ${TORCH_LIBRARY}
INTERFACE_INCLUDE_DIRECTORIES ${TORCH_INCLUDE_DIRS}
INTERFACE_COMPILE_OPTIONS ${TORCH_CXX_FLAGS}
IMPORTED_LOCATION "${TORCH_LIBRARY}"
INTERFACE_INCLUDE_DIRECTORIES "${TORCH_INCLUDE_DIRS}"
INTERFACE_COMPILE_OPTIONS "${TORCH_CXX_FLAGS}"
CXX_STANDARD 11
)
17 changes: 7 additions & 10 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,18 +471,9 @@ def check_file(f):
if not same:
shutil.copyfile(orig_file, sym_file)

# Copy headers necessary to compile C++ extensions.
#
# This is not perfect solution as build does not depend on any of
# the auto-generated code and auto-generated files will not be
# included in this copy. If we want to use auto-generated files,
# we need to find a better way to do this.
# More information can be found in conversation thread of PR #5772

self.copy_tree('torch/lib/tmp_install/share', 'torch/share')
self.copy_tree('third_party/pybind11/include/pybind11/',
'torch/lib/include/pybind11')
self.copy_file('torch/csrc/torch.h', 'torch/lib/include/torch/torch.h')


build_dep_cmds = {}
Expand Down Expand Up @@ -1212,7 +1203,13 @@ def make_relative_rpath(path):
'lib/include/c10/macros/*.h',
'lib/include/torch/*.h',
'lib/include/torch/csrc/*.h',
'lib/include/torch/csrc/api/include/torch/detail/ordered_dict.h',
'lib/include/torch/csrc/api/include/torch/*.h',
'lib/include/torch/csrc/api/include/torch/detail/*.h',
'lib/include/torch/csrc/api/include/torch/nn/*.h',
'lib/include/torch/csrc/api/include/torch/nn/modules/*.h',
'lib/include/torch/csrc/api/include/torch/nn/parallel/*.h',
'lib/include/torch/csrc/api/include/torch/optim/*.h',
'lib/include/torch/csrc/api/include/torch/serialize/*.h',
'lib/include/torch/csrc/autograd/*.h',
'lib/include/torch/csrc/autograd/generated/*.h',
'lib/include/torch/csrc/cuda/*.h',
Expand Down
2 changes: 1 addition & 1 deletion test/cpp_extensions/complex_registration_extension.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#include <torch/torch.h>
#include <torch/extension.h>

#include <ATen/CPUFloatType.h>
#include <ATen/Type.h>
Expand Down
38 changes: 38 additions & 0 deletions test/cpp_extensions/cpp_api_extension.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
#include <torch/extension.h>
#include <torch/python.h>
#include <torch/torch.h>

struct Net : torch::nn::Module {
Net(int64_t in, int64_t out)
: fc(in, out),
bn(torch::nn::BatchNormOptions(out).stateful(true)),
dropout(0.5) {
register_module("fc", fc);
register_module("bn", bn);
register_module("dropout", dropout);
}

torch::Tensor forward(torch::Tensor x) {
return dropout->forward(bn->forward(torch::relu(fc->forward(x))));
}

void set_bias(torch::Tensor bias) {
fc->bias = bias;
}

torch::Tensor get_bias() const {
return fc->bias;
}

torch::nn::Linear fc;
torch::nn::BatchNorm bn;
torch::nn::Dropout dropout;
};

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
torch::python::bind_module<Net>(m, "Net")
.def(py::init<int64_t, int64_t>())
.def("forward", &Net::forward)
.def("set_bias", &Net::set_bias)
.def("get_bias", &Net::get_bias);
}
2 changes: 1 addition & 1 deletion test/cpp_extensions/cuda_extension.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#include <torch/torch.h>
#include <torch/extension.h>

// Declare the function from cuda_extension.cu. It will be compiled
// separately with nvcc and linked with the object file of cuda_extension.cpp
Expand Down
2 changes: 1 addition & 1 deletion test/cpp_extensions/cudnn_extension.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
* 5) Return something (optional).
*/

#include <torch/torch.h>
#include <torch/extension.h>

#include <ATen/cudnn/Descriptors.h> // for TensorDescriptor
#include <ATen/cuda/Exceptions.h> // for CUDNN_CHECK
Expand Down
2 changes: 1 addition & 1 deletion test/cpp_extensions/doubler.h
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#include <torch/torch.h>
#include <torch/extension.h>

struct Doubler {
Doubler(int A, int B) {
Expand Down
2 changes: 1 addition & 1 deletion test/cpp_extensions/extension.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#include <torch/torch.h>
#include <torch/extension.h>

at::Tensor sigmoid_add(at::Tensor x, at::Tensor y) {
return x.sigmoid() + y.sigmoid();
Expand Down
2 changes: 1 addition & 1 deletion test/cpp_extensions/half_support.cu
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#include <torch/torch.h>
#include <torch/extension.h>

#include <THC/THCNumerics.cuh>

Expand Down
2 changes: 1 addition & 1 deletion test/cpp_extensions/jit_extension.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#include <torch/torch.h>
#include <torch/extension.h>

#include "doubler.h"

Expand Down
2 changes: 1 addition & 1 deletion test/cpp_extensions/jit_extension2.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#include <torch/torch.h>
#include <torch/extension.h>

using namespace at;

Expand Down
51 changes: 49 additions & 2 deletions test/test_cpp_extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@
TEST_CUDNN = TEST_CUDA and CUDNN_HEADER_EXISTS and torch.backends.cudnn.is_available()


IS_WINDOWS = sys.platform == 'win32'


class TestCppExtension(common.TestCase):
def setUp(self):
if sys.platform != 'win32':
Expand Down Expand Up @@ -189,7 +192,7 @@ def test_inline_jit_compile_extension_multiple_sources_and_no_functions(self):
'''

cpp_source2 = '''
#include <torch/torch.h>
#include <torch/extension.h>
at::Tensor sin_add(at::Tensor x, at::Tensor y);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("sin_add", &sin_add, "sin(x) + sin(y)");
Expand Down Expand Up @@ -265,7 +268,7 @@ def test_lenient_flag_handling_in_jit_extensions(self):
cpp_sources=cpp_source,
functions='tanh_add',
extra_cflags=['-g\n\n', '-O0 -Wall'],
extra_include_paths=[' cpp_extensions\n', '../'],
extra_include_paths=[' cpp_extensions\n'],
verbose=True)

x = torch.zeros(100, dtype=torch.float32)
Expand Down Expand Up @@ -341,6 +344,50 @@ def compile(code):
module = compile('int f() { return 789; }')
self.assertEqual(module.f(), 789)

@unittest.skipIf(IS_WINDOWS, "C++ API not yet supported on Windows")
def test_cpp_api_extension(self):
here = os.path.abspath(__file__)
pytorch_root = os.path.dirname(os.path.dirname(here))
api_include = os.path.join(pytorch_root, 'torch', 'csrc', 'api', 'include')
module = torch.utils.cpp_extension.load(
name='cpp_api_extension',
sources='cpp_extensions/cpp_api_extension.cpp',
extra_include_paths=api_include,
extra_cflags=[] if IS_WINDOWS else ['-UTORCH_API_INCLUDE_EXTENSION_H'],
verbose=True)

net = module.Net(3, 5)

self.assertTrue(net.training)
net.eval()
self.assertFalse(net.training)
net.train()
self.assertTrue(net.training)
net.eval()

input = torch.randn(2, 3, dtype=torch.float32)
output = net.forward(input)
self.assertEqual(output, net.forward(input))
self.assertEqual(list(output.shape), [2, 5])

bias = net.get_bias()
self.assertEqual(list(bias.shape), [5])
net.set_bias(bias + 1)
self.assertEqual(net.get_bias(), bias + 1)
output2 = net.forward(input)

self.assertNotEqual(output + 1, output2)

self.assertEqual(len(net.parameters()), 4)

p = net.named_parameters()
self.assertEqual(type(p), dict)
self.assertEqual(len(p), 4)
self.assertIn('fc.weight', p)
self.assertIn('fc.bias', p)
self.assertIn('bn.weight', p)
self.assertIn('bn.bias', p)


if __name__ == '__main__':
common.run_tests()
2 changes: 1 addition & 1 deletion torch/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -411,7 +411,7 @@ endif()
install(DIRECTORY "${TORCH_SRC_DIR}/csrc"
DESTINATION ${TORCH_INSTALL_INCLUDE_DIR}/torch
FILES_MATCHING PATTERN "*.h")
install(FILES "${TORCH_SRC_DIR}/script.h"
install(FILES "${TORCH_SRC_DIR}/script.h" "${TORCH_SRC_DIR}/extension.h"
DESTINATION ${TORCH_INSTALL_INCLUDE_DIR}/torch)

install(TARGETS torch
Expand Down
8 changes: 3 additions & 5 deletions torch/csrc/api/include/torch/nn/modules/dropout.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,20 +12,18 @@ namespace nn {

/// Options for `Dropout` and `FeatureDropout`.
struct DropoutOptions {
DropoutOptions(double rate);
/* implicit */ DropoutOptions(double rate = 0.5);
/// The probability with which a particular component of the input is set to
/// zero.
/// Changes to this parameter at runtime are effective.
TORCH_ARG(double, rate) = 0.5;
TORCH_ARG(double, rate);
};

namespace detail {
template <typename Derived>
class DropoutImplBase : public torch::nn::Cloneable<Derived> {
public:
explicit DropoutImplBase(double rate)
: DropoutImplBase(DropoutOptions(rate)) {}
explicit DropoutImplBase(DropoutOptions options_);
explicit DropoutImplBase(DropoutOptions options_ = DropoutOptions());

void reset() override;

Expand Down
2 changes: 2 additions & 0 deletions torch/csrc/api/include/torch/nn/modules/sequential.h
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,8 @@ class SequentialImpl : public Cloneable<SequentialImpl> {
using Iterator = std::vector<AnyModule>::iterator;
using ConstIterator = std::vector<AnyModule>::const_iterator;

SequentialImpl() = default;

/// Constructs the `Sequential` from a variadic list of modules.
template <typename... Modules>
explicit SequentialImpl(Modules&&... modules) {
Expand Down
47 changes: 47 additions & 0 deletions torch/csrc/api/include/torch/nn/pimpl-inl.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
// This class exists only to do SFINAE on abstract types `T` that are really
// `ModuleHolder<ModuleType>`, because there's no good way to say that `T` is a
// `ModuleHolder` over some unknown type `ModuleType`. With this, you can do
// `enable_if_t<is_base_of_v<ModuleHolderIndicator, T>>`.
struct ModuleHolderIndicator {};

// A type trait that is true for types that are `ModuleHolder`s.
template <typename T>
using is_module_holder = std::is_base_of<ModuleHolderIndicator, decay_t<T>>;

template <typename T>
using disable_if_module_holder_t = disable_if_t<is_module_holder<T>::value>;

// A collection of templates that answer the question whether a type `T` is a
// `ModuleHolder`, and if so whether its contained type is of type `C`. This is
// tricky because it is hard to short circuit in template metaprogramming. A
// naive and incorrect solution to this problem would be something like
// `disable_if<is_module_holder<T>::value && typename T::ContainedType == C>`.
// This would disable all types that are not `ModuleHolder`s, because even
// though the `is_module_holder<T>::value` may be `false` for such types the
// `T::ContainedType` access would be ill-formed and thus fail the whole
// expression by the rules of SFINAE. Instead we have to use template
// specialization to statically branch on the first condition
// (`is_module_holder<T>`) and are only then allowed to query
// `T::ContainedType` in the branch for which the condition was true.

// Base template.
template <bool is_module_holder_value, typename T, typename C>
struct is_module_holder_of_impl;

// False branch. `T` is not a `ModuleHolder` and thus not a `ModuleHolder` with
// contained type `C`.
template <typename T, typename C>
struct is_module_holder_of_impl<false, T, C> : std::false_type {};

// True branch. `T` is a `ModuleHolder` and thus we can legit access its
// `ContainedType` and compare it against `C`.
template <typename T, typename C>
struct is_module_holder_of_impl<true, T, C>
: std::is_same<typename T::ContainedType, C> {};

// Helper template.
template <typename T, typename C>
struct is_module_holder_of : is_module_holder_of_impl<
detail::is_module_holder<T>::value,
torch::decay_t<T>,
torch::decay_t<C>> {};
Loading

0 comments on commit e05d689

Please sign in to comment.