From 4df24d14ab66afcedb72219978c1c09af7fa437a Mon Sep 17 00:00:00 2001 From: titaiwangms Date: Fri, 25 Oct 2024 20:50:29 +0000 Subject: [PATCH] This reverts commit 7bfe292a4d6f02d73a0c0ab1091de4ac62fa9310. --- onnxscript/_internal/param_manipulation.py | 15 +++++++++++++++ .../graph_building/_graph_building_torch.py | 2 ++ 2 files changed, 17 insertions(+) diff --git a/onnxscript/_internal/param_manipulation.py b/onnxscript/_internal/param_manipulation.py index 5d1332315..7523ded00 100644 --- a/onnxscript/_internal/param_manipulation.py +++ b/onnxscript/_internal/param_manipulation.py @@ -131,3 +131,18 @@ def tag_arguments_with_param_schemas( raise TypeError(f"Required input/attribute '{param}' was not provided") return tagged_args, tagged_kwargs + + +def return_to_args_order( + param_schemas: Sequence[values.ParamSchema], + inputs: list[Any], + attributes: dict[str, Any], +) -> list[Sequence[Any]]: + """Return the inputs and attributes to the order of the function signature.""" + args = [] + for param in param_schemas: + if param.name in attributes: + args.append(attributes.pop(param.name)) + else: + args.append(inputs.pop(0)) + return args diff --git a/onnxscript/function_libs/torch_lib/graph_building/_graph_building_torch.py b/onnxscript/function_libs/torch_lib/graph_building/_graph_building_torch.py index 383d1bdc5..9d36e65ea 100644 --- a/onnxscript/function_libs/torch_lib/graph_building/_graph_building_torch.py +++ b/onnxscript/function_libs/torch_lib/graph_building/_graph_building_torch.py @@ -422,6 +422,8 @@ def eval_function( # type: ignore[override] if function.traceable: inputs = self._graph.preprocess_inputs(inputs) inputs = _wrap_torch_value_to_tensor(inputs) # type: ignore[assignment] + # return to args order, as it's traced onnx function + args = param_manipulation.return_to_args_order(param_schemas, inputs, attributes) # Trace the function call instead of adding the function as a node return function.function(*args) return self._graph.add_function_call(function, inputs, attributes)