Skip to content

Commit 14d03b5

Browse files
committed
Add unique op
1 parent 69ae7f4 commit 14d03b5

File tree

2 files changed

+113
-1
lines changed

2 files changed

+113
-1
lines changed

onnxscript/function_libs/torch_lib/ops/core.py

Lines changed: 72 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from __future__ import annotations
1414

1515
import math
16+
import re
1617
from typing import Any, Optional, Sequence, Tuple, Union
1718

1819
from onnxscript import (
@@ -8372,6 +8373,63 @@ def aten_unique_consecutive(
83728373
raise NotImplementedError()
83738374

83748375

8376+
_NOT_IMPLEMENTED_UNIQUE = re.compile(
8377+
r"NOT_IMPLEMENTED\s*:\s*Could\s+not\s+find\s+an\s+implementation\s+for\s+Unique"
8378+
)
8379+
"""
8380+
A pattern to detect an unsupported (not implemented) Unique operator
8381+
"""
8382+
8383+
@torch_op("aten::unique", trace_only=True)
8384+
def aten_unique(
8385+
self: TensorType,
8386+
sorted: bool = True,
8387+
return_inverse: bool = False,
8388+
return_counts: bool = False,
8389+
dim: Optional[int] = None,
8390+
) -> tuple[TensorType, TensorType, TensorType]:
8391+
"""unique(Tensor self, bool sorted=True, bool return_inverse=False, bool return_counts=False, int? dim=None) -> (Tensor, Tensor?, Tensor?)"""
8392+
8393+
try:
8394+
if dim is None:
8395+
unique_values, inverse_indices, counts = aten_unique2(self, sorted, return_inverse, return_counts)
8396+
else:
8397+
unique_values, inverse_indices, counts = aten_unique_dim(self, dim, sorted, return_inverse, return_counts)
8398+
except Exception as e:
8399+
# try to provide a more informative error message
8400+
if _NOT_IMPLEMENTED_UNIQUE.search(str(e)) is not None:
8401+
raise NotImplementedError(
8402+
f"'onnxruntime' does not yet support Unique(11) operator with dtype={self.dtype}'"
8403+
) from e
8404+
raise
8405+
if return_inverse:
8406+
if return_counts:
8407+
result = unique_values, inverse_indices, counts
8408+
else:
8409+
result = unique_values, inverse_indices
8410+
elif return_counts:
8411+
result = unique_values, counts
8412+
else:
8413+
result = unique_values
8414+
return result
8415+
8416+
8417+
@torch_op("aten::_unique2", traceable=True)
8418+
def aten_unique2(
8419+
self: TensorType,
8420+
sorted: bool = True,
8421+
return_inverse: bool = False,
8422+
return_counts: bool = False
8423+
) -> tuple[TensorType, TensorType, TensorType]:
8424+
"""unique(Tensor self, bool sorted=True, bool return_inverse=False, bool return_counts=False) -> (Tensor, Tensor, Tensor)"""
8425+
8426+
unique_values, _, inverse_indices, counts = op.Unique(self, axis=None, sorted=sorted)
8427+
input_size = op.Shape(self)
8428+
inverse_indices = op.Reshape(inverse_indices, input_size)
8429+
return unique_values, inverse_indices, counts
8430+
8431+
8432+
@torch_op("aten::unique_dim", traceable=True)
83758433
def aten_unique_dim(
83768434
self: TensorType,
83778435
dim: int,
@@ -8381,7 +8439,20 @@ def aten_unique_dim(
83818439
) -> tuple[TensorType, TensorType, TensorType]:
83828440
"""unique_dim(Tensor self, int dim, bool sorted=True, bool return_inverse=False, bool return_counts=False) -> (Tensor, Tensor, Tensor)"""
83838441

8384-
raise NotImplementedError()
8442+
unique_values, _, inverse_indices, counts = op.Unique(self, axis=dim, sorted=sorted)
8443+
input_size = op.Shape(self)
8444+
# PyTorch accepts negative dim as reversed counting
8445+
input_rank = op.Size(input_size)
8446+
dim = input_rank + dim
8447+
dim = dim % input_rank
8448+
starts = op.Reshape(dim, [-1])
8449+
ends = op.Reshape(dim + 1, [-1])
8450+
input_dim_size = op.Slice(input_size, starts=starts, ends=ends)
8451+
inverse_indices = op.Reshape(inverse_indices, input_dim_size)
8452+
output_size = op.Shape(unique_values)
8453+
output_dim_size = op.Slice(output_size, starts=starts, ends=ends)
8454+
counts = op.Reshape(counts, output_dim_size)
8455+
return unique_values, inverse_indices, counts
83858456

83868457

83878458
def aten_unique_dim_consecutive(

tests/function_libs/torch_lib/ops_test_data.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -438,6 +438,34 @@ def _where_input_wrangler(
438438
return args, kwargs
439439

440440

441+
def _unique_unsorted_xfail_matcher(
442+
sample: Any
443+
) -> bool:
444+
# torch.unique always sorts, so the results are not guaranteed to be
445+
# equivalent to the output of the ONNX op
446+
expect_fail = False
447+
if sample.kwargs.get("sorted", None) is False and sample.input.numel() > 1:
448+
# sorted == None is equivalent to True
449+
# the result will be mismatched if the input is not sorted
450+
# To expect equality, we must verify that the first appearance
451+
# of each unique value is in ascending order in the input
452+
test_kwargs = dict(sample.kwargs)
453+
test_kwargs['return_inverse'] = True
454+
test_kwargs['return_counts'] = False
455+
_, inverse = torch.unique(sample.input, **test_kwargs)
456+
observed = set()
457+
max_observed: Optional[int] = None
458+
for inv_idx in inverse.flatten().tolist():
459+
if inv_idx not in observed:
460+
if max_observed is not None and inv_idx < max_observed:
461+
expect_fail = True
462+
break
463+
observed.add(inv_idx)
464+
if max_observed is None or inv_idx > max_observed:
465+
max_observed = inv_idx
466+
return expect_fail
467+
468+
441469
# Ops to be tested for numerical consistency between onnx and pytorch
442470
# Find the names of the OpInfos in torch/testing/_internal/common_methods_invocations.py
443471
TESTED_TORCHLIB_OPS: tuple[TorchLibOpInfo, ...] = (
@@ -2325,6 +2353,19 @@ def _where_input_wrangler(
23252353
TorchLibOpInfo(
23262354
"transpose", core_ops.aten_transpose_complex, trace_only=True, complex=True
23272355
),
2356+
TorchLibOpInfo(
2357+
"unique",
2358+
core_ops.aten_unique,
2359+
trace_only=True
2360+
).xfail(
2361+
matcher=lambda sample: sample.input.dtype not in {
2362+
torch.float64, torch.float32, torch.float16, torch.int64, torch.int8,
2363+
},
2364+
reason="'onnxruntime' does not implement Unique(11) with the given dtype",
2365+
).xfail(
2366+
matcher=_unique_unsorted_xfail_matcher,
2367+
reason="torch.unique always sorts, so passing 'sorted=False' leads to mismatched outputs",
2368+
),
23282369
TorchLibOpInfo(
23292370
"var_mean",
23302371
core_ops.aten_var_mean,

0 commit comments

Comments
 (0)