-
Couldn't load subscription status.
- Fork 87
Unsqueeze unbatched input of avg_pool #2646
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
@wodesuck please read the following Contributor License Agreement(CLA). If you agree with the CLA, please reply with the following information.
Contributor License AgreementContribution License AgreementThis Contribution License Agreement (“Agreement”) is agreed to by the party signing below (“You”),
|
|
Could you create a test in https://github.com/microsoft/onnxscript/blob/main/tests/function_libs/torch_lib/e2e_ops_tests.py ? |
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #2646 +/- ##
==========================================
- Coverage 70.46% 70.45% -0.01%
==========================================
Files 224 224
Lines 26572 26577 +5
Branches 2637 2639 +2
==========================================
+ Hits 18723 18724 +1
- Misses 6928 6930 +2
- Partials 921 923 +2 ☔ View full report in Codecov by Sentry. |
|
@justinchuby Test added. |
| class Model(torch.nn.Module): | ||
| def forward(self, x2d, x3d, x4d, x5d): | ||
| return ( | ||
| torch.nn.functional.avg_pool1d(x2d, 2), |
Check failure
Code scanning / lintrunner
PYLINT/E1102 Error test
See not-callable. To disable, use # pylint: disable=not-callable
| def forward(self, x2d, x3d, x4d, x5d): | ||
| return ( | ||
| torch.nn.functional.avg_pool1d(x2d, 2), | ||
| torch.nn.functional.avg_pool1d(x3d, 2), |
Check failure
Code scanning / lintrunner
PYLINT/E1102 Error test
See not-callable. To disable, use # pylint: disable=not-callable
| return ( | ||
| torch.nn.functional.avg_pool1d(x2d, 2), | ||
| torch.nn.functional.avg_pool1d(x3d, 2), | ||
| torch.nn.functional.avg_pool2d(x3d, 2), |
Check failure
Code scanning / lintrunner
PYLINT/E1102 Error test
See not-callable. To disable, use # pylint: disable=not-callable
| torch.nn.functional.avg_pool1d(x2d, 2), | ||
| torch.nn.functional.avg_pool1d(x3d, 2), | ||
| torch.nn.functional.avg_pool2d(x3d, 2), | ||
| torch.nn.functional.avg_pool2d(x4d, 2), |
Check failure
Code scanning / lintrunner
PYLINT/E1102 Error test
See not-callable. To disable, use # pylint: disable=not-callable
| torch.nn.functional.avg_pool1d(x3d, 2), | ||
| torch.nn.functional.avg_pool2d(x3d, 2), | ||
| torch.nn.functional.avg_pool2d(x4d, 2), | ||
| torch.nn.functional.avg_pool3d(x4d, 2), |
Check failure
Code scanning / lintrunner
PYLINT/E1102 Error test
See not-callable. To disable, use # pylint: disable=not-callable
| torch.nn.functional.avg_pool2d(x3d, 2), | ||
| torch.nn.functional.avg_pool2d(x4d, 2), | ||
| torch.nn.functional.avg_pool3d(x4d, 2), | ||
| torch.nn.functional.avg_pool3d(x5d, 2), |
Check failure
Code scanning / lintrunner
PYLINT/E1102 Error test
See not-callable. To disable, use # pylint: disable=not-callable
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR adds support for unbatched input tensors in average pooling operations to match PyTorch's behavior. While ONNX's AveragePool requires NCHW format, PyTorch accepts both batched (NCHW) and unbatched (CHW) inputs. The changes handle unbatched inputs by automatically unsqueezing/squeezing dimensions, similar to the existing max_pool implementation.
Key Changes:
- Introduced a helper function
_aten_avg_pool_onnxthat handles both batched and unbatched inputs - Refactored
avg_pool1d,avg_pool2d, andavg_pool3dto use the new helper function - Added comprehensive tests covering all pooling dimensions with both batched and unbatched inputs
Reviewed Changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 1 comment.
| File | Description |
|---|---|
| onnxscript/function_libs/torch_lib/ops/nn.py | Refactored avg_pool operations to support unbatched inputs via new helper function |
| tests/function_libs/torch_lib/e2e_ops_tests.py | Added test cases for avg_pool operations with various input dimensions |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you. Minor comments and please check this page for lint: https://github.com/microsoft/onnxscript#coding-style
|
There is still something wrong with lint. Would you check? |
|
@titaiwangms Pylint says "torch.nn.functional.avg_pool1d is not callable", that's not true. I have run lintrunner locally without wrong, don't known why it still blame. |
You can go ahead and disable it: To disable, use |
Onnx's
AveragePoolrequire input shape asN,C,H,W, but torch accept bothN,C,H,WandC,H,W. Unsqueeze if input is unbatched, just like whatmax_pooldoes.