Skip to content

Commit eb3c7d0

Browse files
authored
[Dy2St]Refine AnnAssign in static_analysis (#39572)
1 parent d414461 commit eb3c7d0

File tree

2 files changed

+8
-2
lines changed

2 files changed

+8
-2
lines changed

python/paddle/fluid/dygraph/dygraph_to_static/static_analysis.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -349,7 +349,11 @@ def _get_node_var_type(self, cur_wrapper):
349349
ret_type = {NodeVarType.type_from_annotation(node.annotation)}
350350
# if annotation and value(Constant) are diffent type, we use value type
351351
if node.value:
352-
ret_type = self.node_to_wrapper_map[node.value].node_var_type
352+
node_value_type = self.node_to_wrapper_map[
353+
node.value].node_var_type
354+
if not (node_value_type &
355+
{NodeVarType.UNKNOWN, NodeVarType.STATEMENT}):
356+
ret_type = node_value_type
353357
if isinstance(node.target, gast.Name):
354358
self.node_to_wrapper_map[node.target].node_var_type = ret_type
355359
self.var_env.set_var_type(node.target.id, ret_type)

python/paddle/fluid/tests/unittests/dygraph_to_static/test_static_analysis.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,7 @@ def add(x, y):
147147
def func_to_test7(a: int, b: float, c: paddle.Tensor, d: float='diff'):
148148
a = True
149149
e, f = paddle.shape(c)
150+
g: paddle.Tensor = len(c)
150151

151152

152153
result_var_type7 = {
@@ -155,7 +156,8 @@ def func_to_test7(a: int, b: float, c: paddle.Tensor, d: float='diff'):
155156
'c': {NodeVarType.TENSOR},
156157
'd': {NodeVarType.STRING},
157158
'e': {NodeVarType.PADDLE_RETURN_TYPES},
158-
'f': {NodeVarType.PADDLE_RETURN_TYPES}
159+
'f': {NodeVarType.PADDLE_RETURN_TYPES},
160+
'g': {NodeVarType.TENSOR}
159161
}
160162

161163
test_funcs = [

0 commit comments

Comments
 (0)