1- Inductor  C++ Wrapper Tutorial
1+ TorchInductor  C++ Wrapper Tutorial
22============================================================== 
33
44**Author **: `Chunyuan Wu  <https://github.com/chunyuan-w >`_, `Bin Bao  <https://github.com/desertfire >`__, `Jiong Gong  <https://github.com/jgong5 >`__
@@ -10,85 +10,120 @@ Prerequisites:
1010Introduction
1111------------ 
1212
13- Python, as the primary interface of PyTorch, is easy to use and efficient for development and debugging. 
14- The Inductor's default wrapper generates Python code to invoke generated kernels and external kernels.
15- However, in deployments requiring high performance, Python, as an interpreted language, runs relatively slower compared to compiled languages.
13+ In ``torch.compile ``, the default backend **TorchInductor ** emits Python wrapper
14+ code that manages memory allocation and kernel invocation. This design provides
15+ flexibility and ease of debugging, but the interpreted nature of Python
16+ introduces runtime overhead in performance-sensitive environments.
1617
17- We implemented an Inductor C++ wrapper by leveraging the PyTorch C++ APIs
18- to generate pure C++ code that combines the generated and external kernels.
19- This allows for the execution of each captured Dynamo graph in pure C++,
20- thereby reducing the Python overhead within the graph.
18+ To address this limitation, TorchInductor includes a specialized mode that
19+ generates **C++ wrapper code ** in place of the Python wrapper, enabling faster
20+ execution with minimal Python involvement.
2121
2222
23- Enabling the API 
23+ Enabling the C++ wrapper mode 
2424---------------- 
25- This feature is still in prototype stage. To activate this feature , add the following to your code:
25+ To enable this C++ wrapper mode for TorchInductor , add the following config  to your code:
2626
2727.. code :: python 
2828
2929    import  torch._inductor.config as  config 
3030    config.cpp_wrapper =  True  
3131
32- 
33- 
3432
3533
3634------------ 
3735
38- We will use the below frontend  code as an example:
36+ We will use the following model  code as an example:
3937
4038.. code :: python 
41-      
39+ 
4240    import  torch 
41+     import  torch._inductor.config as  config 
42+ 
43+     config.cpp_wrapper =  True  
44+ 
45+     def  fn (x , y ): 
46+         return  (x +  y).sum() 
4347
44-     def  fn (x ): 
45-         return  torch.tensor(list (range (2 , 40 , 2 )), device = x.device) +  x 
48+     device =  torch.device(" cuda" if  torch.cuda.is_available() else  " cpu"  
49+     x =  torch.randn(128 , 128 , device = device) 
50+     y =  torch.randn(128 , 128 , device = device) 
4651
47-     x =  torch.randn(1 ) 
48-     opt_fn =  torch.compile()(fn) 
49-     y =  opt_fn(x) 
52+     opt_fn =  torch.compile(fn) 
53+     result =  opt_fn(x, y) 
5054
5155
5256For CPU **
5357
54- The main part of Inductor -generated code with the default Python wrapper will look like this:
58+ The main part of TorchInductor -generated code with the default Python wrapper will look like this:
5559
5660.. code :: python 
5761
58-     def  call (args ): 
59-         arg0_1, =  args 
60-         args.clear() 
61-         assert_size_stride(arg0_1, (1 , ), (1 , )) 
62-         buf0 =  empty_strided((19 , ), (1 , ), device = ' cpu' dtype = torch.float32) 
63-         cpp_fused_add_lift_fresh_0(c_void_p(constant0.data_ptr()), c_void_p(arg0_1.data_ptr()), c_void_p(buf0.data_ptr())) 
64-         del  arg0_1 
65-         return  (buf0, ) 
62+     class  Runner : 
63+         def  __init__ (self partitions ): 
64+             self .partitions =  partitions 
65+ 
66+         def  call (self args ): 
67+             arg0_1, arg1_1 =  args 
68+             args.clear() 
69+             assert_size_stride(arg0_1, (128 , 128 ), (128 , 1 )) 
70+             assert_size_stride(arg1_1, (128 , 128 ), (128 , 1 )) 
71+             buf0 =  empty_strided_cpu((), (), torch.float32) 
72+             cpp_fused_add_sum_0(arg0_1, arg1_1, buf0) 
73+             del  arg0_1 
74+             del  arg1_1 
75+             return  (buf0, ) 
6676
6777call `` function becomes a C++ function
68- ``inductor_entry_cpp `` of the C++ extension `` module ``:
78+ ``inductor_entry_impl ``:
6979
7080.. code :: python 
7181
72-     std::vector< at::Tensor>  inductor_entry_cpp(const std::vector< at::Tensor> &  args) { 
73-         at::Tensor arg0_1 = args[0 ]; 
74-         at::Tensor constant0 = args[1 ]; 
75-         auto buf0 = at::empty_strided({19L  , }, {1L  , }, at::device(at::kCPU).dtype(at::kFloat)); 
76-         cpp_fused_add_lift_fresh_0((long * )(constant0.data_ptr()), (float * )(arg0_1.data_ptr()), (float * )(buf0.data_ptr())); 
82+     cpp_wrapper_src =  ( 
83+     r '''  
84+     # include <torch/csrc/inductor/cpp_wrapper/cpu.h> 
85+     extern "C"  void  cpp_fused_add_sum_0( const float*  in_ptr0, 
86+                         const float*  in_ptr1, 
87+                         float*  out_ptr0) ; 
88+     CACHE_TORCH_DTYPE( float32) ; 
89+     CACHE_TORCH_DEVICE( cpu) ; 
90+ 
91+     void inductor_entry_impl(  
92+         AtenTensorHandle*  
93+             input_handles, // array of input AtenTensorHandle; handles 
94+                             // are stolen; the array itself is borrowed 
95+         AtenTensorHandle*  
96+             output_handles  // array for writing output AtenTensorHandle; handles 
97+                             // will be stolen by the caller; the array itself is 
98+                             // borrowed)  
99+     ) { 
100+         py::gil_scoped_release_simple release; 
101+ 
102+         auto inputs = steal_from_raw_handles_to_raii_handles( input_handles, 2) ; 
103+         auto arg0_1 = std::move( inputs[0 ]) ; 
104+         auto arg1_1 = std::move( inputs[1 ]) ; 
105+         static constexpr int64_t * int_array_0=nullptr; 
106+         AtenTensorHandle buf0_handle; 
107+         AOTI_TORCH_ERROR_CODE_CHECK( aoti_torch_empty_strided( 0, int_array_0, int_array_0, cached_torch_dtype_float32, cached_torch_device_type_cpu,  0, &buf0_handle)) ; 
108+         RAIIAtenTensorHandle buf0( buf0_handle) ; 
109+         cpp_fused_add_sum_0(( const float* ) ( arg0_1. data_ptr( )) , ( const float* ) ( arg1_1. data_ptr( )) , ( float* ) ( buf0. data_ptr( ))) ; 
77110        arg0_1. reset( ) ; 
78-         return  {buf0}; 
79-     } 
80- 
81-     module =  CppWrapperCodeCache.load(cpp_wrapper_src, ' inductor_entry_cpp' ' c2buojsvlqbywxe3itb43hldieh4jqulk72iswa2awalwev7hjn2' False ) 
82- 
83-     def  _wrap_func (f ): 
84-         def  g (args ): 
85-             args_tensor =  [arg if  isinstance (arg, torch.Tensor) else  torch.tensor(arg) for  arg in  args] 
86-             constants_tensor =  [constant0] 
87-             args_tensor.extend(constants_tensor)                     
88- 
89-             return  f(args_tensor) 
90-         return  g 
91-     call =  _wrap_func(module.inductor_entry_cpp) 
111+         arg1_1. reset( ) ; 
112+         output_handles[0 ] = buf0. release( ) ; 
113+     } // inductor_entry_impl 
114+     ...  
115+     '''  
116+     ) 
117+ 
118+     inductor_entry =  CppWrapperCodeCache.load_pybinding( 
119+         argtypes = [" std::vector<AtenTensorHandle>"  
120+         main_code = cpp_wrapper_src, 
121+         device_type = " cpu"  
122+         num_outputs = 1 , 
123+         kernel_code = None , 
124+     ) 
125+ 
126+     call =  _wrap_func(inductor_entry) 
92127
93128For GPU **
94129
@@ -114,46 +149,41 @@ With the C++ wrapper turned on, the below equivalent C++ code will be generated:
114149
115150.. code :: python 
116151
117-     std::vector< at::Tensor>  inductor_entry_cpp(const std::vector< at::Tensor> &  args) { 
118-         at::Tensor arg0_1 = args[0 ]; 
119-         at::Tensor constant0 = args[1 ]; 
120- 
121-         at::cuda::CUDAGuard device_guard(0 ); 
122-         auto buf0 = at::empty_strided({19L  , }, {1L  , }, at::TensorOptions(c10::Device(at::kCUDA, 0 )).dtype(at::kFloat)); 
123-         //  Source Nodes: [add, tensor], Original ATen: [aten.add, aten.lift_fresh] 
124-         if  (triton_poi_fused_add_lift_fresh_0 ==  nullptr) { 
125-             triton_poi_fused_add_lift_fresh_0 = loadKernel(" /tmp/torchinductor_user/mm/cmm6xjgijjffxjku4akv55eyzibirvw6bti6uqmfnruujm5cvvmw.cubin" " triton_poi_fused_add_lift_fresh_0_0d1d2d3"  
126-         } 
127-         CUdeviceptr var_0 = reinterpret_cast< CUdeviceptr> (constant0.data_ptr()); 
128-         CUdeviceptr var_1 = reinterpret_cast< CUdeviceptr> (arg0_1.data_ptr()); 
129-         CUdeviceptr var_2 = reinterpret_cast< CUdeviceptr> (buf0.data_ptr()); 
130-         auto var_3 = 19 ; 
131-         void*  kernel_args_var_0[] = {& var_0, & var_1, & var_2, & var_3}; 
132-         cudaStream_t stream0 = at::cuda::getCurrentCUDAStream(0 ); 
133-         launchKernel(triton_poi_fused_add_lift_fresh_0, 1 , 1 , 1 , 1 , 0 , kernel_args_var_0, stream0); 
134-         arg0_1.reset(); 
135-         return  {buf0}; 
136-     } 
137- 
138-     module =  CppWrapperCodeCache.load(cpp_wrapper_src, ' inductor_entry_cpp' ' czbpeilh4qqmbyejdgsbpdfuk2ss5jigl2qjb7xs4gearrjvuwem' True ) 
152+     inductor_entry =  CppWrapperCodeCache.load_pybinding( 
153+         argtypes = [" std::vector<AtenTensorHandle>"  
154+         main_code = cpp_wrapper_src, 
155+         device_type = " cuda"  
156+         num_outputs = 1 , 
157+         kernel_code = None , 
158+     ) 
139159
140160    def  _wrap_func (f ): 
141161        def  g (args ): 
142-             args_tensor =  [arg if  isinstance (arg, torch.Tensor) else  torch.tensor(arg) for  arg in  args] 
143-             constants_tensor =  [constant0] 
144-             args_tensor.extend(constants_tensor) 
162+             input_tensors =  [arg if  isinstance (arg, torch.Tensor) else  torch.tensor(arg, device = ' cpu' for  arg in  args] 
163+             input_handles =  torch._C._aoti.unsafe_alloc_void_ptrs_from_tensors(input_tensors) 
164+ 
165+             args.clear() 
166+             del  input_tensors 
167+ 
168+             output_handles =  f(input_handles) 
169+             output_tensors =  torch._C._aoti.alloc_tensors_by_stealing_from_void_ptrs(output_handles) 
170+             return  output_tensors 
145171
146-             return  f(args_tensor) 
147172        return  g 
148-     call =  _wrap_func(module.inductor_entry_cpp) 
173+ 
174+     call =  _wrap_func(inductor_entry) 
149175
150176
151177
152178------------ 
153179
154- In this tutorial, we introduced a new C++ wrapper in TorchInductor to speed up your models with just two lines of code changes.
155- We explained the motivation of this new feature and walked through the easy-to-use API to activate this experimental feature.
156- Furthermore, we demonstrated the Inductor-generated code using the default Python wrapper and the new C++ wrapper on both CPU and GPU
157- to visually showcase the difference between these two wrappers.
180+ This tutorial introduced the **C++ wrapper ** feature in TorchInductor, designed
181+ to improve model performance with minimal code modification. We described the
182+ motivation for this feature, detailed the experimental API used to enable it,
183+ and compared the generated outputs of the default Python wrapper and the new
184+ C++ wrapper on both CPU and GPU backends to illustrate their distinctions.
158185
159- This feature is still in prototype stage. If you have any feature requests or run into any issues, please file a bug report at `GitHub issues  <https://github.com/pytorch/pytorch/issues >`_.
186+ ..  For more information on torch.compile, see
187+ .. 
188+ ..  .. _torch.compile tutorial: https://docs.pytorch.org/tutorials/intermediate/torch_compile_tutorial.html
189+ ..  .. TORCH_LOGS tutorial: https://docs.pytorch.org/tutorials/recipes/torch_logs.html
0 commit comments