Skip to content

Commit

Permalink
Add unique op
Browse files Browse the repository at this point in the history
  • Loading branch information
a-gardner1 committed May 20, 2024
1 parent 69ae7f4 commit 14d03b5
Show file tree
Hide file tree
Showing 2 changed files with 113 additions and 1 deletion.
73 changes: 72 additions & 1 deletion onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from __future__ import annotations

import math
import re
from typing import Any, Optional, Sequence, Tuple, Union

from onnxscript import (
Expand Down Expand Up @@ -8372,6 +8373,63 @@ def aten_unique_consecutive(
raise NotImplementedError()


_NOT_IMPLEMENTED_UNIQUE = re.compile(
r"NOT_IMPLEMENTED\s*:\s*Could\s+not\s+find\s+an\s+implementation\s+for\s+Unique"
)
"""
A pattern to detect an unsupported (not implemented) Unique operator
"""

@torch_op("aten::unique", trace_only=True)
def aten_unique(
self: TensorType,
sorted: bool = True,
return_inverse: bool = False,
return_counts: bool = False,
dim: Optional[int] = None,
) -> tuple[TensorType, TensorType, TensorType]:
"""unique(Tensor self, bool sorted=True, bool return_inverse=False, bool return_counts=False, int? dim=None) -> (Tensor, Tensor?, Tensor?)"""

try:
if dim is None:
unique_values, inverse_indices, counts = aten_unique2(self, sorted, return_inverse, return_counts)
else:
unique_values, inverse_indices, counts = aten_unique_dim(self, dim, sorted, return_inverse, return_counts)
except Exception as e:
# try to provide a more informative error message
if _NOT_IMPLEMENTED_UNIQUE.search(str(e)) is not None:
raise NotImplementedError(
f"'onnxruntime' does not yet support Unique(11) operator with dtype={self.dtype}'"
) from e
raise
if return_inverse:
if return_counts:
result = unique_values, inverse_indices, counts
else:
result = unique_values, inverse_indices
elif return_counts:
result = unique_values, counts
else:
result = unique_values
return result


@torch_op("aten::_unique2", traceable=True)
def aten_unique2(
self: TensorType,
sorted: bool = True,
return_inverse: bool = False,
return_counts: bool = False
) -> tuple[TensorType, TensorType, TensorType]:
"""unique(Tensor self, bool sorted=True, bool return_inverse=False, bool return_counts=False) -> (Tensor, Tensor, Tensor)"""

unique_values, _, inverse_indices, counts = op.Unique(self, axis=None, sorted=sorted)
input_size = op.Shape(self)
inverse_indices = op.Reshape(inverse_indices, input_size)
return unique_values, inverse_indices, counts


@torch_op("aten::unique_dim", traceable=True)
def aten_unique_dim(
self: TensorType,
dim: int,
Expand All @@ -8381,7 +8439,20 @@ def aten_unique_dim(
) -> tuple[TensorType, TensorType, TensorType]:
"""unique_dim(Tensor self, int dim, bool sorted=True, bool return_inverse=False, bool return_counts=False) -> (Tensor, Tensor, Tensor)"""

raise NotImplementedError()
unique_values, _, inverse_indices, counts = op.Unique(self, axis=dim, sorted=sorted)
input_size = op.Shape(self)
# PyTorch accepts negative dim as reversed counting
input_rank = op.Size(input_size)
dim = input_rank + dim
dim = dim % input_rank
starts = op.Reshape(dim, [-1])
ends = op.Reshape(dim + 1, [-1])
input_dim_size = op.Slice(input_size, starts=starts, ends=ends)
inverse_indices = op.Reshape(inverse_indices, input_dim_size)
output_size = op.Shape(unique_values)
output_dim_size = op.Slice(output_size, starts=starts, ends=ends)
counts = op.Reshape(counts, output_dim_size)
return unique_values, inverse_indices, counts


def aten_unique_dim_consecutive(
Expand Down
41 changes: 41 additions & 0 deletions tests/function_libs/torch_lib/ops_test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,6 +438,34 @@ def _where_input_wrangler(
return args, kwargs


def _unique_unsorted_xfail_matcher(
sample: Any
) -> bool:
# torch.unique always sorts, so the results are not guaranteed to be
# equivalent to the output of the ONNX op
expect_fail = False
if sample.kwargs.get("sorted", None) is False and sample.input.numel() > 1:
# sorted == None is equivalent to True
# the result will be mismatched if the input is not sorted
# To expect equality, we must verify that the first appearance
# of each unique value is in ascending order in the input
test_kwargs = dict(sample.kwargs)
test_kwargs['return_inverse'] = True
test_kwargs['return_counts'] = False
_, inverse = torch.unique(sample.input, **test_kwargs)
observed = set()
max_observed: Optional[int] = None
for inv_idx in inverse.flatten().tolist():
if inv_idx not in observed:
if max_observed is not None and inv_idx < max_observed:
expect_fail = True
break
observed.add(inv_idx)
if max_observed is None or inv_idx > max_observed:
max_observed = inv_idx
return expect_fail


# Ops to be tested for numerical consistency between onnx and pytorch
# Find the names of the OpInfos in torch/testing/_internal/common_methods_invocations.py
TESTED_TORCHLIB_OPS: tuple[TorchLibOpInfo, ...] = (
Expand Down Expand Up @@ -2325,6 +2353,19 @@ def _where_input_wrangler(
TorchLibOpInfo(
"transpose", core_ops.aten_transpose_complex, trace_only=True, complex=True
),
TorchLibOpInfo(
"unique",
core_ops.aten_unique,
trace_only=True
).xfail(
matcher=lambda sample: sample.input.dtype not in {
torch.float64, torch.float32, torch.float16, torch.int64, torch.int8,
},
reason="'onnxruntime' does not implement Unique(11) with the given dtype",
).xfail(
matcher=_unique_unsorted_xfail_matcher,
reason="torch.unique always sorts, so passing 'sorted=False' leads to mismatched outputs",
),
TorchLibOpInfo(
"var_mean",
core_ops.aten_var_mean,
Expand Down

0 comments on commit 14d03b5

Please sign in to comment.