Skip to content

Commit 31ed9a5

Browse files
authored
[Dy2Stat] Use Paddle2.0 api paddle.tensor.array_* (#30156)
1 parent ad55f60 commit 31ed9a5

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

python/paddle/fluid/dygraph/dygraph_to_static/list_transformer.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ def _transform_slice_to_tensor_write(self, node):
126126
i = "paddle.cast(" \
127127
"x=paddle.jit.dy2static.to_static_variable({})," \
128128
"dtype='int64')".format(ast_to_source_code(slice_node))
129-
assign_code = "{} = fluid.layers.array_write(x={}, i={}, array={})" \
129+
assign_code = "{} = paddle.tensor.array_write(x={}, i={}, array={})" \
130130
.format(target_name, value_code, i, target_name)
131131
assign_node = gast.parse(assign_code).body[0]
132132
return assign_node
@@ -168,7 +168,7 @@ def _is_list_append_tensor(self, node):
168168
# return False
169169
# if NodeVarType.TENSOR not in var_type_set and NodeVarType.PADDLE_RETURN_TYPES not in var_type_set:
170170
# return False
171-
# # TODO: Consider that `arg` may be a gast.Call about Paddle Api. eg: list_a.append(fluid.layers.reshape(x))
171+
# # TODO: Consider that `arg` may be a gast.Call about Paddle Api. eg: list_a.append(paddle.reshape(x))
172172
# # else:
173173
# # return True
174174
self.list_name_to_updated[value_name.strip()] = True
@@ -187,16 +187,16 @@ def _need_to_create_tensor_array(self, node):
187187

188188
def _create_tensor_array(self):
189189
# Although `dtype='float32'`, other types such as `int32` can also be supported
190-
func_code = "fluid.layers.create_array(dtype='float32')"
190+
func_code = "paddle.tensor.create_array(dtype='float32')"
191191
func_node = gast.parse(func_code).body[0].value
192192
return func_node
193193

194194
def _to_array_write_node(self, node):
195195
assert isinstance(node, gast.Call)
196196
array = astor.to_source(gast.gast_to_ast(node.func.value))
197197
x = astor.to_source(gast.gast_to_ast(node.args[0]))
198-
i = "fluid.layers.array_length({})".format(array)
199-
func_code = "fluid.layers.array_write(x={}, i={}, array={})".format(
198+
i = "paddle.tensor.array_length({})".format(array)
199+
func_code = "paddle.tensor.array_write(x={}, i={}, array={})".format(
200200
x, i, array)
201201
return gast.parse(func_code).body[0].value
202202

0 commit comments

Comments
 (0)