Skip to content

Conversation

@wodesuck
Copy link
Contributor

_constant_folding.cast override output.shape with input.shape, that may make a static shape to dynamic shape. Here should use _merge_shapes instead.

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR fixes a shape handling issue in the constant folding optimization for Cast operations. The change prevents static shapes from being incorrectly converted to dynamic shapes by merging output and input shapes instead of overriding them.

  • Replaces direct shape override with intelligent shape merging logic
  • Preserves static shape information when possible during Cast operations

@justinchuby
Copy link
Collaborator

@gramalingam @titaiwangms

@codecov
Copy link

codecov bot commented Sep 29, 2025

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 70.04%. Comparing base (dddf0c2) to head (4b79f9b).
⚠️ Report is 3 commits behind head on main.

Additional details and impacted files
@@           Coverage Diff           @@
##             main    #2578   +/-   ##
=======================================
  Coverage   70.04%   70.04%           
=======================================
  Files         223      223           
  Lines       26215    26213    -2     
  Branches     2583     2582    -1     
=======================================
- Hits        18363    18362    -1     
  Misses       6946     6946           
+ Partials      906      905    -1     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@justinchuby justinchuby added this to the 0.5.3 milestone Sep 29, 2025
Copy link
Collaborator

@justinchuby justinchuby left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you share a concrete example when you see this case?

@wodesuck
Copy link
Contributor Author

import torch
import torch.nn as nn


class Model(nn.Module):
    def forward(self, x):
        return x.new_zeros(x.shape, dtype=torch.int64)


def main():
    model = Model()
    args = torch.rand(4, 4),
    batch = torch.export.Dim("batch")
    dynamic_shapes = {"x": {0: batch}}
    torch.onnx.export(
        model,
        args,
        "model_test.onnx",
        dynamic_shapes=dynamic_shapes,
        dynamo=True,
    )


if __name__ == "__main__":
    main()
image

The output shape should be batch x 4, but now it's ? x ?.

Copy link
Collaborator

@justinchuby justinchuby left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

@justinchuby justinchuby merged commit 3a26097 into microsoft:main Sep 30, 2025
55 of 56 checks passed
@wodesuck wodesuck deleted the patch-2 branch September 30, 2025 06:04
justinchuby added a commit that referenced this pull request Sep 30, 2025
Similar with #2578, for this case:

```python
import torch
import torch.nn as nn


class Model(nn.Module):
    def forward(self, x):
        return x.new_zeros(x.shape)


def main():
    model = Model()
    args = torch.rand(4, 4),
    batch = torch.export.Dim("batch")
    dynamic_shapes = {"x": {0: batch}}
    torch.onnx.export(
        model,
        args,
        "model_test.onnx",
        dynamic_shapes=dynamic_shapes,
        dynamo=True,
    )


if __name__ == "__main__":
    main()
```

---------

Co-authored-by: Justin Chu <justinchuby@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

Development

Successfully merging this pull request may close these issues.

2 participants