Skip to content

Commit a2b8014

Browse files
authored
[DoubleGrad PR #7] paddle.grad() to copy backward graph before backward run (PaddlePaddle#41306)
* [Refactor] refactored eager_gen.py PR #2 * [DoubleGrad PR #1] Decoupled code generation logics for Dygraph ForwardFunctions and GradNodes * Fixed minor issue * Adjusted logics of GenerateNodeCreationCodes and GenerateForwardDefinition * Fixed issues * Supported higher-order grad node generation * [DoubleGrad PR #4] Supported higher-order GradNode generation * [DoubleGrad #4] Bug Fixes to Double Grad Node Generation * Fixed yaml typo * Fixed yaml typo * fixed minor issues * [DoubleGrad PR #5] Enabled gradient computations for grad_tensors passed to paddle.grad() * Fixed minor issue * Fixed CI-Inference issue * Fixed CI-inference issues * [DoubleGrad PR #7] paddle.grad() to copy backward graph before backward run * Fixed minor issues * Fixed issue with backward graph construction logic * Fixed implementation issues with backward graph reconstruction * Fixed unittest issue * Fixed issues
1 parent 5936fa6 commit a2b8014

File tree

12 files changed

+237
-54
lines changed

12 files changed

+237
-54
lines changed

paddle/fluid/eager/accumulation/accumulation_node.h

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,10 @@ class GradNodeAccumulation : public GradNodeBase {
2525
// Constructor: configure fwd input tensors to grad node
2626
explicit GradNodeAccumulation(AutogradMeta* meta) : GradNodeBase(1, 1) {
2727
VLOG(6) << "Construct GradNodeAccumulation";
28-
weak_grad_ = meta->WeakGrad();
28+
if (meta) {
29+
weak_grad_ = meta->WeakGrad();
30+
}
31+
2932
SetDefaultGradInOutMeta();
3033
}
3134

@@ -40,11 +43,6 @@ class GradNodeAccumulation : public GradNodeBase {
4043

4144
void ClearTensorWrappers() override { VLOG(6) << "Do nothing here now"; }
4245

43-
bool IsTensorWrappersCleared() override {
44-
VLOG(6) << "Do nothing here now";
45-
return false;
46-
}
47-
4846
std::string name() { return "GradNodeAccumulation"; }
4947

5048
/**
@@ -58,6 +56,11 @@ class GradNodeAccumulation : public GradNodeBase {
5856
inline bool ReduceHooksRegistered() { return reduce_hooks_.size() != 0; }
5957
void ApplyReduceHooks();
6058

59+
std::shared_ptr<GradNodeBase> Copy() const override {
60+
return std::shared_ptr<GradNodeAccumulation>(
61+
new GradNodeAccumulation(nullptr));
62+
}
63+
6164
private:
6265
std::weak_ptr<paddle::experimental::Tensor> weak_grad_;
6366

paddle/fluid/eager/api/generated/eager_generated/backwards/scale_node.h

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -44,18 +44,19 @@ class GradNodeScale : public GradNodeBase {
4444

4545
void ClearTensorWrappers() override { VLOG(6) << "Do nothing here now"; }
4646

47-
bool IsTensorWrappersCleared() override {
48-
VLOG(6) << "Do nothing here now";
49-
return false;
50-
}
51-
5247
void SetTensorWrappers_X(
5348
const std::vector<paddle::experimental::Tensor>& tensors);
5449

5550
void SetAttributes_scale(float scale);
5651
std::string name() override { return ""; }
5752
// Members: define fwd input tensors
5853
// For Scale there is no fwd input tensor needed
54+
55+
std::shared_ptr<GradNodeBase> Copy() const override {
56+
auto copied_node = std::make_shared<GradNodeScale>(*this);
57+
return copied_node;
58+
}
59+
5960
private:
6061
float scale_{1.0};
6162
};

paddle/fluid/eager/auto_code_generator/eager_generator.cc

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2479,22 +2479,23 @@ static std::string GenerateGradNodeHeaderContents(
24792479
"\n"
24802480
" void ClearTensorWrappers() override { \n"
24812481
"%s\n"
2482-
" is_tensor_wrappers_cleared = true;\n"
2482+
" SetIsTensorWrappersCleared(true);\n"
24832483
" }\n"
24842484
" std::string name() override { return \" GradNode%s \"; } \n "
24852485
"\n"
2486+
"std::shared_ptr<GradNodeBase> Copy() const override {{\n "
2487+
" auto copied_node = std::shared_ptr<GradNode%s>(new "
2488+
"GradNode%s(*this));\n "
2489+
" return copied_node;\n "
2490+
"}}\n "
2491+
"\n"
24862492
" // SetX, SetY, ...\n"
24872493
"%s\n"
24882494
" // SetAttrMap\n"
24892495
"%s\n"
2490-
" bool IsTensorWrappersCleared() override { \n"
2491-
" return is_tensor_wrappers_cleared;\n"
2492-
" }\n"
24932496
" private:\n"
24942497
" // TensorWrappers\n"
24952498
"%s\n"
2496-
" bool is_tensor_wrappers_cleared = false;\n"
2497-
"\n"
24982499
" // Attribute Map\n"
24992500
"%s\n"
25002501
"};";
@@ -2601,8 +2602,9 @@ static std::string GenerateGradNodeHeaderContents(
26012602

26022603
std::string grad_node_str = paddle::string::Sprintf(
26032604
GRAD_NODE_TEMPLATE, op_type, op_type, op_type, op_type, op_type, op_type,
2604-
op_type, clear_tensor_wrappers_str, op_type, set_tensor_wrappers_str,
2605-
set_attr_map_str, tensor_wrapper_members_str, attr_members_str);
2605+
op_type, clear_tensor_wrappers_str, op_type, op_type, op_type,
2606+
set_tensor_wrappers_str, set_attr_map_str, tensor_wrapper_members_str,
2607+
attr_members_str);
26062608

26072609
return grad_node_str;
26082610
}

paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -125,23 +125,24 @@ class {} : public egr::GradNodeBase {{
125125
126126
void ClearTensorWrappers() override {{
127127
{}
128-
is_tensor_wrappers_cleared = true;
128+
SetIsTensorWrappersCleared(true);
129+
}}
130+
131+
std::shared_ptr<GradNodeBase> Copy() const override {{
132+
auto copied_node = std::shared_ptr<{}>(new {}(*this));
133+
134+
return copied_node;
129135
}}
130136
131137
// SetTensorWrapperX, SetTensorWrapperY, ...
132138
{}
133139
// SetAttributes
134140
{}
135141
136-
bool IsTensorWrappersCleared() override {{
137-
return is_tensor_wrappers_cleared;
138-
}}
139142
private:
140143
// TensorWrappers
141144
{}
142145
143-
bool is_tensor_wrappers_cleared = false;
144-
145146
// Attributes
146147
{}
147148
}};
@@ -1218,9 +1219,10 @@ def GenerateNodeDeclaration(self):
12181219
grad_node_name = GetGradNodeName(forward_op_name)
12191220
self.node_declaration_str = NODE_DECLARATION_TEMPLATE.format(
12201221
grad_node_name, grad_node_name, grad_node_name, grad_node_name,
1221-
grad_node_name, clear_tensor_wrapper_str,
1222-
set_tensor_wrapper_methods_str, set_attribute_methods_str,
1223-
tensor_wrapper_members_str, attribute_members_str)
1222+
grad_node_name, clear_tensor_wrapper_str, grad_node_name,
1223+
grad_node_name, set_tensor_wrapper_methods_str,
1224+
set_attribute_methods_str, tensor_wrapper_members_str,
1225+
attribute_members_str)
12241226

12251227
logging.info(f"Generated Node Declaration: {self.node_declaration_str}")
12261228

paddle/fluid/eager/backward.cc

Lines changed: 111 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,16 @@ class GeneralGrad {
5050
for (size_t i = 0; i < num_inputs; i++) {
5151
AutogradMeta* auto_grad_meta =
5252
EagerUtils::unsafe_autograd_meta(inputs[i]);
53-
auto target_node = auto_grad_meta->GetMutableGradNode().get();
53+
auto* target_node = auto_grad_meta->GetMutableGradNode().get();
54+
55+
if (orig_to_copied_node_mapping_.count(target_node)) {
56+
target_node = orig_to_copied_node_mapping_[target_node];
57+
} else {
58+
VLOG(6) << "Unable to find target node in "
59+
"orig_to_copied_node_mapping_, likely indicating an "
60+
"unused input";
61+
}
62+
5463
PADDLE_ENFORCE_NOT_NULL(target_node,
5564
paddle::platform::errors::Fatal(
5665
"There is no grad op for %s:[%d] or it's"
@@ -249,7 +258,15 @@ class GeneralGrad {
249258
for (size_t i = 0; i < inputs.size(); ++i) {
250259
auto& input = inputs[i];
251260
AutogradMeta* auto_grad_meta = EagerUtils::unsafe_autograd_meta(input);
252-
auto target_node = auto_grad_meta->GetMutableGradNode().get();
261+
262+
auto* target_node = auto_grad_meta->GetMutableGradNode().get();
263+
if (orig_to_copied_node_mapping_.count(target_node)) {
264+
target_node = orig_to_copied_node_mapping_[target_node];
265+
} else {
266+
VLOG(6) << "Unable to find target node in "
267+
"orig_to_copied_node_mapping_, likely indicating an unused "
268+
"input";
269+
}
253270

254271
auto iter = results_map.find(target_node);
255272
if (iter != results_map.end()) {
@@ -326,6 +343,78 @@ class GeneralGrad {
326343
potential_stop_nodes.clear();
327344
depending_nodes.clear();
328345
results_map.clear();
346+
copied_grad_nodes_.clear();
347+
orig_to_copied_node_mapping_.clear();
348+
}
349+
350+
GradNodeBase* CopyGradNode(const std::shared_ptr<GradNodeBase>& orig_node) {
351+
if (orig_to_copied_node_mapping_.count(orig_node.get())) {
352+
return orig_to_copied_node_mapping_[orig_node.get()];
353+
}
354+
std::shared_ptr<GradNodeBase> copied_node = orig_node->Copy();
355+
356+
// Save node and update mapping
357+
orig_to_copied_node_mapping_[orig_node.get()] = copied_node.get();
358+
copied_grad_nodes_.push_back(copied_node);
359+
360+
return copied_node.get();
361+
}
362+
363+
void ReconstructBackwardGraph(
364+
const std::queue<GradNodeBase*>& orig_init_queue) {
365+
std::queue<GradNodeBase*> queue = orig_init_queue;
366+
std::unordered_set<GradNodeBase*> visited;
367+
368+
// BFS and recursively copy the grad nodes
369+
while (!queue.empty()) {
370+
GradNodeBase* orig_node = queue.front();
371+
queue.pop();
372+
if (visited.count(orig_node)) {
373+
continue;
374+
}
375+
visited.insert(orig_node);
376+
377+
PADDLE_ENFORCE(
378+
orig_to_copied_node_mapping_.count(orig_node),
379+
paddle::platform::errors::Fatal(
380+
"Cannot reconstruct backward graph,"
381+
"unable to find copied target for certain grad node."));
382+
GradNodeBase* copied_node = orig_to_copied_node_mapping_[orig_node];
383+
384+
const std::vector<std::vector<Edge>>& orig_edges = orig_node->GetEdges();
385+
std::vector<std::vector<Edge>>& copied_edges =
386+
copied_node->GetMutableEdges();
387+
for (size_t i = 0; i < orig_edges.size(); i++) {
388+
for (size_t j = 0; j < orig_edges[i].size(); j++) {
389+
const Edge& orig_edge = orig_edges[i][j];
390+
Edge& copied_edge = copied_edges[i][j];
391+
392+
std::shared_ptr<GradNodeBase> orig_next_node =
393+
orig_edge.GetMutableGradNode();
394+
if (!orig_next_node) continue;
395+
396+
// Copy Next Node
397+
std::shared_ptr<GradNodeBase> copied_next_node;
398+
if (orig_to_copied_node_mapping_.count(orig_next_node.get())) {
399+
copied_next_node =
400+
orig_to_copied_node_mapping_[orig_next_node.get()]
401+
->shared_from_this();
402+
403+
} else {
404+
copied_next_node = orig_next_node->Copy();
405+
orig_to_copied_node_mapping_[orig_next_node.get()] =
406+
copied_next_node.get();
407+
copied_grad_nodes_.push_back(copied_next_node);
408+
}
409+
410+
// Update Edge's Grad Node
411+
copied_edge.SetGradNode(copied_next_node);
412+
413+
// Update BFS queue
414+
queue.push(orig_next_node.get());
415+
}
416+
}
417+
}
329418
}
330419

331420
private:
@@ -345,6 +434,10 @@ class GeneralGrad {
345434
std::unordered_set<GradNodeBase*> /* pre nodes */>
346435
depending_nodes;
347436
std::unordered_map<GradNodeBase*, paddle::experimental::Tensor> results_map;
437+
438+
std::vector<std::shared_ptr<GradNodeBase>> copied_grad_nodes_;
439+
std::unordered_map<GradNodeBase*, GradNodeBase*> orig_to_copied_node_mapping_;
440+
348441
DISABLE_COPY_AND_ASSIGN(GeneralGrad);
349442
};
350443

@@ -444,6 +537,7 @@ std::vector<paddle::experimental::Tensor> RunBackward(
444537
// 1. Init queue with starting nodes
445538
// 2. Prepare initial input buffers
446539
std::queue<GradNodeBase*> queue;
540+
std::queue<GradNodeBase*> orig_queue;
447541
std::unordered_map<GradNodeBase*, std::unique_ptr<GradTensorHolder>>
448542
node_input_buffers_dict;
449543
for (size_t i = 0; i < tensors.size(); i++) {
@@ -468,6 +562,16 @@ std::vector<paddle::experimental::Tensor> RunBackward(
468562

469563
// TODO(zhanlve): Copy and Modify GradNode if is_general_grad
470564
GradNodeBase* grad_node = shared_grad_node.get();
565+
if (is_general_grad) {
566+
// Save orig grad node
567+
orig_queue.push(grad_node);
568+
569+
// Replace grad_node with copied grad_node
570+
grad_node = GeneralGrad::Instance().CopyGradNode(shared_grad_node);
571+
572+
// Record potential startup grad node
573+
GeneralGrad::Instance().GetPotentialStartupNodes()->insert(grad_node);
574+
}
471575

472576
// Prepare GradTensorHolder
473577
if (!node_input_buffers_dict.count(grad_node)) {
@@ -504,9 +608,11 @@ std::vector<paddle::experimental::Tensor> RunBackward(
504608

505609
// Prepare queue, potential startup_nodes
506610
queue.push(grad_node);
507-
if (is_general_grad) {
508-
GeneralGrad::Instance().GetPotentialStartupNodes()->emplace(grad_node);
509-
}
611+
}
612+
613+
if (is_general_grad) {
614+
// Copy Backward Graph
615+
GeneralGrad::Instance().ReconstructBackwardGraph(orig_queue);
510616
}
511617

512618
VLOG(6) << "Update In degree Map for backward";

paddle/fluid/eager/custom_operator/custom_operator_node.h

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,10 @@ class RunCustomOpNode : public GradNodeBase {
3636
}
3737

3838
// Functor: perform backward computations
39-
virtual std::vector<std::vector<paddle::experimental::Tensor>> operator()(
40-
std::vector<std::vector<paddle::experimental::Tensor>>& grads,
41-
bool create_graph = false) // NOLINT
39+
virtual std::vector<std::vector<paddle::experimental::Tensor>>
40+
operator()( // NOLINT
41+
std::vector<std::vector<paddle::experimental::Tensor>>& grads, // NOLINT
42+
bool create_graph = false) // NOLINT
4243
override;
4344

4445
std::string name() {
@@ -64,13 +65,15 @@ class RunCustomOpNode : public GradNodeBase {
6465
}
6566

6667
void ClearTensorWrappers() override { VLOG(6) << "Do nothing here now"; }
67-
bool IsTensorWrappersCleared() override {
68-
VLOG(6) << "Do nothing here now";
69-
return false;
70-
}
7168

7269
void SetAttrs(const std::vector<paddle::any>& attr) { attrs_ = attr; }
7370

71+
std::shared_ptr<GradNodeBase> Copy() const override {
72+
auto copied_node =
73+
std::shared_ptr<RunCustomOpNode>(new RunCustomOpNode(*this));
74+
return copied_node;
75+
}
76+
7477
public:
7578
std::unordered_map<int, std::vector<egr::TensorWrapper>> fwd_outs;
7679
std::unordered_map<int, std::vector<egr::TensorWrapper>> fwd_ins;

paddle/fluid/eager/grad_node_info.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -326,6 +326,10 @@ const std::vector<std::vector<Edge>>& GradNodeBase::GetEdges() const {
326326
return adj_edges_;
327327
}
328328

329+
std::vector<std::vector<Edge>>& GradNodeBase::GetMutableEdges() {
330+
return adj_edges_;
331+
}
332+
329333
std::vector<std::vector<paddle::experimental::Tensor>>
330334
GradNodeBase::ApplyGradientHooks(
331335
const std::vector<std::vector<paddle::experimental::Tensor>>& tensors) {

0 commit comments

Comments
 (0)