Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add unique op #1547

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
Open

Conversation

a-gardner1
Copy link

Add support for exporting torch.unique following the conclusion of pytorch/pytorch#113118.

Copy link

codecov bot commented May 15, 2024

Codecov Report

Attention: Patch coverage is 57.14286% with 18 lines in your changes missing coverage. Please review.

Project coverage is 77.50%. Comparing base (69ae7f4) to head (f9885f1).
Report is 171 commits behind head on main.

Files with missing lines Patch % Lines
onnxscript/function_libs/torch_lib/ops/core.py 57.14% 16 Missing and 2 partials ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1547      +/-   ##
==========================================
- Coverage   77.56%   77.50%   -0.07%     
==========================================
  Files         214      216       +2     
  Lines       23186    23381     +195     
  Branches     3975     4033      +58     
==========================================
+ Hits        17984    18121     +137     
- Misses       4433     4477      +44     
- Partials      769      783      +14     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@a-gardner1 a-gardner1 marked this pull request as draft May 15, 2024 22:27
Copy link
Collaborator

@justinchuby justinchuby left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your contribution! Could you follow the CLA bot's instruction to get that cleared?

Comment on lines 8385 to 8390
except Exception as e:
# try to provide a more informative error message
if _NOT_IMPLEMENTED_UNIQUE.search(str(e)) is not None:
raise NotImplementedError(
f"'onnxruntime' does not yet support Unique(11) operator with dtype={self.dtype}'"
) from e
Copy link
Collaborator

@justinchuby justinchuby May 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would remove this try-catch as the function here is symbolic; we don't expect them to raise any errors

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed in b528a6a

@justinchuby justinchuby added the topic: torch_lib Related to the torch/aten function lib in development label May 15, 2024
@a-gardner1
Copy link
Author

Thanks for your contribution! Could you follow the CLA bot's instruction to get that cleared?

Yea, I may have jumped the gun a bit. Working on officially getting permission from my employer.

@a-gardner1
Copy link
Author

a-gardner1 commented May 16, 2024

@a-gardner1 please read the following Contributor License Agreement(CLA). If you agree with the CLA, please reply with the following information.

@microsoft-github-policy-service agree [company="{your company}"]

Options:

  • (default - no company specified) I have sole ownership of intellectual property rights to my Submissions and I am not making Submissions in the course of work for my employer.
@microsoft-github-policy-service agree
  • (when company given) I am making Submissions in the course of work for my employer (or my employer has intellectual property rights in my Submissions by contract or applicable law). I have permission from my employer to make Submissions and enter into this Agreement on behalf of my employer. By signing below, the defined term “You” includes me and my employer.
@microsoft-github-policy-service agree company="Microsoft"

Contributor License Agreement

@microsoft-github-policy-service agree [company="Radiance Technologies"]

@microsoft-github-policy-service agree company="Radiance Technologies"

@a-gardner1
Copy link
Author

@microsoft-github-policy-service agree company="Radiance Technologies"

