Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

onnx: fix checks in TorchOnnxToTorch pass to match the ONNX spec #2848

Merged
merged 3 commits into from
Feb 8, 2024

Conversation

ashay
Copy link
Collaborator

@ashay ashay commented Feb 1, 2024

This PR contains three commits to update the validation checks in the ONNX ->
Torch conversion pass for the AveragePool, Pad, and Slice operators:

onnx: fix preconditions for lowering AveragePool ops

The pads attribute of the AveragePool operator specifies the value to
pad at both the beginning as well as the end of the axis (see
https://onnx.ai/onnx/operators/onnx__AveragePool.html#attributes), so
the size of this attribute should be twice the rank of the input tensor.
However, our TorchOnnxToTorch bails out early since it incorrectly
compares the pads attribute with the rank (not twice the rank) of the
input tensor.

This patch fixes the code to match the spec and adds a lit test.

onnx: allow optional constant value for Pad operator

The constant_value input of the onnx.Pad operator is optional (see
https://onnx.ai/onnx/operators/onnx__Pad.html#inputs), but the existing
logic for lowering the operator into the Torch dialect assumes that it
is mandatory.

This patch makes the attribute optional and constructs a default value
(a list of zeros the size of the input tensor) if the attribute was not
specified.

onnx: fix checks for axes and steps inputs of Slice operator

The ONNX Spec for the Slice operator allows the starts and ends
inputs to have fewer indices that the dimensions of the data tensor
(see https://onnx.ai/onnx/operators/onnx__Slice.html), but our code
expects these inputs to be as many as the data tensor's dimensions.

More precisely, the spec requires that the starts and ends inputs
are only as long as the axes input, but since the axes input is
optional, the default type for the axes input has to match the type
for the starts and ends inputs. Moreover, the number of indices in
the steps input also has to match those in the axes inputs (instad
of matching the dimensions of the data input).

This patch fixes the checks in the TorchOnnxToTorch conversion so that
they match the ONNX spec.

ashay added 3 commits February 1, 2024 16:05
The `pads` attribute of the AveragePool operator specifies the value to
pad at both the beginning as well as the end of the axis (see
https://onnx.ai/onnx/operators/onnx__AveragePool.html#attributes), so
the size of this attribute should be twice the rank of the input tensor.
However, our TorchOnnxToTorch bails out early since it incorrectly
compares the pads attribute with the rank (not twice the rank) of the
input tensor.

This patch fixes the code to match the spec and adds a lit test.
The `constant_value` input of the onnx.Pad operator is optional (see
https://onnx.ai/onnx/operators/onnx__Pad.html#inputs), but the existing
logic for lowering the operator into the Torch dialect assumes that it
is mandatory.

This patch makes the attribute optional and constructs a default value
(a list of zeros the size of the input tensor) if the attribute was not
specified.
The ONNX Spec for the Slice operator allows the `starts` and `ends`
inputs to have fewer indices that the dimensions of the `data` tensor
(see https://onnx.ai/onnx/operators/onnx__Slice.html), but our code
expects these inputs to be as many as the `data` tensor's dimensions.

More precisely, the spec requires that the `starts` and `ends` inputs
are only as long as the `axes` input, but since the `axes` input is
optional, the default type for the `axes` input has to match the type
for the `starts` and `ends` inputs.  Moreover, the number of indices in
the `steps` input also has to match those in the `axes` inputs (instad
of matching the dimensions of the `data` input).

This patch fixes the checks in the TorchOnnxToTorch conversion so that
they match the ONNX spec.
Copy link
Collaborator

@ramiro050 ramiro050 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

@ashay ashay merged commit 21f070e into llvm:main Feb 8, 2024
3 checks passed
@ashay ashay deleted the ashay/onnx-spec-mismatch branch February 8, 2024 05:19
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants