Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ONNX tests failing on master #3251

Closed
datumbox opened this issue Jan 14, 2021 · 11 comments · Fixed by #3319
Closed

ONNX tests failing on master #3251

datumbox opened this issue Jan 14, 2021 · 11 comments · Fixed by #3319

Comments

@datumbox
Copy link
Contributor

datumbox commented Jan 14, 2021

🐛 Bug

I seems that the ONNX tests are failing today on the latest master and the problem is probably related to changes upstream.

This was originally spotted on an unrelated PR but to confirm we reran the tests on previously day's passing master and it failed with the following errors:

======================================================================
ERROR: test_faster_rcnn (__main__.ONNXExporterTester)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "test/test_onnx.py", line 376, in test_faster_rcnn
    tolerate_small_mismatch=True)
  File "test/test_onnx.py", line 53, in run_model
    self.ort_validate(onnx_io, test_inputs, test_ouputs, tolerate_small_mismatch)
  File "test/test_onnx.py", line 72, in ort_validate
    ort_outs = ort_session.run(None, ort_inputs)
  File "/home/circleci/.local/lib/python3.7/site-packages/onnxruntime/capi/onnxruntime_inference_collection.py", line 124, in run
    return self._sess.run(output_names, input_feed, run_options)
onnxruntime.capi.onnxruntime_pybind11_state.RuntimeException: [ONNXRuntimeError] : 6 : RUNTIME_EXCEPTION : Non-zero status code returned while running ReduceMax node. Name:'ReduceMax_1833' Status Message: /onnxruntime_src/onnxruntime/core/providers/cpu/reduction/reduction_ops.cc:487 void onnxruntime::CommonReduce(onnxruntime::OpKernelContext*, std::vector<long int>, int64_t, onnxruntime::ResultsNoTransposePrepareForReduce&, bool) [with T = float; AGG = onnxruntime::ReduceAggregatorMax<float, float>; int64_t = long int] keepdims_ was false. Can't reduce on dim with value of 0 if 'keepdims' is false. Invalid output shape would be produced. input_shape:{0,4}


======================================================================
ERROR: test_keypoint_rcnn (__main__.ONNXExporterTester)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "test/test_onnx.py", line 477, in test_keypoint_rcnn
    tolerate_small_mismatch=True)
  File "test/test_onnx.py", line 53, in run_model
    self.ort_validate(onnx_io, test_inputs, test_ouputs, tolerate_small_mismatch)
  File "test/test_onnx.py", line 72, in ort_validate
    ort_outs = ort_session.run(None, ort_inputs)
  File "/home/circleci/.local/lib/python3.7/site-packages/onnxruntime/capi/onnxruntime_inference_collection.py", line 124, in run
    return self._sess.run(output_names, input_feed, run_options)
onnxruntime.capi.onnxruntime_pybind11_state.RuntimeException: [ONNXRuntimeError] : 6 : RUNTIME_EXCEPTION : Non-zero status code returned while running ReduceMax node. Name:'ReduceMax_1833' Status Message: /onnxruntime_src/onnxruntime/core/providers/cpu/reduction/reduction_ops.cc:487 void onnxruntime::CommonReduce(onnxruntime::OpKernelContext*, std::vector<long int>, int64_t, onnxruntime::ResultsNoTransposePrepareForReduce&, bool) [with T = float; AGG = onnxruntime::ReduceAggregatorMax<float, float>; int64_t = long int] keepdims_ was false. Can't reduce on dim with value of 0 if 'keepdims' is false. Invalid output shape would be produced. input_shape:{0,4}


======================================================================
ERROR: test_mask_rcnn (__main__.ONNXExporterTester)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "test/test_onnx.py", line 429, in test_mask_rcnn
    tolerate_small_mismatch=True)
  File "test/test_onnx.py", line 53, in run_model
    self.ort_validate(onnx_io, test_inputs, test_ouputs, tolerate_small_mismatch)
  File "test/test_onnx.py", line 72, in ort_validate
    ort_outs = ort_session.run(None, ort_inputs)
  File "/home/circleci/.local/lib/python3.7/site-packages/onnxruntime/capi/onnxruntime_inference_collection.py", line 124, in run
    return self._sess.run(output_names, input_feed, run_options)
onnxruntime.capi.onnxruntime_pybind11_state.RuntimeException: [ONNXRuntimeError] : 6 : RUNTIME_EXCEPTION : Non-zero status code returned while running ReduceMax node. Name:'ReduceMax_1833' Status Message: /onnxruntime_src/onnxruntime/core/providers/cpu/reduction/reduction_ops.cc:487 void onnxruntime::CommonReduce(onnxruntime::OpKernelContext*, std::vector<long int>, int64_t, onnxruntime::ResultsNoTransposePrepareForReduce&, bool) [with T = float; AGG = onnxruntime::ReduceAggregatorMax<float, float>; int64_t = long int] keepdims_ was false. Can't reduce on dim with value of 0 if 'keepdims' is false. Invalid output shape would be produced. input_shape:{0,4}

cc @neginraoof, @spandantiwari , @jiafatom

@jiafatom
Copy link
Contributor

jiafatom commented Jan 15, 2021

Found the offending PR on pytorch side: pytorch/pytorch#49410
Details below:

The torch_onnx_test on master branch fails at around 1/14/2021 2:00 am
https://app.circleci.com/pipelines/github/pytorch/vision/5741/workflows/5c6087df-4f7f-458d-b3a7-34061cff1a82/jobs/370853
this test is successful 4 days ago on 1/10/2021.
https://app.circleci.com/pipelines/github/pytorch/vision/5741/workflows/5c6087df-4f7f-458d-b3a7-34061cff1a82https://app.circleci.com/pipelines/github/pytorch/vision/5741/workflows/5c6087df-4f7f-458d-b3a7-34061cff1a82/jobs/367109

This root cause is pytorch/pytorch#50163 which merges pytorch-onnx dev branch into pytorch version, merging happens at 1:56 pm on 1/13/2021.
This aligns the timeline in between the above two dates. This PR consists of 9 PRs together. Then go through these 9 PRs to check which one causes the failure.

I imported the failure test into pytorch test
After testing test_faster_rcnn against pytorch/pytorch@7b920b0
Found the offending PR on pytorch side: pytorch/pytorch#49410
The test passed when commenting out torch._C._jit_pass_onnx_fold_if(graph)
This is understandable to me:
When I see the onnx model converted using dummy input, it has only one NonMaxSuppression node,
but when I see the onnx model converted using a true image, it has two NonMaxSuppression nodes.
This is because of the logic that handles differently on some conditions.

@datumbox
Copy link
Contributor Author

@jiafatom Thanks for the great analysis/investigation.

I see you were involved on the review of the PR pytorch/pytorch#49410. What would you recommend it the right course of action? Do you plan to raise a ticket and let the author of the original PR about the issue or you plan to send a PR upstream?

Let me know if you need anything from my side. Thanks!

@jiafatom
Copy link
Contributor

jiafatom commented Jan 15, 2021

@datumbox I have a PR to fix this issue on upstream: pytorch/pytorch#50582
I imported the above three torch vision test into pytorch test, and it passed locally, and torchvision test looks good "mostly" (see detail [A] below)

It need some time for this PR get merged. For current policy with Facebook, we merge to pytorch branch when we have ~10 PRs in a batch. So we estimate this PR merge may happen in around 10-14 days. That means torch_vision test_onnx will still be red during this time. Do you have any comments on this? Thanks.

Detail [A]:
When I test torch vision test against this PR, it passed test_faster_rcnn and test_mask_rcnn, fails on test_keypoint_rcnn on a single data point out of 561:
With rtol=0.001 and atol=1e-05, found 1 element(s) (out of 561) whose difference(s) exceeded the margin of error (including 0 nan comparisons). The greatest difference was 2.6005320250988007e-05 (-0.014647360891103745 vs. -0.014621355570852757), which occurred at index (29, 4).

The difference is around rtol=0.0017 and atol = 2.7e-5, slightly larger than the bound rtol=0.001 and atol=1e-05. I feel it is acceptable - we can relax the error bar to unblock torch vision UT. Further analysis is a separate issue.

@jiafatom
Copy link
Contributor

jiafatom commented Jan 15, 2021

@datumbox I just brought this issue at group meeting, please feel free to disable onnx test for now if needed. Thanks.

@datumbox
Copy link
Contributor Author

@jiafatom Thanks for looking into it.

