Skip to content

Commit 15378ef

Browse files
committed
Cherry-pick PR 30103. Add Inplace strategy (Output reuse Input Varbase) in dygraph (#30103)
* add view strategy on squeeze,unsqueeze,reshape,flatten * add squeeze unittest * add unittests * use View strategy as name rather than Reuse Allacation * fix view api doc * fix format * use core.ops when input of reshape2 is Tensor * fix test_cross_entropy_loss error because of reshape2 * fix test_cross_entropy_loss error because of reshape2 * add inplace strategy * add elementwise_add sub * let backward op not use inplace * grad op do not use inplace * fix memory increase error and add leaf error message * delete selected_rows * change op_function * little change * solve HandleViewBetweenInputAndOutput * add unittest and leaf error message * merge view error * optimize op_function_generator format and support sum inplace op * fix format of basic_engine * fix format for framework * little change of variable wrapper * add reshape, squeeze, unsqueeze, scatter api * add relu elu tanh softmax inplace api * fix test_squeeze_op unittest * fix test_relu_op unittest * fix comment problems * delete sample code of inplace api * add reference of grad_pending_nodes in basic_engine * fix unittest name * add inplace apis into wlist * fix error message * add PADDLE_ENFORCE for set grad op twice * fix head file error
1 parent badc6f2 commit 15378ef

29 files changed

+1102
-257
lines changed

paddle/fluid/framework/details/op_registry.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ limitations under the License. */
1414

1515
#pragma once
1616

17+
#include <map>
1718
#include <memory>
1819
#include <string>
1920
#include <tuple>
@@ -247,8 +248,9 @@ struct OpInfoFiller<T, kGradOpBaseMaker> {
247248
const std::string& type,
248249
const imperative::NameVarBaseMap& var_base_map_in,
249250
const imperative::NameVarBaseMap& var_base_map_out,
250-
const framework::AttributeMap& attrs) {
251-
T maker(type, var_base_map_in, var_base_map_out, attrs);
251+
const framework::AttributeMap& attrs,
252+
const std::map<std::string, std::string>& inplace_map) {
253+
T maker(type, var_base_map_in, var_base_map_out, attrs, inplace_map);
252254
return maker();
253255
};
254256
}

paddle/fluid/framework/grad_op_desc_maker.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,10 @@ class SingleGradOpMaker<imperative::OpBase>
221221

222222
std::shared_ptr<imperative::GradOpNode> operator()() const final {
223223
auto node = this->NewGradNode();
224+
auto& inplace_map = this->GetInplaceMap();
225+
if (!inplace_map.empty()) {
226+
node->SetInplaceGradNameMap(inplace_map);
227+
}
224228
{
225229
imperative::TracedGradOp traced_grad_op(node);
226230
try {

paddle/fluid/framework/type_defs.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,8 @@ using DygraphGradOpMakerFN =
5959
const std::string& /*op_type*/,
6060
const imperative::NameVarBaseMap& /*var_base_map_in*/,
6161
const imperative::NameVarBaseMap& /*var_base_map_out*/,
62-
const framework::AttributeMap& /*attributes*/)>;
62+
const framework::AttributeMap& /*attributes*/,
63+
const std::map<std::string, std::string>& /*inplace_map*/)>;
6364

6465
using InferVarTypeFN =
6566
std::function<void(framework::InferVarTypeContext* /*context*/)>;

paddle/fluid/imperative/basic_engine.cc

Lines changed: 158 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,9 @@ void BasicEngine::CheckBackwardInputs(const OpBase& op) {
108108
}
109109
}
110110

