Skip to content

Commit ee908eb

Browse files
author
KeDengMS
authored
Symbolic shape inference: fix rank for ConstantOfShape (#5912)
1 parent c2d6100 commit ee908eb

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

onnxruntime/python/tools/symbolic_shape_infer.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -623,7 +623,8 @@ def _infer_ConstantOfShape(self, node):
623623
self.sympy_data_[node.output[0]] = np.ones([int(x) for x in sympy_shape], dtype=np.int64) * numpy_helper.to_array(get_attribute(node, 'value', 0))
624624
else:
625625
# create new dynamic shape
626-
sympy_shape = self._new_symbolic_shape(self._get_shape_rank(node,0), node)
626+
# note input0 is a 1D vector of shape, the new symbolic shape has the rank of the shape vector length
627+
sympy_shape = self._new_symbolic_shape(self._get_shape(node,0)[0], node)
627628

628629
vi.CopyFrom(helper.make_tensor_value_info(node.output[0],
629630
vi.type.tensor_type.elem_type,

0 commit comments

Comments
 (0)