forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Unify C++ API with C++ extensions (pytorch#11510)
Summary: Currently the C++ API and C++ extensions are effectively two different, entirely orthogonal code paths. This PR unifies the C++ API with the C++ extension API by adding an element of Python binding support to the C++ API. This means the `torch/torch.h` included by C++ extensions, which currently routes to `torch/csrc/torch.h`, can now be rerouted to `torch/csrc/api/include/torch/torch.h` -- i.e. the main C++ API header. This header then includes Python binding support conditioned on a define (`TORCH_WITH_PYTHON_BINDINGS`), *which is only passed when building a C++ extension*. Currently stacked on top of pytorch#11498 Why is this useful? 1. One less codepath. In particular, there has been trouble again and again due to the two `torch/torch.h` header files and ambiguity when both ended up in the include path. This is now fixed. 2. I have found that it is quite common to want to bind a C++ API module back into Python. This could be for simple experimentation, or to have your training loop in Python but your models in C++. This PR makes this easier by adding pybind11 support to the C++ API. 3. The C++ extension API simply becomes richer by gaining access to the C++ API headers. soumith ezyang apaszke Pull Request resolved: pytorch#11510 Reviewed By: ezyang Differential Revision: D9998835 Pulled By: goldsborough fbshipit-source-id: 7a94b44a9d7e0377b7f1cfc99ba2060874d51535
- Loading branch information
1 parent
1c09bfd
commit e05d689
Showing
22 changed files
with
320 additions
and
76 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,4 @@ | ||
#include <torch/torch.h> | ||
#include <torch/extension.h> | ||
|
||
#include <ATen/CPUFloatType.h> | ||
#include <ATen/Type.h> | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
#include <torch/extension.h> | ||
#include <torch/python.h> | ||
#include <torch/torch.h> | ||
|
||
struct Net : torch::nn::Module { | ||
Net(int64_t in, int64_t out) | ||
: fc(in, out), | ||
bn(torch::nn::BatchNormOptions(out).stateful(true)), | ||
dropout(0.5) { | ||
register_module("fc", fc); | ||
register_module("bn", bn); | ||
register_module("dropout", dropout); | ||
} | ||
|
||
torch::Tensor forward(torch::Tensor x) { | ||
return dropout->forward(bn->forward(torch::relu(fc->forward(x)))); | ||
} | ||
|
||
void set_bias(torch::Tensor bias) { | ||
fc->bias = bias; | ||
} | ||
|
||
torch::Tensor get_bias() const { | ||
return fc->bias; | ||
} | ||
|
||
torch::nn::Linear fc; | ||
torch::nn::BatchNorm bn; | ||
torch::nn::Dropout dropout; | ||
}; | ||
|
||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { | ||
torch::python::bind_module<Net>(m, "Net") | ||
.def(py::init<int64_t, int64_t>()) | ||
.def("forward", &Net::forward) | ||
.def("set_bias", &Net::set_bias) | ||
.def("get_bias", &Net::get_bias); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,4 @@ | ||
#include <torch/torch.h> | ||
#include <torch/extension.h> | ||
|
||
struct Doubler { | ||
Doubler(int A, int B) { | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,4 @@ | ||
#include <torch/torch.h> | ||
#include <torch/extension.h> | ||
|
||
#include <THC/THCNumerics.cuh> | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,4 @@ | ||
#include <torch/torch.h> | ||
#include <torch/extension.h> | ||
|
||
#include "doubler.h" | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,4 @@ | ||
#include <torch/torch.h> | ||
#include <torch/extension.h> | ||
|
||
using namespace at; | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
// This class exists only to do SFINAE on abstract types `T` that are really | ||
// `ModuleHolder<ModuleType>`, because there's no good way to say that `T` is a | ||
// `ModuleHolder` over some unknown type `ModuleType`. With this, you can do | ||
// `enable_if_t<is_base_of_v<ModuleHolderIndicator, T>>`. | ||
struct ModuleHolderIndicator {}; | ||
|
||
// A type trait that is true for types that are `ModuleHolder`s. | ||
template <typename T> | ||
using is_module_holder = std::is_base_of<ModuleHolderIndicator, decay_t<T>>; | ||
|
||
template <typename T> | ||
using disable_if_module_holder_t = disable_if_t<is_module_holder<T>::value>; | ||
|
||
// A collection of templates that answer the question whether a type `T` is a | ||
// `ModuleHolder`, and if so whether its contained type is of type `C`. This is | ||
// tricky because it is hard to short circuit in template metaprogramming. A | ||
// naive and incorrect solution to this problem would be something like | ||
// `disable_if<is_module_holder<T>::value && typename T::ContainedType == C>`. | ||
// This would disable all types that are not `ModuleHolder`s, because even | ||
// though the `is_module_holder<T>::value` may be `false` for such types the | ||
// `T::ContainedType` access would be ill-formed and thus fail the whole | ||
// expression by the rules of SFINAE. Instead we have to use template | ||
// specialization to statically branch on the first condition | ||
// (`is_module_holder<T>`) and are only then allowed to query | ||
// `T::ContainedType` in the branch for which the condition was true. | ||
|
||
// Base template. | ||
template <bool is_module_holder_value, typename T, typename C> | ||
struct is_module_holder_of_impl; | ||
|
||
// False branch. `T` is not a `ModuleHolder` and thus not a `ModuleHolder` with | ||
// contained type `C`. | ||
template <typename T, typename C> | ||
struct is_module_holder_of_impl<false, T, C> : std::false_type {}; | ||
|
||
// True branch. `T` is a `ModuleHolder` and thus we can legit access its | ||
// `ContainedType` and compare it against `C`. | ||
template <typename T, typename C> | ||
struct is_module_holder_of_impl<true, T, C> | ||
: std::is_same<typename T::ContainedType, C> {}; | ||
|
||
// Helper template. | ||
template <typename T, typename C> | ||
struct is_module_holder_of : is_module_holder_of_impl< | ||
detail::is_module_holder<T>::value, | ||
torch::decay_t<T>, | ||
torch::decay_t<C>> {}; |
Oops, something went wrong.