111-
void BasicEngine::PrepareGradAccumulators(const OpBase& op) {
111+
void BasicEngine::PrepareGradAccumulators(
112+
const OpBase& op,
113+
const std::vector<std::shared_ptr<GradOpNode>>& grad_pending_nodes) {
112114
for (const auto& pair : op.GetOutsMap()) {
113115
if (!pair.second.IsGrad()) {
114116
continue;
@@ -117,29 +119,94 @@ void BasicEngine::PrepareGradAccumulators(const OpBase& op) {
117119
for (const auto& var : pair.second) {
118120
if (!var) continue;
119121

120-
auto& accumulator = accumulators_[var.get()];
121-
if (!accumulator) {
122-
if (FLAGS_sort_sum_gradient) {
123-
accumulator.reset(new SortedGradientAccumulator(var.get()));
124-
} else {
125-
accumulator.reset(new EagerGradientAccumulator(var.get()));
122+
if (!var->HasGradNode()) {
123+
auto& accumulator = accumulators_[var.get()];
124+
if (!accumulator) {
125+
if (FLAGS_sort_sum_gradient) {
126+
accumulator.reset(new SortedGradientAccumulator(var.get()));
127+
} else {
128+
accumulator.reset(new EagerGradientAccumulator(var.get()));
129+
}
126130
}
127-
}
128131

129-
accumulator->IncreaseRefCnt();
132+
accumulator->IncreaseRefCnt();
130133

131-
VLOG(3) << "Prepare to acccumulate variable grad " << var->Name() << "("
132-
<< var.get() << ") with reference count "
133-
<< accumulator->RefCnt();
134+
VLOG(3) << "Prepare to acccumulate variable grad " << var->Name() << "("
135+
<< var.get()
136+
<< ") that don't have grad node with reference count "
137+
<< accumulator->RefCnt();
138+
139+
if (var->HasLeafHooks()) {
140+
VLOG(3) << "Grad variable wrapper (" << var->Name()
141+
<< ") has leaf grad hooks.";
142+
PADDLE_ENFORCE_NE(
143+
var->HasGradNode(), true,
144+
platform::errors::PermissionDenied(
145+
"Only leaf Tensor's gradient can append hook to "
146+
"Gradientaccumulator."));
147+
accumulator->SetPostHooks(var->GetLeafHooks());
148+
}
149+
} else {
150+
// Because Inplace op overwrites the grad_node of the input grad_var. So
151+
// only the information of grad_pending_node can be used to find the
152+
// grad_node of grad_var.
153+
bool find_grad_node_of_var = false;
154+
for (auto& grad_pending_node : grad_pending_nodes) {
155+
PADDLE_ENFORCE_NOT_NULL(
156+
grad_pending_node,
157+
platform::errors::NotFound("Grad pending node is nullptr."));
158+
for (auto& grad_pending_op : *grad_pending_node) {
159+
VLOG(6) << "Determine whether var (" << var->Name()
160+
<< ") is the input var of grad_pending_op ("
161+
<< grad_pending_op.Type() << ").";
162+
grad_pending_op.EnforceHasInOut();
163+
for (const auto& grad_pending_op_ins_pair :
164+
grad_pending_op.GetInsMap()) {
165+
if (!grad_pending_op_ins_pair.second.IsGrad()) {
166+
continue;
167+
}
168+
for (const auto& pending_in_var :
169+
grad_pending_op_ins_pair.second) {
170+
if (var == pending_in_var) {
171+
VLOG(6) << "Var (" << var->Name()
172+
<< ") is the input var of grad_pending_op ("
173+
<< grad_pending_op.Type() << ").";
174+
find_grad_node_of_var = true;
175+
break;
176+
}
177+
}
178+
if (find_grad_node_of_var) {
179+
break;
180+
}
181+
}
182+
}
134183

135-
if (var->HasLeafHooks()) {
136-
VLOG(3) << "Grad variable wrapper (" << var->Name()
137-
<< ") has leaf grad hooks.";
138-
PADDLE_ENFORCE_NE(var->HasGradNode(), true,
139-
platform::errors::PermissionDenied(
140-
"Only leaf Tensor's gradient can append hook to "
141-
"Gradientaccumulator."));
142-
accumulator->SetPostHooks(var->GetLeafHooks());
184+
if (find_grad_node_of_var) {
185+
auto& accumulator =
186+
accumulators_with_grad_node_[grad_pending_node][var.get()];
187+
188+
if (!accumulator) {
189+
if (FLAGS_sort_sum_gradient) {
190+
accumulator.reset(new SortedGradientAccumulator(var.get()));
191+
} else {
192+
accumulator.reset(new EagerGradientAccumulator(var.get()));
193+
}
194+
}
195+
196+
accumulator->IncreaseRefCnt();
197+
198+
VLOG(3) << "Prepare to acccumulate variable grad " << var->Name()
199+
<< "(" << var.get()
200+
<< ") that has grad node with reference count "
201+
<< accumulator->RefCnt();
202+
break;
203+
}
204+
}
205+
PADDLE_ENFORCE_EQ(
206+
find_grad_node_of_var, true,
207+
platform::errors::NotFound(
208+
"No grad node corresponding to grad Tensor (%s) was found.",
209+
var->Name()));
143210
}
144211
}
145212
}
@@ -148,10 +215,13 @@ void BasicEngine::PrepareGradAccumulators(const OpBase& op) {
148215
void BasicEngine::PrepareDeps() {
149216
PADDLE_ENFORCE_EQ(
150217
node_deps_.empty(), true,
151-
platform::errors::AlreadyExists("Op deps must be initialized here"));
218+
platform::errors::AlreadyExists("Op deps must be initialized."));
152219
PADDLE_ENFORCE_EQ(
153220
accumulators_.empty(), true,
154-
platform::errors::AlreadyExists("Accumulators must be initialized here"));
221+
platform::errors::AlreadyExists("Accumulators must be initialized."));
222+
PADDLE_ENFORCE_EQ(
223+
accumulators_with_grad_node_.empty(), true,
224+
platform::errors::AlreadyExists("Accumulators must be initialized."));
155225

156226
std::queue<GradOpNode*> q;
157227
std::unordered_set<GradOpNode*> visited;
@@ -163,16 +233,17 @@ void BasicEngine::PrepareDeps() {
163233
auto* cur_node = q.front();
164234
q.pop();
165235

236+
const auto& grad_pending_nodes = cur_node->GradPendingNodes();
237+
166238
for (auto& cur_op : *cur_node) {
167239
cur_op.EnforceHasInOut();
168-
PrepareGradAccumulators(cur_op);
240+
PrepareGradAccumulators(cur_op, grad_pending_nodes);
169241
}
170242

171-
const auto& grad_pending_nodes = cur_node->GradPendingNodes();
172243
for (auto& grad_pending_node : grad_pending_nodes) {
173244
PADDLE_ENFORCE_NOT_NULL(
174245
grad_pending_node,
175-
platform::errors::NotFound("Grad pending node should not be null"));
246+
platform::errors::NotFound("Grad pending node is nullptr."));
176247
++node_deps_[grad_pending_node.get()];
177248
if (visited.count(grad_pending_node.get()) == 0) {
178249
visited.insert(grad_pending_node.get());
@@ -198,6 +269,8 @@ void BasicEngine::Execute() {
198269
auto shared_cur_node = std::move(q.front());
199270
q.pop();
200271

272+
auto& inplace_grad_name_map = shared_cur_node->InplaceGradNameMap();
273+
201274
for (auto& cur_op : *shared_cur_node) {
202275
++op_num;
203276

@@ -222,11 +295,38 @@ void BasicEngine::Execute() {
222295
continue;
223296
}
224297

225-
auto iter = accumulators_.find(var.get());
226-
PADDLE_ENFORCE_EQ(
227-
iter != accumulators_.end(), true,
228-
platform::errors::NotFound("Cannot find gradient of variable %s",
229-
var->Name()));
298+
std::unordered_map<VariableWrapper*,
299+
std::unique_ptr<GradientAccumulator>>::iterator
300+
iter;
301+
if (!var->HasGradNode()) {
302+
VLOG(10) << "Find gradient of var (" << var->Name()
303+
<< ") with no grad_node.";
304+
iter = accumulators_.find(var.get());
305+
PADDLE_ENFORCE_EQ(
306+
iter != accumulators_.end(), true,
307+
platform::errors::NotFound(
308+
"Cannot find gradient of variable %s", var->Name()));
309+
} else {
310+
bool flag_find_grad = false;
311+
VLOG(10) << "Find gradient of var (" << var->Name()
312+
<< ") with grad_node.";
313+
for (auto& grad_pending_node :
314+
shared_cur_node->GradPendingNodes()) {
315+
const auto& iter_grad_node =
316+
accumulators_with_grad_node_.find(grad_pending_node);
317+
if (iter_grad_node != accumulators_with_grad_node_.end()) {
318+
iter = iter_grad_node->second.find(var.get());
319+
if (iter != iter_grad_node->second.end()) {
320+
flag_find_grad = true;
321+
break;
322+
}
323+
}
324+
}
325+
PADDLE_ENFORCE_EQ(
326+
flag_find_grad, true,
327+
platform::errors::NotFound(
328+
"Cannot find gradient of variable %s", var->Name()));
329+
}
230330

231331
// leaf_accumulators_ : hooks and accumulate-grad for leaf tensor
232332
if (var->IsLeafGrad()) {
@@ -245,6 +345,25 @@ void BasicEngine::Execute() {
245345
need_accu_var_list_.emplace_back(iter->second.get(), var);
246346
VLOG(10) << "create temporary var of " << var->Name()
247347
<< " for sum gradient within this graph!";
348+
} else if (!inplace_grad_name_map.empty() &&
349+
inplace_grad_name_map.count(pair.first)) {
350+
// When calculate Inplace grad op, create a new output var.
351+
// If a tmp var has been created, there is no need to create it
352+
// again.
353+
for (auto& in_var :
354+
bwd_ins.at(inplace_grad_name_map.at(pair.first))) {
355+
if (in_var == var) {
356+
auto tmp_var = std::make_shared<VariableWrapper>(var->Name());
357+
tmp_var->SetType(var->Type());
358+
tmp_var->SetForwardDataType(var->ForwardDataType());
359+
inplace_output_grad_var_list_.emplace_back(var, tmp_var);
360+
var = tmp_var;
361+
VLOG(10) << "Inplace grad op does not use the Inplace "
362+
"strategy, a temporary output var ("
363+
<< var->Name() << ") will be created.";
364+
break;
365+
}
366+
}
248367
}
249368
}
250369
}
@@ -280,6 +399,10 @@ void BasicEngine::Execute() {
280399
cur_op.place());
281400
}
282401

402+
for (auto& pair : inplace_output_grad_var_list_) {
403+
*pair.first = std::move(*pair.second);
404+
}
405+
283406
// Step 2: Sum Gradient of This graph
284407
for (auto& pair : need_accu_var_list_) {
285408
pair.first->SumGrad(std::move(pair.second), cur_op.id());
@@ -302,6 +425,7 @@ void BasicEngine::Execute() {
302425
}
303426

304427
need_accu_var_list_.clear();
428+
inplace_output_grad_var_list_.clear();
305429
leaf_accumulators_.clear();
306430

307431
if (!retain_graph_) {
@@ -312,9 +436,9 @@ void BasicEngine::Execute() {
312436

313437
// Step 3: Collect ready ops
314438
for (auto& grad_pending_node : shared_cur_node->GradPendingNodes()) {
315-
PADDLE_ENFORCE_NOT_NULL(grad_pending_node,
316-
platform::errors::NotFound(
317-
"Grad pending node should not be nullptr"));
439+
PADDLE_ENFORCE_NOT_NULL(
440+
grad_pending_node,
441+
platform::errors::NotFound("Grad pending node is nullptr."));
318442
auto iter = node_deps_.find(grad_pending_node.get());
319443
if (iter == node_deps_.end()) {
320444
continue;
@@ -334,6 +458,7 @@ void BasicEngine::Clear() {
334458
init_node_.reset();
335459
node_deps_.clear();
336460
accumulators_.clear();
461+
accumulators_with_grad_node_.clear();
337462
need_accu_var_list_.clear();
338463
leaf_accumulators_.clear();
339464
}

paddle/fluid/imperative/basic_engine.h

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,15 +39,33 @@ class BasicEngine : public Engine {
3939

4040
void CheckBackwardInputs(const OpBase& op);
4141

42-
void PrepareGradAccumulators(const OpBase& op);
42+
void PrepareGradAccumulators(
43+
const OpBase& op,
44+
const std::vector<std::shared_ptr<GradOpNode>>& grad_pending_nodes);
4345

4446
void Clear();
4547

4648
private:
4749
std::shared_ptr<GradOpNode> init_node_;
4850
std::unordered_map<GradOpNode*, size_t> node_deps_;
51+
// The input and output of Inplace op are the same. If only `var` is used
52+
// as the key, then the input and output of inplace op must be gradient
53+
// accumulated. Therefore, add the `grad_node` as the key to prevent the
54+
// problem of gradient accumulation in inplace op.
55+
std::unordered_map<std::shared_ptr<GradOpNode>,
56+
std::unordered_map<VariableWrapper*,
57+
std::unique_ptr<GradientAccumulator>>>
58+
accumulators_with_grad_node_;
59+
// Leaf var doesn't have grad_node, and leaf var with `stop_gradient=False`
60+
// can't use Inplace strategy. If a var doesn't have grad_node, only use
61+
// `var` as the key.
4962
std::unordered_map<VariableWrapper*, std::unique_ptr<GradientAccumulator>>
5063
accumulators_;
64+
// The output grad var of Inplace grad op. Because Inplace grad op does not
65+
// use the Inplace strategy, a new output grad var needs to be created.
66+
std::vector<std::pair<std::shared_ptr<VariableWrapper>,
67+
std::shared_ptr<VariableWrapper>>>
68+
inplace_output_grad_var_list_;
5169
std::vector<std::pair<GradientAccumulator*, std::shared_ptr<VariableWrapper>>>
5270
need_accu_var_list_;
5371
// leaf_accumulators_ is only for leaf tensor(hooks/accumulate grad)

0 commit comments

Comments
 (0)