Skip to content

Commit 19eb510

Browse files
authored
[Dy2stat]Fix error in tensor_shape_transformer. (#37999) (#38168)
修复tensor_shape_transformer中的错误。 之前在类似if len(paddle.shape(x)[0]) > 0中,paddle会被当做一个变量被传入convert_var_shape函数中
1 parent 8100c16 commit 19eb510

File tree

2 files changed

+16
-0
lines changed

2 files changed

+16
-0
lines changed

python/paddle/fluid/dygraph/dygraph_to_static/tensor_shape_transformer.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -282,6 +282,10 @@ def _is_var_shape(self, node):
282282
return False
283283

284284
if isinstance(node, gast.Attribute):
285+
# If node is `paddle.shape`, return False
286+
if (node.attr == 'shape' and isinstance(node.value, gast.Name) and
287+
node.value.id == 'paddle'):
288+
return False
285289
if node.attr != 'shape':
286290
return False
287291
return True

python/paddle/fluid/tests/unittests/dygraph_to_static/test_tensor_shape.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,12 @@ def dyfunc_change_shape_after_assign(x):
217217
return res
218218

219219

220+
def dyfunc_len_paddle_shape():
221+
x = paddle.to_tensor([1, 2, 3])
222+
if len(paddle.shape(x)) > 0:
223+
print(x)
224+
225+
220226
# 1. Basic tests without control flow
221227
class TestTensorShapeBasic(unittest.TestCase):
222228
def setUp(self):
@@ -582,5 +588,11 @@ def test(self):
582588
func.concrete_program
583589

584590

591+
class TestPaddleShape(unittest.TestCase):
592+
def test_paddle_shape(self):
593+
func = paddle.jit.to_static(dyfunc_len_paddle_shape)
594+
self.assertEqual('paddle.shape(x)' in func.code, True)
595+
596+
585597
if __name__ == '__main__':
586598
unittest.main()

0 commit comments

Comments
 (0)