@@ -124,72 +124,6 @@ def can_match_buffer_size(input_buf: BufferLike, output_buf: BufferLike):
124
124
return False
125
125
126
126
127
- def convert_arg_type (arg : torch .Argument ) -> str :
128
- from .cpp import CONTAINER_PYTHON_TO_CPP , PYTHON_TO_CPP
129
-
130
- # use x.real_type instead of x.type so that we get ScalarType instead of int
131
- python_type = repr (arg .real_type ) # type: ignore[attr-defined]
132
-
133
- if python_type == "Tensor" :
134
- # Conversions rules follow https://github.com/pytorch/pytorch/tree/main/aten/src/ATen/native#func
135
- if arg .alias_info is not None and arg .alias_info .is_write :
136
- return f"at::{ python_type } &"
137
- else :
138
- return f"at::{ python_type } const&"
139
-
140
- if python_type in PYTHON_TO_CPP :
141
- cpp_type = PYTHON_TO_CPP [python_type ]
142
- return cpp_type
143
-
144
- # Convert args of container types e.g. Optional[*]
145
- for py_container , cpp_container in CONTAINER_PYTHON_TO_CPP .items ():
146
- container_match = re .findall (py_container + r"\[([a-zA-Z_]+)]" , python_type )
147
- if len (container_match ) == 1 :
148
- contained_type = container_match [0 ]
149
- assert contained_type in PYTHON_TO_CPP , (
150
- f"unsupported { py_container } type in convert_arg_type: { contained_type } "
151
- )
152
- cpp_contained_type = PYTHON_TO_CPP [contained_type ]
153
- return f"{ cpp_container } <{ cpp_contained_type } >"
154
-
155
- raise AssertionError (f"unsupport python_type: { python_type } " )
156
-
157
-
158
- def convert_return_type (ret : torch .Argument ) -> str :
159
- # use x.real_type instead of x.type so that we get ScalarType instead of int
160
- python_type = repr (ret .real_type ) # type: ignore[attr-defined]
161
- python_to_cpp = {
162
- "Tensor" : "at::Tensor" ,
163
- "List[Tensor]" : "std::vector<at::Tensor>" ,
164
- }
165
-
166
- cpp_type = python_to_cpp .get (python_type , None )
167
- assert cpp_type is not None , f"NYI return type: { python_type } "
168
- # An output aliasing an input is returned by reference only when it's a
169
- # Tensor, not when it's a Tensor[]. For example, aten.split.Tensor's output
170
- # aliases the input tensor, but the op returns a vector by value.
171
- if python_type == "Tensor" and ret .alias_info is not None :
172
- cpp_type += "&"
173
- return cpp_type
174
-
175
-
176
- def get_cpp_op_schema (kernel : torch ._ops .OpOverload ) -> str :
177
- args = kernel ._schema .arguments
178
- returns = kernel ._schema .returns
179
-
180
- num_returns = len (returns )
181
- assert num_returns > 0 , "must have at least one return value"
182
-
183
- if num_returns == 1 :
184
- cpp_return_value = convert_return_type (returns [0 ])
185
- elif num_returns > 1 :
186
- tuple_returns = ", " .join ([convert_return_type (r ) for r in returns ])
187
- cpp_return_value = f"std::tuple<{ tuple_returns } >"
188
-
189
- cpp_arg_type = [f"{ convert_arg_type (arg )} { arg .name } " for arg in args ]
190
- return f"{ cpp_return_value } ({ ', ' .join (cpp_arg_type )} )" # type: ignore[possibly-undefined]
191
-
192
-
193
127
# TODO: Move to a well known place
194
128
TritonMetaParams = dict [str , int ]
195
129
TritonGrid = Union [
0 commit comments