-
Notifications
You must be signed in to change notification settings - Fork 88
Merge output shape with input shape instead of override #2578
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR fixes a shape handling issue in the constant folding optimization for Cast operations. The change prevents static shapes from being incorrectly converted to dynamic shapes by merging output and input shapes instead of overriding them.
- Replaces direct shape override with intelligent shape merging logic
- Preserves static shape information when possible during Cast operations
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## main #2578 +/- ##
=======================================
Coverage 70.04% 70.04%
=======================================
Files 223 223
Lines 26215 26213 -2
Branches 2583 2582 -1
=======================================
- Hits 18363 18362 -1
Misses 6946 6946
+ Partials 906 905 -1 ☔ View full report in Codecov by Sentry. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you share a concrete example when you see this case?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
Similar with #2578, for this case: ```python import torch import torch.nn as nn class Model(nn.Module): def forward(self, x): return x.new_zeros(x.shape) def main(): model = Model() args = torch.rand(4, 4), batch = torch.export.Dim("batch") dynamic_shapes = {"x": {0: batch}} torch.onnx.export( model, args, "model_test.onnx", dynamic_shapes=dynamic_shapes, dynamo=True, ) if __name__ == "__main__": main() ``` --------- Co-authored-by: Justin Chu <justinchuby@users.noreply.github.com>

_constant_folding.castoverrideoutput.shapewithinput.shape, that may make a static shape to dynamic shape. Here should use_merge_shapesinstead.