We are currently completing the work of including FasterRCNN with MobileNetV3 backbone (#3253). Given that this bug affects the tests of *rcnn models, it makes it hard to confirm that the new model will be ONNX compatible. I wonder if your team could bring the PR faster as an exception for this use-case?

@jiafatom
Copy link
Contributor

@jiafatom Thanks for looking into it.

We are currently completing the work of including FasterRCNN with MobileNetV3 backbone (#3253). Given that this bug affects the tests of *rcnn models, it makes it hard to confirm that the new model will be ONNX compatible. I wonder if your team could bring the PR faster as an exception for this use-case?

cc @neginraoof, @spandantiwari
Can we expedite the PR to unblock this issue? Thanks.

BowenBao pushed a commit to pytorch/pytorch that referenced this issue Jan 20, 2021
…50582)

Fixing pytorch/vision#3251 (PR #49410 triggers the torch vision test build failure, on three tests test_faster_rcnn, test_mask_rcnn, test_keypoint_rcnn. )

The offending PR is fine on pytorch UT, because the torchvision and pytorch test has a gap when we merge them - we are using different test API on two sides, therefore causing some discrepancy.

This PR bridge the gap for the above three tests, and disable _jit_pass_onnx_fold_if pass until it gets fixed.
Allow _jit_pass_onnx_fold_if only when dynamic_axes is None.
@jiafatom
Copy link
Contributor

jiafatom commented Jan 20, 2021

The above mentioned pytorch/pytorch#50582 got merged to our dev branch. After our dev branch merge into master, torch.vision onnx test should be fine.

@datumbox
Copy link
Contributor Author

@jiafatom Much appreciated, thanks for the flexibility!

@jiafatom
Copy link
Contributor

jiafatom commented Jan 20, 2021

Another thing I would bring up on torch vision side,
If any PR changes test_onnx.py, it is highly recommended to update it into pytorch onnx_test CI build as well, so that the torch vision failure can be captured in pytorch side.
For example, #3205 is an example that may need the sync.

@datumbox
Copy link
Contributor Author

If any PR changes test_onnx.py, it is highly recommended to update it into pytorch onnx_test CI build as well, so that the torch vision failure can be captured in pytorch side.

@jiafatom You mean update the tests on test_onnx.py to ensure we are not breaking anything for onnx? I think that the quoted PR does that. Do you suggest doing something more? Thanks!

BowenBao added a commit to pytorch/pytorch that referenced this issue Jan 21, 2021
…50582)

Fixing pytorch/vision#3251 (PR #49410 triggers the torch vision test build failure, on three tests test_faster_rcnn, test_mask_rcnn, test_keypoint_rcnn. )

The offending PR is fine on pytorch UT, because the torchvision and pytorch test has a gap when we merge them - we are using different test API on two sides, therefore causing some discrepancy.

This PR bridge the gap for the above three tests, and disable _jit_pass_onnx_fold_if pass until it gets fixed.
Allow _jit_pass_onnx_fold_if only when dynamic_axes is None.

