Skip to content

Commit

Permalink
This reverts commit 7bfe292.
Browse files Browse the repository at this point in the history
  • Loading branch information
titaiwangms committed Oct 25, 2024
1 parent 7bfe292 commit 4df24d1
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 0 deletions.
15 changes: 15 additions & 0 deletions onnxscript/_internal/param_manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,3 +131,18 @@ def tag_arguments_with_param_schemas(
raise TypeError(f"Required input/attribute '{param}' was not provided")

return tagged_args, tagged_kwargs


def return_to_args_order(
param_schemas: Sequence[values.ParamSchema],
inputs: list[Any],
attributes: dict[str, Any],
) -> list[Sequence[Any]]:
"""Return the inputs and attributes to the order of the function signature."""
args = []
for param in param_schemas:
if param.name in attributes:
args.append(attributes.pop(param.name))
else:
args.append(inputs.pop(0))
return args
Original file line number Diff line number Diff line change
Expand Up @@ -422,6 +422,8 @@ def eval_function( # type: ignore[override]
if function.traceable:
inputs = self._graph.preprocess_inputs(inputs)
inputs = _wrap_torch_value_to_tensor(inputs) # type: ignore[assignment]
# return to args order, as it's traced onnx function
args = param_manipulation.return_to_args_order(param_schemas, inputs, attributes)
# Trace the function call instead of adding the function as a node
return function.function(*args)
return self._graph.add_function_call(function, inputs, attributes)
Expand Down

0 comments on commit 4df24d1

Please sign in to comment.