Skip to content

Enabled AutoCodeGen for Eager Dygraph #37639

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Nov 29, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions paddle/fluid/eager/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
generated/**
autocodegen/generated_example/
17 changes: 15 additions & 2 deletions paddle/fluid/eager/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,9 +1,22 @@
set(eager_deps pten pten_api hook_utils tensor_utils utils global_utils backward pten_tensor legacy autograd_meta grad_node_info grad_tensor_holder gradient_accumulation accumulation_node)
set(fluid_deps tracer layer proto_desc operator op_registry variable_helper memcpy)
set(generated_deps dygraph_function dygraph_node)

if(NOT DEFINED ON_INFER)
message("Performing Eager Dygraph Auto Code Generation")
add_subdirectory(auto_code_generator)
endif()

add_subdirectory(api)
add_subdirectory(accumulation)
add_subdirectory(tests)
add_subdirectory(legacy)

cc_library(autograd_meta SRCS autograd_meta.cc DEPS pten pten_api)
cc_library(grad_node_info SRCS grad_node_info.cc DEPS pten pten_api)
cc_library(grad_tensor_holder SRCS grad_tensor_holder.cc DEPS grad_node_info gradient_accumulation)

cc_library(autograd_meta SRCS autograd_meta.cc DEPS pten pten_api)
cc_library(utils SRCS utils.cc DEPS pten pten_api global_utils layer proto_desc operator op_registry variable_helper memcpy scale_op autograd_meta)
cc_library(legacy SRCS ${DYGRAPH_LEGACY} DEPS global_utils proto_desc operator pten pten_api op_registry variable_helper memcpy)
cc_library(backward SRCS backward.cc DEPS grad_tensor_holder utils autograd_meta grad_node_info)

add_subdirectory(tests)
1 change: 1 addition & 0 deletions paddle/fluid/eager/api/generated/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
fluid_generated/**
4 changes: 4 additions & 0 deletions paddle/fluid/eager/api/generated/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1 +1,5 @@
add_subdirectory(eager_generated)

if(NOT DEFINED ON_INFER)
add_subdirectory(fluid_generated)
endif()
15 changes: 13 additions & 2 deletions paddle/fluid/eager/auto_code_generator/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,24 @@ target_link_libraries(eager_generator ${EAGER_GENERETOR_DEPS})
get_property (os_dependency_modules GLOBAL PROPERTY OS_DEPENDENCY_MODULES)
target_link_libraries(eager_generator ${os_dependency_modules})

if(WITH_ROCM)
target_link_libraries(eager_generator ${ROCM_HIPRTC_LIB})
endif()

# Prepare file structure
message("Generate dygraph file structure at path: ${PADDLE_SOURCE_DIR}/paddle/fluid/eager/generated")
execute_process(
COMMAND "${PYTHON_EXECUTABLE}" "${PADDLE_SOURCE_DIR}/paddle/fluid/eager/auto_code_generator/generate_file_structures.py" "${PADDLE_SOURCE_DIR}/paddle/fluid/eager/"
)

add_custom_target(eager_codegen
COMMAND "${CMAKE_CURRENT_BINARY_DIR}/eager_generator" "${PADDLE_SOURCE_DIR}/paddle/fluid/eager/api/generated/fluid_generated"
if(WIN32)
add_custom_target(eager_codegen
COMMAND "${CMAKE_CURRENT_BINARY_DIR}/eager_generator.exe" "${PADDLE_SOURCE_DIR}/paddle/fluid/eager/api/generated/fluid_generated"
DEPENDS eager_generator
VERBATIM)
else()
add_custom_target(eager_codegen
COMMAND "${CMAKE_CURRENT_BINARY_DIR}/eager_generator" "${PADDLE_SOURCE_DIR}/paddle/fluid/eager/api/generated/fluid_generated"
DEPENDS eager_generator
VERBATIM)
endif()
36 changes: 14 additions & 22 deletions paddle/fluid/eager/auto_code_generator/eager_generator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -577,11 +577,6 @@ static std::string GenerateGradNodeCreationContent(
// If single output slotname and not duplicable,
// then generate: "egr::AutogradMeta* p_autograd_out =
// egr::EagerUtils::autograd_meta("op_proto->outputs()[0].name()")"

// TODO(zhanlve): in case of multiple slotname but none of which are
// duplicable,
// avoid constructing vector<AutogradMeta*>, generate seperate
// AutogradMeta* objects respectively.
std::string get_autograd_meta_str = " // Prepare Autograd Meta \n";
for (const proto::OpProto::Var& input : op_proto.inputs()) {
const std::string& input_name = input.name();
Expand All @@ -607,11 +602,6 @@ static std::string GenerateGradNodeCreationContent(
// If single output slotname and not duplicable,
// then generate: "egr::AutogradMeta* p_autograd_out =
// egr::EagerUtils::autograd_meta("op_proto.outputs()[0].name()")"

// TODO(zhanlve): in case of multiple slotname but none of which are
// duplicable,
// avoid constructing vector<AutogradMeta*>, generate seperate
// AutogradMeta* objects respectively.
for (const proto::OpProto::Var& output : op_proto.outputs()) {
const std::string& output_name = output.name();
const std::string& output_autograd_name = "p_autograd_" + output_name;
Expand Down Expand Up @@ -725,9 +715,9 @@ static std::string GenerateGradNodeCreationContent(
// [Generation] GradNode Creation
const char* GRAD_NODE_CREATION_TEMPLATE =
" %s"
" bool require_any_grad = egr::ComputeRequireGrad(%s);\n"
" bool require_any_grad = egr::EagerUtils::ComputeRequireGrad(%s);\n"
" if(require_any_grad) {\n"
" egr::PassStopGradient(%s);\n"
" egr::EagerUtils::PassStopGradient(%s);\n"
"%s\n }";
std::string grad_node_creation_body_str = paddle::string::Sprintf(
GRAD_NODE_CREATION_TEMPLATE, prepare_autograd_meta_str,
Expand Down Expand Up @@ -793,7 +783,7 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents(
Controller.Instance().GetExpectedPlace(), {});

// According to fwd_outputs_names
std::vector<egr::EagerTensor> Out0 = GetOutputs(outs["Out0"]);
std::vector<egr::EagerTensor> Out0 = GGetOutputetOutputs(outs["Out0"]);
egr::EagerTensor Out1 = GetOutputs(outs["Out1"][0]);
std::vector<egr::EagerTensor> Out2 = GetOutputs(outs["Out2"]);

Expand Down Expand Up @@ -830,7 +820,8 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents(
input_args_str_list[input_position] =
paddle::string::Sprintf(FWD_INS_ARG_TEMPLATE, input_name);
}
const char* FWD_INS_CONTENT_TEMPLATE = "{ \"%s\", egr::SyncToVars(%s) },";
const char* FWD_INS_CONTENT_TEMPLATE =
"{ \"%s\", egr::EagerUtils::SyncToVars(%s) },";
ins_contents_str += paddle::string::Sprintf(FWD_INS_CONTENT_TEMPLATE,
input_name, input_name);
}
Expand Down Expand Up @@ -925,14 +916,14 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents(
if (output.duplicable()) {
const char* FWD_OUT_TENSORS_TEMPLATE =
" std::vector<egr::EagerTensor> %s = "
"egr::GetOutputs(outs[\"%s\"]);\n";
"egr::EagerUtils::GetOutputs(outs[\"%s\"]);\n";
out_tensor_str = paddle::string::Sprintf(FWD_OUT_TENSORS_TEMPLATE,
output_name, output_name);
return_types[return_position] = "std::vector<egr::EagerTensor>";
} else {
const char* FWD_OUT_TENSOR_TEMPLATE =
" egr::EagerTensor %s = "
"egr::GetOutput(outs[\"%s\"][0]);\n";
"egr::EagerUtils::GetOutput(outs[\"%s\"][0]);\n";
out_tensor_str = paddle::string::Sprintf(FWD_OUT_TENSOR_TEMPLATE,
output_name, output_name);
return_types[return_position] = "egr::EagerTensor";
Expand Down Expand Up @@ -1093,7 +1084,8 @@ static std::string GenerateGradNodeCCContents(
grad_ins_fwd_slotname_map.at(grad_input_name) + "_";
const char* GRAD_INS_FWD_CONTENT_TEMPLATE =
"{ \"%s\", "
"egr::SyncToVars(egr::EagerUtils::RecoverTensorWrapper(&this->%s, "
"egr::EagerUtils::SyncToVars(egr::EagerUtils::RecoverTensorWrapper(&"
"this->%s, "
"nullptr)) },";
ins_contents_str +=
paddle::string::Sprintf(GRAD_INS_FWD_CONTENT_TEMPLATE,
Expand All @@ -1104,7 +1096,7 @@ static std::string GenerateGradNodeCCContents(
size_t fwd_output_position = fwd_outputs_name_pos_map.at(
grad_ins_grad_slotname_map.at(grad_input_name));
const char* GRAD_INS_GRAD_CONTENT_TEMPLATE =
"{ \"%s\", egr::SyncToVars(grads[%d]) },";
"{ \"%s\", egr::EagerUtils::SyncToVars(grads[%d]) },";
ins_contents_str += paddle::string::Sprintf(
GRAD_INS_GRAD_CONTENT_TEMPLATE, grad_input_name, fwd_output_position);

Expand Down Expand Up @@ -1206,7 +1198,7 @@ static std::string GenerateGradNodeCCContents(
fwd_inputs_name_pos_map.at(grad_outs_slotname_map.at(grad_out_name));

const char* BWD_OUTPUT_TEMPLATE =
" outputs[%d] = GetOutputs(outs[\"%s\"]);\n";
" outputs[%d] = egr::EagerUtils::GetOutputs(outs[\"%s\"]);\n";
outputs_str += paddle::string::Sprintf(BWD_OUTPUT_TEMPLATE,
fwd_input_position, grad_out_name);
}
Expand Down Expand Up @@ -1526,6 +1518,9 @@ static void DygraphCodeGeneration(const std::string& output_dir) {
GenerateForwardHFile(output_dir, dygraph_forward_api_str);
}

} // namespace framework
} // namespace paddle

int main(int argc, char* argv[]) {
if (argc != 2) {
std::cerr << "argc must be 2" << std::endl;
Expand All @@ -1537,6 +1532,3 @@ int main(int argc, char* argv[]) {

return 0;
}

} // namespace framework
} // namespace paddle
2 changes: 1 addition & 1 deletion paddle/fluid/eager/legacy/prepared_operator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
#include "paddle/fluid/framework/pten_utils.h"
#include "paddle/utils/small_vector.h"
#ifdef PADDLE_WITH_XPU
#include "paddle/fluid/platform/xpu/xpu_op_list.h"
#include "paddle/fluid/platform/device/xpu/xpu_op_list.h"
#endif
DECLARE_bool(check_nan_inf);
DECLARE_bool(run_pten_kernel);
Expand Down
3 changes: 0 additions & 3 deletions paddle/fluid/eager/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,2 @@
set(eager_deps pten pten_api hook_utils tensor_utils utils global_utils backward pten_tensor autograd_meta grad_node_info grad_tensor_holder gradient_accumulation accumulation_node)
set(fluid_deps tracer layer proto_desc operator op_registry variable_helper memcpy)

add_subdirectory(data_structure_tests)
add_subdirectory(task_tests)
4 changes: 4 additions & 0 deletions paddle/fluid/eager/tests/task_tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,7 @@ cc_test(test_egr_task_backward SRCS backward_test.cc DEPS ${eager_deps} ${fluid_
cc_test(test_egr_task_hook SRCS hook_test.cc DEPS ${eager_deps} ${fluid_deps} eager_scale scale_node)
cc_test(test_egr_task_cross_batch SRCS cross_batch_accumulation_test.cc DEPS ${eager_deps} ${fluid_deps} eager_scale scale_node)
cc_test(test_egr_task_fwd_bwd_joint SRCS fwd_bwd_joint_test.cc DEPS ${eager_deps} ${fluid_deps} eager_scale scale_node)

if(NOT DEFINED ON_INFER)
cc_test(test_egr_task_autocodegen SRCS generated_test.cc DEPS ${eager_deps} ${fluid_deps} ${generated_deps})
endif()
95 changes: 95 additions & 0 deletions paddle/fluid/eager/tests/task_tests/generated_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

// Eager Dygraph

#include <chrono>

#include "gtest/gtest.h"

#include "paddle/fluid/eager/api/all.h"
#include "paddle/fluid/eager/api/utils/tensor_utils.h"
#include "paddle/fluid/eager/autograd_meta.h"
#include "paddle/fluid/eager/backward.h"
#include "paddle/fluid/eager/utils.h"

#include "paddle/fluid/eager/tests/test_utils.h"
#include "paddle/fluid/imperative/tracer.h"

#include "paddle/fluid/eager/api/generated/fluid_generated/dygraph_forward_api.h"
#include "paddle/pten/core/kernel_registry.h"

// TODO(jiabin): remove nolint here!!!
using namespace egr; // NOLINT

namespace eager_test {

TEST(Generated, Sigmoid) {
// Prepare Device Contexts
InitEnv(paddle::platform::CPUPlace());
VLOG(6) << "Init Env";
// 1. Prepare Input
paddle::framework::DDim ddim = paddle::framework::make_ddim({2, 4, 4, 4});
VLOG(6) << "Make Dim";
egr::EagerTensor tensor = CreateTensorWithValue(
ddim, paddle::platform::CPUPlace(), pten::DataType::FLOAT32,
pten::DataLayout::NCHW, 0.0, true);
VLOG(6) << "Make EagerTensor";
RetainGradForTensor(tensor);
VLOG(6) << "Retain Grad for Tensor";
auto output_tensor = sigmoid_dygraph_function(tensor, {});
VLOG(6) << "Run Backward";
CompareVariableWithValue<float>(output_tensor, 0.5);

std::vector<egr::EagerTensor> target_tensors = {output_tensor};
VLOG(6) << "Runing Backward";
RunBackward(target_tensors, {});

VLOG(6) << "Finish Backward";
CompareGradVariableWithValue<float>(tensor, 0.25);
}

TEST(Generated, Matmul_v2) {
// Prepare Device Contexts
InitEnv(paddle::platform::CPUPlace());

auto tracer = std::make_shared<paddle::imperative::Tracer>();
paddle::imperative::SetCurrentTracer(tracer);

// 1. Prepare Input
paddle::framework::DDim ddimX = paddle::framework::make_ddim({4, 16});
egr::EagerTensor X = CreateTensorWithValue(
ddimX, paddle::platform::CPUPlace(), pten::DataType::FLOAT32,
pten::DataLayout::NCHW, 3.0, true);
RetainGradForTensor(X);

paddle::framework::DDim ddimY = paddle::framework::make_ddim({16, 20});
egr::EagerTensor Y = CreateTensorWithValue(
ddimY, paddle::platform::CPUPlace(), pten::DataType::FLOAT32,
pten::DataLayout::NCHW, 2.0, true);
RetainGradForTensor(Y);

auto output_tensor = matmul_v2_dygraph_function(
X, Y, {{"trans_x", false}, {"trans_y", false}});

CompareVariableWithValue<float>(output_tensor, 96);

std::vector<egr::EagerTensor> target_tensors = {output_tensor};
RunBackward(target_tensors, {});

CompareGradVariableWithValue<float>(X, 2.0 * 20);
CompareGradVariableWithValue<float>(Y, 3.0 * 4);
}

} // namespace eager_test
16 changes: 16 additions & 0 deletions paddle/fluid/eager/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

#include "paddle/fluid/eager/utils.h"
#include "paddle/fluid/eager/api/utils/global_utils.h"
#include "paddle/fluid/eager/tensor_wrapper.h"

#include "paddle/pten/api/all.h"
#include "paddle/pten/common/layout.h"
Expand Down Expand Up @@ -188,4 +189,19 @@ egr::EagerTensor EagerUtils::GetOutput(
return EagerTensor((*(out.get())));
}

EagerTensor EagerUtils::RecoverTensorWrapper(
TensorWrapper* tw, const std::shared_ptr<GradNodeBase>& grad_node) {
return tw->recover(grad_node);
}

std::vector<EagerTensor> EagerUtils::RecoverTensorWrapper(
std::vector<TensorWrapper>* tw,
const std::shared_ptr<GradNodeBase>& grad_node) {
std::vector<EagerTensor> ret;
for (auto& t : *tw) {
ret.emplace_back(t.recover(grad_node));
}
return ret;
}

} // namespace egr
9 changes: 9 additions & 0 deletions paddle/fluid/eager/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@

namespace egr {

class TensorWrapper;

/**
* EagerUtils is utils used to do some static conversion or autograd
* members access, this class is desinged to be a full static functional
Expand Down Expand Up @@ -131,6 +133,13 @@ class EagerUtils {
iter.apply(std::forward<Args>(args)...);
}

// TensorWrapper Utils
static egr::EagerTensor RecoverTensorWrapper(
egr::TensorWrapper* tw, const std::shared_ptr<GradNodeBase>& grad_node);
static std::vector<egr::EagerTensor> RecoverTensorWrapper(
std::vector<egr::TensorWrapper>* tw,
const std::shared_ptr<GradNodeBase>& grad_node);

// Intermidate needed remove this once we don't need legacy
static std::vector<std::shared_ptr<egr::EagerTensor>> SyncToVars(
const egr::EagerTensor& tensor);
Expand Down
14 changes: 14 additions & 0 deletions paddle/fluid/framework/details/nan_inf_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include <string>
#include <vector>

#include "paddle/fluid/eager/legacy/type_def.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/imperative/type_defs.h"
Expand Down Expand Up @@ -53,6 +54,19 @@ void CheckOpHasNanOrInfInDygraph(const std::string& op_type,
}
}

template <typename TensorType>
static void CheckOpHasNanOrInfInEager(const std::string& op_type,
const egr::NameMap<TensorType>& op_outs,
platform::Place place) {
for (const auto& pair : op_outs) {
for (const auto& tensor : pair.second) {
auto* var = tensor->MutableVar();
if (var == nullptr) continue;
CheckVarHasNanOrInf(op_type, tensor->name(), var, place);
}
}
}

#ifdef PADDLE_WITH_ASCEND_CL
void NPUAllocAndClearFloatStatus(const framework::OperatorBase& op,
const framework::ScopeBase& scope,
Expand Down