@@ -438,6 +438,34 @@ def _where_input_wrangler(
return args, kwargs


def _unique_unsorted_xfail_matcher(
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@justinchuby I'm not sure what the preferred behavior is here. Should we match torch.unique and ignore the sorted argument (i.e., always sort in aten_unique) or respect the argument and deviate in accordance with this matcher?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if the argument leads to different behavior in cuda/cpu etc? I assume sorted=False means it can be sorted, but it doesn't need to be; and there are some potential performance gain by turning it off. If that's the interpretation I would keep the argument. Otherwise ignoring the argument and matching behavior would also be nice.

Copy link
Author

@a-gardner1 a-gardner1 May 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am investigating differences in behavior between cuda/cpu and have found at least one already (unique_dim on CPU ignores the return_inverse and return_counts arguments whereas the CUDA impl does not). How should these differences be handled? Can the op registration be conditioned by the device somehow, or should I favor CUDA over CPU?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Matching CUDA for now is preferable. Thanks!

@a-gardner1 a-gardner1 marked this pull request as ready for review May 17, 2024 20:35
# HACK: force indices to be in the graph so that it gets a name during optimization
# Otherwise an error will be raised in `onnxscript.Scope.lookup_or_create`
indices_size = op.Shape(indices)
counts = op.Reshape(counts, indices_size)
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I want to note that the way that this function was written in 1d74d59 is functionally equivalent but yields an error in onnxscript.Scope.lookup_or_create because it causes modified to be True in onnxscript.optimizer.optimize, thus causing a second loop of optimization that crashes in the first call to inline_simple_functions.

This seems indicative of a potential bug to me, but I am not knowledgeable enough about the codebase to suggest a cause or fix.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@justinchuby
Copy link
Collaborator

Thanks for completing the CLA. I will take a look next week

@justinchuby justinchuby self-assigned this May 18, 2024
result = unique_values, counts
else:
result = unique_values
return result
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we need to always return the same number of values. Consider returning None when they are not available?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doing so deviates from the behavior of torch.unique and causes this assertion in the unit tests to fail:

assert len(flattened_torch_outputs) == len(flattened_function_outputs)

Please advise on how to address this.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does torch.ops.aten.unique exhibit the same behavior? If it always returns three variables, consider creating a new OpInfo for torch.ops.aten.unique similar to

opinfo_core.OpInfo(
"ops.aten._native_batch_norm_legit.no_stats",
aten_name="_native_batch_norm_legit.no_stats",
dtypes=common_dtype.floating_types_and(torch.bfloat16),
dtypesIfCUDA=common_dtype.floating_types_and(torch.float16, torch.bfloat16),
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
assert_jit_shape_analysis=True,
sample_inputs_func=sample_inputs__native_batch_norm_legit_no_stats,
),
. You may remove the xfail with the custom OpInfo too because you may simply remove the xfail cases.

You may adapt the sample function from https://github.com/pytorch/pytorch/blob/b948b1ad7a9cf61c9692506c60c295fd40e00f43/torch/testing/_internal/common_methods_invocations.py#L3346-L3372

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the pointer to extra_opinfo. It turns out torch.ops.aten.unique does not exist, but torch.ops.aten._unique does. Added OpInfo for it, _unique2, and unique_dim in 14d03b5

"""unique(Tensor self, bool sorted=True, bool return_inverse=False, bool return_counts=False) -> (Tensor, Tensor, Tensor)"""

unique_values, indices, inverse_indices, counts = op.Unique(self, axis=None, sorted=sorted)
# HACK: force indices to be in the graph so that it gets a name during optimization
Copy link
Collaborator

@justinchuby justinchuby May 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this a bug we should fix elsewhere? saw comment below

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this could possibly be considered a different bug. The other one is a side-effect of onnxscript.optimizer.constant_folding.fold_constants, whereas this one is a side-effect of the function linked below, which converts the names of unused outputs to empty strings but only removes them if they are trailing. Since inverse_indices and counts are used, it leads to an error being raised in onnxscript.Scope.lookup_or_create due to the empty string name given to indices.

def remove_unused_optional_outputs(

@@ -8380,8 +8380,21 @@ def aten__unique(
) -> tuple[TensorType, TensorType]:
"""_unique(Tensor self, bool sorted=True, bool return_inverse=False) -> (Tensor, Tensor)"""

unique_values, _, inverse_indices, _ = op.Unique(self, axis=None, sorted=True)
unique_values, indices, inverse_indices, _ = op.Unique(self, axis=None, sorted=True)
# HACK: force indices to be in the graph so that it gets a name during optimization
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suggest removing all hacks. I will go fix what's necessary where the bug is. We are also moving to prefer trace_only=True for new functions so if you can include the flag in @torch_op that would be awesome.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That would be awesome. The hacks are definitely getting out of hand. I'll wait for that fix so that I can continue to test with this locally.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you have a short script handy that will reproduce the error?

Copy link
Author

@a-gardner1 a-gardner1 May 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if __name__ == '__main__':
    import logging
    import torch
    import numpy as np
    import onnx
    import onnxruntime as ort
    for i in range(16):
        sorted = bool(i & 1)
        return_inverse = bool((i & 2) > 1)
        return_counts = bool((i & 4) > 1)
        dim = 0 if bool((i & 8) > 1) else None

        print(
            f"Testing sorted={sorted}, return_inverse={return_inverse}, return_counts={return_counts}, dim={dim}"
        )

        def test_function(
                x: torch.Tensor,
                s: bool = sorted,
                ri: bool = return_inverse,
                rc: bool = return_counts,
                d: int | None = dim) -> Any:
            result = torch.unique(
                x,
                sorted=s,
                return_inverse=ri,
                return_counts=rc,
                dim=d)
            return result

        onnx_program = torch.onnx.dynamo_export(
            test_function,
            torch.arange(10),
            export_options=torch.onnx.ExportOptions(
                dynamic_shapes=True,
                diagnostic_options=torch.onnx.DiagnosticOptions(
                    verbosity_level=logging.DEBUG)))
        onnx_program.save("torch_unique.onnx")
        onnx_inputs = onnx_program.adapt_torch_inputs_to_onnx(torch.arange(10))
        onnx_outputs = onnx_program(*onnx_inputs)
        loaded_onnx_program = onnx.load("torch_unique.onnx")
        onnx.checker.check_model(loaded_onnx_program)
        ort_session = ort.InferenceSession("torch_unique.onnx")
        inputs = np.random.randint(0, 10, 10)
        print(f"Inputs: {inputs}")
        outputs = ort_session.run(None,
                                  {"l_x_": inputs})
        print(f"Outputs: {outputs}")
    print("Success")

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, you should also test using the nightly release of PyTorch with the changes in pytorch/pytorch#126561.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is trace_only=True expected to require significant changes to the way one implements an op? It appears that enabling the flag breaks passing a value to op.ConstantOfShape and also breaks indexing a shape.

For example, op.ConstantOfShape([0], value=[0]) must become op.Cast(op.ConstantOfShape([0]), to=INT64.dtype), and output_size[dim] must become op.Slice(output_size, [dim], [dim+1]).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Your observation is correct. This may be the case because the gaps in implementation we have. Bridging the gaps is in our roadmap but is not the highest priority for the team.

pytorchmergebot pushed a commit to pytorch/pytorch that referenced this pull request Jun 1, 2024
Follow-up to #113118 and #124306.

Developed in coordination with the solution to microsoft/onnxscript#1547

This PR adds the missing fake tensor implementation for `aten.unique_dim`, thus enabling tracing and compilation of `torch.unique` when `dim` is not None.

Local testing has proceeded with the following simple script (provided that one has checked out the changes in microsoft/onnxscript#1547):

```python
    import onnx
    import onnxruntime as ort
    import logging
    import numpy as np
    onnx_program = torch.onnx.dynamo_export(
        lambda x: torch.unique(x,
                               dim=0,
                               return_inverse=True),
        torch.arange(10),
        export_options=torch.onnx.ExportOptions(
            dynamic_shapes=True,
            diagnostic_options=torch.onnx.DiagnosticOptions(
                verbosity_level=logging.DEBUG)))
    onnx_program.save("torch_unique.onnx")
    onnx_inputs = onnx_program.adapt_torch_inputs_to_onnx(torch.arange(10))
    onnx_outputs = onnx_program(*onnx_inputs)
    loaded_onnx_program = onnx.load("torch_unique.onnx")
    onnx.checker.check_model(loaded_onnx_program)
    ort_session = ort.InferenceSession("torch_unique.onnx")
    inputs = np.random.randint(0, 10, 10)
    print(f"Inputs: {inputs}")
    outputs = ort_session.run(None,
                              {
                                  "l_x_": inputs
                              })
    print(f"Outputs: {outputs}")
    print("Success")
```

Co-authored-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: #126561
Approved by: https://github.com/ezyang
petrex pushed a commit to petrex/pytorch that referenced this pull request Jun 5, 2024
Follow-up to pytorch#113118 and pytorch#124306.

Developed in coordination with the solution to microsoft/onnxscript#1547

This PR adds the missing fake tensor implementation for `aten.unique_dim`, thus enabling tracing and compilation of `torch.unique` when `dim` is not None.

Local testing has proceeded with the following simple script (provided that one has checked out the changes in microsoft/onnxscript#1547):

```python
    import onnx
    import onnxruntime as ort
    import logging
    import numpy as np
    onnx_program = torch.onnx.dynamo_export(
        lambda x: torch.unique(x,
                               dim=0,
                               return_inverse=True),
        torch.arange(10),
        export_options=torch.onnx.ExportOptions(
            dynamic_shapes=True,
            diagnostic_options=torch.onnx.DiagnosticOptions(
                verbosity_level=logging.DEBUG)))
    onnx_program.save("torch_unique.onnx")
    onnx_inputs = onnx_program.adapt_torch_inputs_to_onnx(torch.arange(10))
    onnx_outputs = onnx_program(*onnx_inputs)
    loaded_onnx_program = onnx.load("torch_unique.onnx")
    onnx.checker.check_model(loaded_onnx_program)
    ort_session = ort.InferenceSession("torch_unique.onnx")
    inputs = np.random.randint(0, 10, 10)
    print(f"Inputs: {inputs}")
    outputs = ort_session.run(None,
                              {
                                  "l_x_": inputs
                              })
    print(f"Outputs: {outputs}")
    print("Success")
```

Co-authored-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: pytorch#126561
Approved by: https://github.com/ezyang
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
topic: torch_lib Related to the torch/aten function lib in development
Projects
Development

Successfully merging this pull request may close these issues.

2 participants