Skip to content

Commit 7227655

Browse files
Merge input and output shape when removing identity (#2588)
Similar with #2578, for this case: ```python import torch import torch.nn as nn class Model(nn.Module): def forward(self, x): return x.new_zeros(x.shape) def main(): model = Model() args = torch.rand(4, 4), batch = torch.export.Dim("batch") dynamic_shapes = {"x": {0: batch}} torch.onnx.export( model, args, "model_test.onnx", dynamic_shapes=dynamic_shapes, dynamo=True, ) if __name__ == "__main__": main() ``` --------- Co-authored-by: Justin Chu <justinchuby@users.noreply.github.com>
1 parent 9b54ad5 commit 7227655

File tree

1 file changed

+3
-0
lines changed

1 file changed

+3
-0
lines changed

onnxscript/optimizer/_constant_folding.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -608,6 +608,9 @@ def identity(node: ir.Node, op, state: OptimizerState) -> ReturnValue:
608608
input = node.inputs[0]
609609
output = node.outputs[0]
610610
if input is not None and output is not None:
611+
input.shape = _merge_shapes(input.shape, output.shape)
612+
if input.type is None:
613+
input.type = output.type
611614
state.set_sym_value(output, input)
612615
return None
613616

0 commit comments

Comments
 (0)