Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

def torch_lib_onnx_functions_from_registry() -> Generator[onnxscript.OnnxFunction, None, None]:
for op in registration.default_registry.values():
for func in (*op.overloads, *op.privates, *op.complex):
for func in (*op.overloads, *op.complex):
if isinstance(func, onnxscript.OnnxFunction):
yield func

Expand Down
32 changes: 21 additions & 11 deletions onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3759,7 +3759,6 @@ def aten_grid_sampler(
padding_mode_options = ("zeros", "border", "reflection")
padding_mode_str = padding_mode_options[padding_mode]

# Only one onnx Op so don't put into private function
return op.GridSample(
input,
grid,
Expand All @@ -3785,7 +3784,6 @@ def aten_grid_sampler_2d(
padding_mode_options = ("zeros", "border", "reflection")
padding_mode_str = padding_mode_options[padding_mode]

# Only one onnx Op so don't put into private function
return op.GridSample(
input,
grid,
Expand Down Expand Up @@ -4115,7 +4113,9 @@ def _aten_index_onnx(


@torch_op(("aten::index.Tensor", "aten::_unsafe_index.Tensor"), trace_only=True)
def aten_index(self: TensorType, indices: Sequence[Optional[INT64]]) -> TensorType:
def aten_index(
self: TensorType, indices: Sequence[Optional[Union[INT64, BOOL]]]
) -> TensorType:
"""index.Tensor(Tensor self, Tensor?[] indices) -> Tensor

NOTE: Understanding `aten::index`
Expand All @@ -4135,14 +4135,19 @@ def aten_index(self: TensorType, indices: Sequence[Optional[INT64]]) -> TensorTy

None in `indices` are like fillers for dimensions that cannot be removed in the process.
"""
# Handle Boolean indexing first
for index in indices:
if index is None:
continue
if index.dtype == BOOL.dtype:
return _aten_index_bool(self, indices)

index_ranks = [len(index.shape) for index in indices if index is not None]

return _aten_index_onnx(self, indices, index_ranks)


@torch_op(("aten::index.Tensor", "aten::_unsafe_index.Tensor"), trace_only=True)
def aten_index_bool(self: TensorType, indices: Sequence[Optional[BOOL]]) -> TensorType: # pylint: disable=inconsistent-return-statements
def _aten_index_bool(self: TensorType, indices: Sequence[Optional[BOOL]]) -> TensorType: # pylint: disable=inconsistent-return-statements
index_ranks = [len(index.shape) for index in indices if index is not None]

if index_ranks[0] == 1:
Expand Down Expand Up @@ -4177,9 +4182,9 @@ def aten_index_bool(self: TensorType, indices: Sequence[Optional[BOOL]]) -> Tens
finla_rank = input_rank - (len(index.shape) - 1)
trans_perm = list(range(finla_rank))
trans_perm = trans_perm[-1:] + trans_perm[:-1]
for _ in range(count_of_none):
result = op.Transpose(result, perm=trans_perm)
return result
for _ in range(count_of_none):
result = op.Transpose(result, perm=trans_perm)
return result


def aten_index_add(
Expand All @@ -4201,7 +4206,7 @@ def aten_index_copy(
@torch_op(("aten::index_put", "aten::_unsafe_index_put"), trace_only=True)
def aten_index_put(
self: TReal,
indices: Sequence[INT64],
indices: Sequence[Optional[Union[INT64, BOOL]]],
values: TReal,
accumulate: bool = False,
) -> TReal:
Expand All @@ -4210,6 +4215,12 @@ def aten_index_put(
See implementation of `torch.onnx.symbolic_opset11.index_put
<https://github.com/pytorch/pytorch/blob/main/torch/onnx/symbolic_opset11.py#L212>`_.
"""
# Handle Boolean indexing first
for index in indices:
if index is None:
continue
if index.dtype == BOOL.dtype:
return _aten_index_put_bool(self, indices, values, accumulate=accumulate)

def _make_reshape_list_broadcastable(reshape_list, values_shape):
# Remove ones until the rank of reshape_list matches values_shape.
Expand Down Expand Up @@ -4287,8 +4298,7 @@ def _make_reshape_list_broadcastable(reshape_list, values_shape):
return result


@torch_op("aten::index_put", trace_only=True)
def aten_index_put_bool(
def _aten_index_put_bool(
self: TReal,
indices: Sequence[BOOL],
values: TReal,
Expand Down
1 change: 0 additions & 1 deletion onnxscript/function_libs/torch_lib/ops/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,6 @@ def aten_col2im(
else: # assert len(padding) == 4, already [w, x, y, z]
pads = padding

# Only one ONNX op here so didn't write a private function
return op.Col2Im(
self,
output_size,
Expand Down
29 changes: 17 additions & 12 deletions onnxscript/function_libs/torch_lib/registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,12 @@ class OverloadedFunction:
Attributes:
name: Name of the op. E.g. "aten::add".
overloads: Overloads function.
privates: Private functions not exposed to users.
complex: Support complex functions.
"""

def __init__(self, name: str):
self.name = name
self.overloads: list[Any] = []
self.privates: list[Any] = []
self.complex: list[Any] = []


Expand All @@ -39,17 +37,22 @@ class Registry:
def __init__(self):
self._registry: dict[str, OverloadedFunction] = {}

def register(
self, func: Any, name: str, *, private: bool = False, complex: bool = False
) -> None:
def register(self, func: Any, name: str, *, complex: bool = False) -> None:
"""Register a function."""

if private:
self._registry.setdefault(name, OverloadedFunction(name)).privates.append(func)
elif complex:
self._registry.setdefault(name, OverloadedFunction(name)).complex.append(func)
overloaded_function = self._registry.setdefault(name, OverloadedFunction(name))

if complex:
if overloaded_function.complex:
raise ValueError(
f"Complex overload for '{name}' already registered: {overloaded_function.complex}."
)
overloaded_function.complex.append(func)
else:
self._registry.setdefault(name, OverloadedFunction(name)).overloads.append(func)
if overloaded_function.overloads:
raise ValueError(
f"Real overload for '{name}' already registered: {overloaded_function.overloads}."
)
overloaded_function.overloads.append(func)

def __getitem__(self, name):
return self._registry[name]
Expand Down Expand Up @@ -131,7 +134,9 @@ def wrapper(

assert registry is not None
for name_ in _check_and_normalize_names(name):
registry.register(processed_func, name_, private=private, complex=complex)
if private:
continue
registry.register(processed_func, name_, complex=complex)
return processed_func

return wrapper
5 changes: 5 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,11 @@ onnx = ["py.typed"]

[tool.pytest.ini_options]
addopts = "-rsfEX --tb=short --color=yes"
norecursedirs = [
# Skip test collection because pytest will try to import the modules twice,
# causing the torchlib registry to complain that functions are redefined.
"onnxscript/function_libs/torch_lib/ops",
]

[tool.mypy]
# TODO disallow_incomplete_defs = true
Expand Down
18 changes: 2 additions & 16 deletions tests/function_libs/torch_lib/ops_test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -721,23 +721,10 @@ def _where_input_wrangler(
# TorchLibOpInfo("is_same_size", core_ops.aten_is_same_size), # no test case in OPS_DB
# TorchLibOpInfo("is_nonzero", core_ops.aten_is_nonzero), # no test case in OPS_DB
TorchLibOpInfo("ops.aten.index.Tensor", core_ops.aten_index),
TorchLibOpInfo("ops.aten.index.Tensor.bool", core_ops.aten_index_bool),
TorchLibOpInfo(
"index_put_bool",
core_ops.aten_index_put_bool,
input_wrangler=_index_put_input_wrangler,
).skip(
matcher=lambda sample: sample.args[0][0].dtype != torch.bool,
reason="this Aten overload only supports tensor(bool) as indices",
),
TorchLibOpInfo("ops.aten.index.Tensor.bool", core_ops.aten_index),
TorchLibOpInfo(
"index_put", core_ops.aten_index_put, input_wrangler=_index_put_input_wrangler
)
.skip(
matcher=lambda sample: sample.args[0][0].dtype != torch.int64,
reason="this Aten overload only supports tensor(int) as indices",
)
.xfail(
).skip(
dtypes=(torch.float16,),
matcher=lambda sample: sample.kwargs.get("accumulate") is True,
reason="fixme: ORT only supports float32 when accumulate is True: MLFloat16 data type is not supported with ScatterND when reduction is 'add'",
Expand Down Expand Up @@ -1806,7 +1793,6 @@ def _where_input_wrangler(
ops_test_common.duplicate_opinfo(OPS_DB, "cat", ("concat", "concatenate"))
ops_test_common.duplicate_opinfo(OPS_DB, "clone", ("lift_fresh_copy",))
ops_test_common.duplicate_opinfo(OPS_DB, "div", ("div_mode",))
ops_test_common.duplicate_opinfo(OPS_DB, "index_put", ("index_put_bool",))
ops_test_common.duplicate_opinfo(OPS_DB, "max", ("max_dim",))
ops_test_common.duplicate_opinfo(OPS_DB, "mean", ("mean_dim",))
ops_test_common.duplicate_opinfo(OPS_DB, "min", ("min_dim",))
Expand Down
Loading