From 14d03b539095a80df070a2d570e023db2a4bee58 Mon Sep 17 00:00:00 2001 From: a-gardner1 Date: Wed, 15 May 2024 21:30:20 +0000 Subject: [PATCH] Add unique op --- .../function_libs/torch_lib/ops/core.py | 73 ++++++++++++++++++- .../function_libs/torch_lib/ops_test_data.py | 41 +++++++++++ 2 files changed, 113 insertions(+), 1 deletion(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index c66a978e9..9da750a27 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -13,6 +13,7 @@ from __future__ import annotations import math +import re from typing import Any, Optional, Sequence, Tuple, Union from onnxscript import ( @@ -8372,6 +8373,63 @@ def aten_unique_consecutive( raise NotImplementedError() +_NOT_IMPLEMENTED_UNIQUE = re.compile( + r"NOT_IMPLEMENTED\s*:\s*Could\s+not\s+find\s+an\s+implementation\s+for\s+Unique" + ) +""" +A pattern to detect an unsupported (not implemented) Unique operator +""" + +@torch_op("aten::unique", trace_only=True) +def aten_unique( + self: TensorType, + sorted: bool = True, + return_inverse: bool = False, + return_counts: bool = False, + dim: Optional[int] = None, +) -> tuple[TensorType, TensorType, TensorType]: + """unique(Tensor self, bool sorted=True, bool return_inverse=False, bool return_counts=False, int? dim=None) -> (Tensor, Tensor?, Tensor?)""" + + try: + if dim is None: + unique_values, inverse_indices, counts = aten_unique2(self, sorted, return_inverse, return_counts) + else: + unique_values, inverse_indices, counts = aten_unique_dim(self, dim, sorted, return_inverse, return_counts) + except Exception as e: + # try to provide a more informative error message + if _NOT_IMPLEMENTED_UNIQUE.search(str(e)) is not None: + raise NotImplementedError( + f"'onnxruntime' does not yet support Unique(11) operator with dtype={self.dtype}'" + ) from e + raise + if return_inverse: + if return_counts: + result = unique_values, inverse_indices, counts + else: + result = unique_values, inverse_indices + elif return_counts: + result = unique_values, counts + else: + result = unique_values + return result + + +@torch_op("aten::_unique2", traceable=True) +def aten_unique2( + self: TensorType, + sorted: bool = True, + return_inverse: bool = False, + return_counts: bool = False +) -> tuple[TensorType, TensorType, TensorType]: + """unique(Tensor self, bool sorted=True, bool return_inverse=False, bool return_counts=False) -> (Tensor, Tensor, Tensor)""" + + unique_values, _, inverse_indices, counts = op.Unique(self, axis=None, sorted=sorted) + input_size = op.Shape(self) + inverse_indices = op.Reshape(inverse_indices, input_size) + return unique_values, inverse_indices, counts + + +@torch_op("aten::unique_dim", traceable=True) def aten_unique_dim( self: TensorType, dim: int, @@ -8381,7 +8439,20 @@ def aten_unique_dim( ) -> tuple[TensorType, TensorType, TensorType]: """unique_dim(Tensor self, int dim, bool sorted=True, bool return_inverse=False, bool return_counts=False) -> (Tensor, Tensor, Tensor)""" - raise NotImplementedError() + unique_values, _, inverse_indices, counts = op.Unique(self, axis=dim, sorted=sorted) + input_size = op.Shape(self) + # PyTorch accepts negative dim as reversed counting + input_rank = op.Size(input_size) + dim = input_rank + dim + dim = dim % input_rank + starts = op.Reshape(dim, [-1]) + ends = op.Reshape(dim + 1, [-1]) + input_dim_size = op.Slice(input_size, starts=starts, ends=ends) + inverse_indices = op.Reshape(inverse_indices, input_dim_size) + output_size = op.Shape(unique_values) + output_dim_size = op.Slice(output_size, starts=starts, ends=ends) + counts = op.Reshape(counts, output_dim_size) + return unique_values, inverse_indices, counts def aten_unique_dim_consecutive( diff --git a/tests/function_libs/torch_lib/ops_test_data.py b/tests/function_libs/torch_lib/ops_test_data.py index cff34897d..37f0b4693 100644 --- a/tests/function_libs/torch_lib/ops_test_data.py +++ b/tests/function_libs/torch_lib/ops_test_data.py @@ -438,6 +438,34 @@ def _where_input_wrangler( return args, kwargs +def _unique_unsorted_xfail_matcher( + sample: Any +) -> bool: + # torch.unique always sorts, so the results are not guaranteed to be + # equivalent to the output of the ONNX op + expect_fail = False + if sample.kwargs.get("sorted", None) is False and sample.input.numel() > 1: + # sorted == None is equivalent to True + # the result will be mismatched if the input is not sorted + # To expect equality, we must verify that the first appearance + # of each unique value is in ascending order in the input + test_kwargs = dict(sample.kwargs) + test_kwargs['return_inverse'] = True + test_kwargs['return_counts'] = False + _, inverse = torch.unique(sample.input, **test_kwargs) + observed = set() + max_observed: Optional[int] = None + for inv_idx in inverse.flatten().tolist(): + if inv_idx not in observed: + if max_observed is not None and inv_idx < max_observed: + expect_fail = True + break + observed.add(inv_idx) + if max_observed is None or inv_idx > max_observed: + max_observed = inv_idx + return expect_fail + + # Ops to be tested for numerical consistency between onnx and pytorch # Find the names of the OpInfos in torch/testing/_internal/common_methods_invocations.py TESTED_TORCHLIB_OPS: tuple[TorchLibOpInfo, ...] = ( @@ -2325,6 +2353,19 @@ def _where_input_wrangler( TorchLibOpInfo( "transpose", core_ops.aten_transpose_complex, trace_only=True, complex=True ), + TorchLibOpInfo( + "unique", + core_ops.aten_unique, + trace_only=True + ).xfail( + matcher=lambda sample: sample.input.dtype not in { + torch.float64, torch.float32, torch.float16, torch.int64, torch.int8, + }, + reason="'onnxruntime' does not implement Unique(11) with the given dtype", + ).xfail( + matcher=_unique_unsorted_xfail_matcher, + reason="torch.unique always sorts, so passing 'sorted=False' leads to mismatched outputs", + ), TorchLibOpInfo( "var_mean", core_ops.aten_var_mean,