Skip to content

Commit e2a4c7f

Browse files
authored
Merge branch 'main' into patch-1
2 parents 5b559b5 + 9dc8613 commit e2a4c7f

File tree

1 file changed

+110
-80
lines changed

1 file changed

+110
-80
lines changed
Lines changed: 110 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
Inductor C++ Wrapper Tutorial
1+
TorchInductor C++ Wrapper Tutorial
22
==============================================================
33

44
**Author**: `Chunyuan Wu <https://github.com/chunyuan-w>`_, `Bin Bao <https://github.com/desertfire>`__, `Jiong Gong <https://github.com/jgong5>`__
@@ -10,85 +10,120 @@ Prerequisites:
1010
Introduction
1111
------------
1212

13-
Python, as the primary interface of PyTorch, is easy to use and efficient for development and debugging.
14-
The Inductor's default wrapper generates Python code to invoke generated kernels and external kernels.
15-
However, in deployments requiring high performance, Python, as an interpreted language, runs relatively slower compared to compiled languages.
13+
In ``torch.compile``, the default backend **TorchInductor** emits Python wrapper
14+
code that manages memory allocation and kernel invocation. This design provides
15+
flexibility and ease of debugging, but the interpreted nature of Python
16+
introduces runtime overhead in performance-sensitive environments.
1617

17-
We implemented an Inductor C++ wrapper by leveraging the PyTorch C++ APIs
18-
to generate pure C++ code that combines the generated and external kernels.
19-
This allows for the execution of each captured Dynamo graph in pure C++,
20-
thereby reducing the Python overhead within the graph.
18+
To address this limitation, TorchInductor includes a specialized mode that
19+
generates **C++ wrapper code** in place of the Python wrapper, enabling faster
20+
execution with minimal Python involvement.
2121

2222

23-
Enabling the API
23+
Enabling the C++ wrapper mode
2424
----------------
25-
This feature is still in prototype stage. To activate this feature, add the following to your code:
25+
To enable this C++ wrapper mode for TorchInductor, add the following config to your code:
2626

2727
.. code:: python
2828
2929
import torch._inductor.config as config
3030
config.cpp_wrapper = True
3131
32-
This will speed up your models by reducing the Python overhead of the Inductor wrapper.
33-
3432
3533
Example code
3634
------------
3735

38-
We will use the below frontend code as an example:
36+
We will use the following model code as an example:
3937

4038
.. code:: python
41-
39+
4240
import torch
41+
import torch._inductor.config as config
42+
43+
config.cpp_wrapper = True
44+
45+
def fn(x, y):
46+
return (x + y).sum()
4347
44-
def fn(x):
45-
return torch.tensor(list(range(2, 40, 2)), device=x.device) + x
48+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
49+
x = torch.randn(128, 128, device=device)
50+
y = torch.randn(128, 128, device=device)
4651
47-
x = torch.randn(1)
48-
opt_fn = torch.compile()(fn)
49-
y = opt_fn(x)
52+
opt_fn = torch.compile(fn)
53+
result = opt_fn(x, y)
5054
5155
5256
**For CPU**
5357

54-
The main part of Inductor-generated code with the default Python wrapper will look like this:
58+
The main part of TorchInductor-generated code with the default Python wrapper will look like this:
5559

5660
.. code:: python
5761
58-
def call(args):
59-
arg0_1, = args
60-
args.clear()
61-
assert_size_stride(arg0_1, (1, ), (1, ))
62-
buf0 = empty_strided((19, ), (1, ), device='cpu', dtype=torch.float32)
63-
cpp_fused_add_lift_fresh_0(c_void_p(constant0.data_ptr()), c_void_p(arg0_1.data_ptr()), c_void_p(buf0.data_ptr()))
64-
del arg0_1
65-
return (buf0, )
62+
class Runner:
63+
def __init__(self, partitions):
64+
self.partitions = partitions
65+
66+
def call(self, args):
67+
arg0_1, arg1_1 = args
68+
args.clear()
69+
assert_size_stride(arg0_1, (128, 128), (128, 1))
70+
assert_size_stride(arg1_1, (128, 128), (128, 1))
71+
buf0 = empty_strided_cpu((), (), torch.float32)
72+
cpp_fused_add_sum_0(arg0_1, arg1_1, buf0)
73+
del arg0_1
74+
del arg1_1
75+
return (buf0, )
6676
6777
By turning on the C++ wrapper, the generated code for the ``call`` function becomes a C++ function
68-
``inductor_entry_cpp`` of the C++ extension ``module``:
78+
``inductor_entry_impl``:
6979

