Skip to content

Commit 4e66010

Browse files
authored
fix bug of optional_tensor in amp logic (#42561)
1 parent 80015c0 commit 4e66010

File tree

1 file changed

+4
-1
lines changed
  • paddle/fluid/eager/auto_code_generator/final_state_generator

1 file changed

+4
-1
lines changed

paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -865,7 +865,10 @@ def GenerateForwardDefinition(self, is_inplaced):
865865
f"if ({name}.get_ptr() != nullptr) amp_tensors_vector.push_back({{ *({name}.get_ptr()) }});\n"
866866
)
867867
amp_autocast_optional_list.append(
868-
f"auto NEW_{name} = ({name}.get_ptr() != nullptr) ? paddle::make_optional<const paddle::experimental::Tensor&>(egr::EagerAmpAutoCast(\"{name}\", *({name}.get_ptr()), amp_dst_dtype, op_name)) : {name};\n"
868+
f"auto NEW_{name}_temp_tensor = ({name}.get_ptr() != nullptr) ? egr::EagerAmpAutoCast(\"{name}\", *({name}.get_ptr()), amp_dst_dtype, op_name) : paddle::experimental::Tensor();\n"
869+
)
870+
amp_autocast_optional_list.append(
871+
f"auto NEW_{name} = ({name}.get_ptr() != nullptr) ? paddle::make_optional<const paddle::experimental::Tensor&>(NEW_{name}_temp_tensor) : {name};\n"
869872
)
870873
else:
871874
if is_inplaced and inplace_map and name in inplace_map.keys(

0 commit comments

Comments
 (0)