13
13
from __future__ import annotations
14
14
15
15
import math
16
+ import re
16
17
from typing import Any , Optional , Sequence , Tuple , Union
17
18
18
19
from onnxscript import (
@@ -8372,6 +8373,63 @@ def aten_unique_consecutive(
8372
8373
raise NotImplementedError ()
8373
8374
8374
8375
8376
+ _NOT_IMPLEMENTED_UNIQUE = re .compile (
8377
+ r"NOT_IMPLEMENTED\s*:\s*Could\s+not\s+find\s+an\s+implementation\s+for\s+Unique"
8378
+ )
8379
+ """
8380
+ A pattern to detect an unsupported (not implemented) Unique operator
8381
+ """
8382
+
8383
+ @torch_op ("aten::unique" , trace_only = True )
8384
+ def aten_unique (
8385
+ self : TensorType ,
8386
+ sorted : bool = True ,
8387
+ return_inverse : bool = False ,
8388
+ return_counts : bool = False ,
8389
+ dim : Optional [int ] = None ,
8390
+ ) -> tuple [TensorType , TensorType , TensorType ]:
8391
+ """unique(Tensor self, bool sorted=True, bool return_inverse=False, bool return_counts=False, int? dim=None) -> (Tensor, Tensor?, Tensor?)"""
8392
+
8393
+ try :
8394
+ if dim is None :
8395
+ unique_values , inverse_indices , counts = aten_unique2 (self , sorted , return_inverse , return_counts )
8396
+ else :
8397
+ unique_values , inverse_indices , counts = aten_unique_dim (self , dim , sorted , return_inverse , return_counts )
8398
+ except Exception as e :
8399
+ # try to provide a more informative error message
8400
+ if _NOT_IMPLEMENTED_UNIQUE .search (str (e )) is not None :
8401
+ raise NotImplementedError (
8402
+ f"'onnxruntime' does not yet support Unique(11) operator with dtype={ self .dtype } '"
8403
+ ) from e
8404
+ raise
8405
+ if return_inverse :
8406
+ if return_counts :
8407
+ result = unique_values , inverse_indices , counts
8408
+ else :
8409
+ result = unique_values , inverse_indices
8410
+ elif return_counts :
8411
+ result = unique_values , counts
8412
+ else :
8413
+ result = unique_values
8414
+ return result
8415
+
8416
+
8417
+ @torch_op ("aten::_unique2" , traceable = True )
8418
+ def aten_unique2 (
8419
+ self : TensorType ,
8420
+ sorted : bool = True ,
8421
+ return_inverse : bool = False ,
8422
+ return_counts : bool = False
8423
+ ) -> tuple [TensorType , TensorType , TensorType ]:
8424
+ """unique(Tensor self, bool sorted=True, bool return_inverse=False, bool return_counts=False) -> (Tensor, Tensor, Tensor)"""
8425
+
8426
+ unique_values , _ , inverse_indices , counts = op .Unique (self , axis = None , sorted = sorted )
8427
+ input_size = op .Shape (self )
8428
+ inverse_indices = op .Reshape (inverse_indices , input_size )
8429
+ return unique_values , inverse_indices , counts
8430
+
8431
+
8432
+ @torch_op ("aten::unique_dim" , traceable = True )
8375
8433
def aten_unique_dim (
8376
8434
self : TensorType ,
8377
8435
dim : int ,
@@ -8381,7 +8439,20 @@ def aten_unique_dim(
8381
8439
) -> tuple [TensorType , TensorType , TensorType ]:
8382
8440
"""unique_dim(Tensor self, int dim, bool sorted=True, bool return_inverse=False, bool return_counts=False) -> (Tensor, Tensor, Tensor)"""
8383
8441
8384
- raise NotImplementedError ()
8442
+ unique_values , _ , inverse_indices , counts = op .Unique (self , axis = dim , sorted = sorted )
8443
+ input_size = op .Shape (self )
8444
+ # PyTorch accepts negative dim as reversed counting
8445
+ input_rank = op .Size (input_size )
8446
+ dim = input_rank + dim
8447
+ dim = dim % input_rank
8448
+ starts = op .Reshape (dim , [- 1 ])
8449
+ ends = op .Reshape (dim + 1 , [- 1 ])
8450
+ input_dim_size = op .Slice (input_size , starts = starts , ends = ends )
8451
+ inverse_indices = op .Reshape (inverse_indices , input_dim_size )
8452
+ output_size = op .Shape (unique_values )
8453
+ output_dim_size = op .Slice (output_size , starts = starts , ends = ends )
8454
+ counts = op .Reshape (counts , output_dim_size )
8455
+ return unique_values , inverse_indices , counts
8385
8456
8386
8457
8387
8458
def aten_unique_dim_consecutive (
0 commit comments