7080
.. code:: python
7181
72-
std::vector<at::Tensor> inductor_entry_cpp(const std::vector<at::Tensor>& args) {
73-
at::Tensor arg0_1 = args[0];
74-
at::Tensor constant0 = args[1];
75-
auto buf0 = at::empty_strided({19L, }, {1L, }, at::device(at::kCPU).dtype(at::kFloat));
76-
cpp_fused_add_lift_fresh_0((long*)(constant0.data_ptr()), (float*)(arg0_1.data_ptr()), (float*)(buf0.data_ptr()));
82+
cpp_wrapper_src = (
83+
r'''
84+
#include <torch/csrc/inductor/cpp_wrapper/cpu.h>
85+
extern "C" void cpp_fused_add_sum_0(const float* in_ptr0,
86+
const float* in_ptr1,
87+
float* out_ptr0);
88+
CACHE_TORCH_DTYPE(float32);
89+
CACHE_TORCH_DEVICE(cpu);
90+
91+
void inductor_entry_impl(
92+
AtenTensorHandle*
93+
input_handles, // array of input AtenTensorHandle; handles
94+
// are stolen; the array itself is borrowed
95+
AtenTensorHandle*
96+
output_handles // array for writing output AtenTensorHandle; handles
97+
// will be stolen by the caller; the array itself is
98+
// borrowed)
99+
) {
100+
py::gil_scoped_release_simple release;
101+
102+
auto inputs = steal_from_raw_handles_to_raii_handles(input_handles, 2);
103+
auto arg0_1 = std::move(inputs[0]);
104+
auto arg1_1 = std::move(inputs[1]);
105+
static constexpr int64_t *int_array_0=nullptr;
106+
AtenTensorHandle buf0_handle;
107+
AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_empty_strided(0, int_array_0, int_array_0, cached_torch_dtype_float32, cached_torch_device_type_cpu, 0, &buf0_handle));
108+
RAIIAtenTensorHandle buf0(buf0_handle);
109+
cpp_fused_add_sum_0((const float*)(arg0_1.data_ptr()), (const float*)(arg1_1.data_ptr()), (float*)(buf0.data_ptr()));
77110
arg0_1.reset();
78-
return {buf0};
79-
}
80-
81-
module = CppWrapperCodeCache.load(cpp_wrapper_src, 'inductor_entry_cpp', 'c2buojsvlqbywxe3itb43hldieh4jqulk72iswa2awalwev7hjn2', False)
82-
83-
def _wrap_func(f):
84-
def g(args):
85-
args_tensor = [arg if isinstance(arg, torch.Tensor) else torch.tensor(arg) for arg in args]
86-
constants_tensor = [constant0]
87-
args_tensor.extend(constants_tensor)
88-
89-
return f(args_tensor)
90-
return g
91-
call = _wrap_func(module.inductor_entry_cpp)
111+
arg1_1.reset();
112+
output_handles[0] = buf0.release();
113+
} // inductor_entry_impl
114+
...
115+
'''
116+
)
117+
118+
inductor_entry = CppWrapperCodeCache.load_pybinding(
119+
argtypes=["std::vector<AtenTensorHandle>"],
120+
main_code=cpp_wrapper_src,
121+
device_type="cpu",
122+
num_outputs=1,
123+
kernel_code=None,
124+
)
125+
126+
call = _wrap_func(inductor_entry)
92127
93128
**For GPU**
94129

@@ -114,46 +149,41 @@ With the C++ wrapper turned on, the below equivalent C++ code will be generated:
114149

115150
.. code:: python
116151
117-
std::vector<at::Tensor> inductor_entry_cpp(const std::vector<at::Tensor>& args) {
118-
at::Tensor arg0_1 = args[0];
119-
at::Tensor constant0 = args[1];
120-
121-
at::cuda::CUDAGuard device_guard(0);
122-
auto buf0 = at::empty_strided({19L, }, {1L, }, at::TensorOptions(c10::Device(at::kCUDA, 0)).dtype(at::kFloat));
123-
// Source Nodes: [add, tensor], Original ATen: [aten.add, aten.lift_fresh]
124-
if (triton_poi_fused_add_lift_fresh_0 == nullptr) {
125-
triton_poi_fused_add_lift_fresh_0 = loadKernel("/tmp/torchinductor_user/mm/cmm6xjgijjffxjku4akv55eyzibirvw6bti6uqmfnruujm5cvvmw.cubin", "triton_poi_fused_add_lift_fresh_0_0d1d2d3");
126-
}
127-
CUdeviceptr var_0 = reinterpret_cast<CUdeviceptr>(constant0.data_ptr());
128-
CUdeviceptr var_1 = reinterpret_cast<CUdeviceptr>(arg0_1.data_ptr());
129-
CUdeviceptr var_2 = reinterpret_cast<CUdeviceptr>(buf0.data_ptr());
130-
auto var_3 = 19;
131-
void* kernel_args_var_0[] = {&var_0, &var_1, &var_2, &var_3};
132-
cudaStream_t stream0 = at::cuda::getCurrentCUDAStream(0);
133-
launchKernel(triton_poi_fused_add_lift_fresh_0, 1, 1, 1, 1, 0, kernel_args_var_0, stream0);
134-
arg0_1.reset();
135-
return {buf0};
136-
}
137-
138-
module = CppWrapperCodeCache.load(cpp_wrapper_src, 'inductor_entry_cpp', 'czbpeilh4qqmbyejdgsbpdfuk2ss5jigl2qjb7xs4gearrjvuwem', True)
152+
inductor_entry = CppWrapperCodeCache.load_pybinding(
153+
argtypes=["std::vector<AtenTensorHandle>"],
154+
main_code=cpp_wrapper_src,
155+
device_type="cuda",
156+
num_outputs=1,
157+
kernel_code=None,
158+
)
139159
140160
def _wrap_func(f):
141161
def g(args):
142-
args_tensor = [arg if isinstance(arg, torch.Tensor) else torch.tensor(arg) for arg in args]
143-
constants_tensor = [constant0]
144-
args_tensor.extend(constants_tensor)
162+
input_tensors = [arg if isinstance(arg, torch.Tensor) else torch.tensor(arg, device='cpu') for arg in args]
163+
input_handles = torch._C._aoti.unsafe_alloc_void_ptrs_from_tensors(input_tensors)
164+
165+
args.clear()
166+
del input_tensors
167+
168+
output_handles = f(input_handles)
169+
output_tensors = torch._C._aoti.alloc_tensors_by_stealing_from_void_ptrs(output_handles)
170+
return output_tensors
145171
146-
return f(args_tensor)
147172
return g
148-
call = _wrap_func(module.inductor_entry_cpp)
173+
174+
call = _wrap_func(inductor_entry)
149175
150176
151177
Conclusion
152178
------------
153179

154-
In this tutorial, we introduced a new C++ wrapper in TorchInductor to speed up your models with just two lines of code changes.
155-
We explained the motivation of this new feature and walked through the easy-to-use API to activate this experimental feature.
156-
Furthermore, we demonstrated the Inductor-generated code using the default Python wrapper and the new C++ wrapper on both CPU and GPU
157-
to visually showcase the difference between these two wrappers.
180+
This tutorial introduced the **C++ wrapper** feature in TorchInductor, designed
181+
to improve model performance with minimal code modification. We described the
182+
motivation for this feature, detailed the experimental API used to enable it,
183+
and compared the generated outputs of the default Python wrapper and the new
184+
C++ wrapper on both CPU and GPU backends to illustrate their distinctions.
158185

159-
This feature is still in prototype stage. If you have any feature requests or run into any issues, please file a bug report at `GitHub issues <https://github.com/pytorch/pytorch/issues>`_.
186+
.. For more information on torch.compile, see
187+
..
188+
.. .. _torch.compile tutorial: https://docs.pytorch.org/tutorials/intermediate/torch_compile_tutorial.html
189+
.. .. TORCH_LOGS tutorial: https://docs.pytorch.org/tutorials/recipes/torch_logs.html

0 commit comments

Comments
 (0)