[ghstack-poisoned]
BowenBao added a commit to pytorch/pytorch that referenced this issue Jan 21, 2021
…s is None (#50582)"

Fixing pytorch/vision#3251 (PR #49410 triggers the torch vision test build failure, on three tests test_faster_rcnn, test_mask_rcnn, test_keypoint_rcnn. )

The offending PR is fine on pytorch UT, because the torchvision and pytorch test has a gap when we merge them - we are using different test API on two sides, therefore causing some discrepancy.

This PR bridge the gap for the above three tests, and disable _jit_pass_onnx_fold_if pass until it gets fixed.
Allow _jit_pass_onnx_fold_if only when dynamic_axes is None.

[ghstack-poisoned]
BowenBao added a commit to pytorch/pytorch that referenced this issue Jan 22, 2021
…s is None (#50582)"

Fixing pytorch/vision#3251 (PR #49410 triggers the torch vision test build failure, on three tests test_faster_rcnn, test_mask_rcnn, test_keypoint_rcnn. )

The offending PR is fine on pytorch UT, because the torchvision and pytorch test has a gap when we merge them - we are using different test API on two sides, therefore causing some discrepancy.

This PR bridge the gap for the above three tests, and disable _jit_pass_onnx_fold_if pass until it gets fixed.
Allow _jit_pass_onnx_fold_if only when dynamic_axes is None.

[ghstack-poisoned]
BowenBao added a commit to pytorch/pytorch that referenced this issue Jan 22, 2021
…s is None (#50582)"

Fixing pytorch/vision#3251 (PR #49410 triggers the torch vision test build failure, on three tests test_faster_rcnn, test_mask_rcnn, test_keypoint_rcnn. )

The offending PR is fine on pytorch UT, because the torchvision and pytorch test has a gap when we merge them - we are using different test API on two sides, therefore causing some discrepancy.

This PR bridge the gap for the above three tests, and disable _jit_pass_onnx_fold_if pass until it gets fixed.
Allow _jit_pass_onnx_fold_if only when dynamic_axes is None.

Differential Revision: [D26023934](https://our.internmc.facebook.com/intern/diff/D26023934)

[ghstack-poisoned]
BowenBao added a commit to pytorch/pytorch that referenced this issue Jan 25, 2021
…s is None (#50582)"


Fixing pytorch/vision#3251 (PR #49410 triggers the torch vision test build failure, on three tests test_faster_rcnn, test_mask_rcnn, test_keypoint_rcnn. )

The offending PR is fine on pytorch UT, because the torchvision and pytorch test has a gap when we merge them - we are using different test API on two sides, therefore causing some discrepancy.

This PR bridge the gap for the above three tests, and disable _jit_pass_onnx_fold_if pass until it gets fixed.
Allow _jit_pass_onnx_fold_if only when dynamic_axes is None.

[ghstack-poisoned]
BowenBao added a commit to pytorch/pytorch that referenced this issue Jan 25, 2021
…s is None (#50582)"


Fixing pytorch/vision#3251 (PR #49410 triggers the torch vision test build failure, on three tests test_faster_rcnn, test_mask_rcnn, test_keypoint_rcnn. )

The offending PR is fine on pytorch UT, because the torchvision and pytorch test has a gap when we merge them - we are using different test API on two sides, therefore causing some discrepancy.

This PR bridge the gap for the above three tests, and disable _jit_pass_onnx_fold_if pass until it gets fixed.
Allow _jit_pass_onnx_fold_if only when dynamic_axes is None.

Differential Revision: [D26050886](https://our.internmc.facebook.com/intern/diff/D26050886)

[ghstack-poisoned]
BowenBao added a commit to pytorch/pytorch that referenced this issue Jan 26, 2021
…s is None (#50582)"


Fixing pytorch/vision#3251 (PR #49410 triggers the torch vision test build failure, on three tests test_faster_rcnn, test_mask_rcnn, test_keypoint_rcnn. )

The offending PR is fine on pytorch UT, because the torchvision and pytorch test has a gap when we merge them - we are using different test API on two sides, therefore causing some discrepancy.

This PR bridge the gap for the above three tests, and disable _jit_pass_onnx_fold_if pass until it gets fixed.
Allow _jit_pass_onnx_fold_if only when dynamic_axes is None.

Differential Revision: [D26050886](https://our.internmc.facebook.com/intern/diff/D26050886)

[ghstack-poisoned]
facebook-github-bot pushed a commit to pytorch/pytorch that referenced this issue Jan 28, 2021
…50582) (#50910)

Summary:
Pull Request resolved: #50910

Fixing pytorch/vision#3251 (PR #49410 triggers the torch vision test build failure, on three tests test_faster_rcnn, test_mask_rcnn, test_keypoint_rcnn. )

The offending PR is fine on pytorch UT, because the torchvision and pytorch test has a gap when we merge them - we are using different test API on two sides, therefore causing some discrepancy.

This PR bridge the gap for the above three tests, and disable _jit_pass_onnx_fold_if pass until it gets fixed.
Allow _jit_pass_onnx_fold_if only when dynamic_axes is None.

Test Plan: Imported from OSS

Reviewed By: pbelevich

Differential Revision: D26050886

Pulled By: SplitInfinity

fbshipit-source-id: b765ffe30914261866dcc761f0d0999fd16169e3
BowenBao added a commit to BowenBao/pytorch that referenced this issue Jan 28, 2021
…ytorch#50582)

Fixing pytorch/vision#3251 (PR pytorch#49410 triggers the torch vision test build failure, on three tests test_faster_rcnn, test_mask_rcnn, test_keypoint_rcnn. )

The offending PR is fine on pytorch UT, because the torchvision and pytorch test has a gap when we merge them - we are using different test API on two sides, therefore causing some discrepancy.

This PR bridge the gap for the above three tests, and disable _jit_pass_onnx_fold_if pass until it gets fixed.
Allow _jit_pass_onnx_fold_if only when dynamic_axes is None.

ghstack-source-id: 53b04161a00c3f7ae959dada3780be360e3071b7
Pull Request resolved: pytorch#50910
@jiafatom
Copy link
Contributor

jiafatom commented Mar 4, 2021

If any PR changes test_onnx.py, it is highly recommended to update it into pytorch onnx_test CI build as well, so that the torch vision failure can be captured in pytorch side.

@jiafatom You mean update the tests on test_onnx.py to ensure we are not breaking anything for onnx? I think that the quoted PR does that. Do you suggest doing something more? Thanks!

Yes, what I mean is to add test_onnx.py change into pytorch repo also. This time I added it here in test/onnx/test_pytorch_onnx_onnxruntime.py in pytorch:
jiafatom/pytorch@ab7fef9#diff-6616e68665d29ba8ab30d14fc99a8459f974fb58368d71daf7bca89701e557dbR173

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants