1818#include < map>
1919#include < memory>
2020#include < string>
21+ #include < utility>
2122#include < vector>
2223
2324#include " paddle/fluid/framework/data_type.h"
3132namespace paddle {
3233namespace tape {
3334
35+ using std::map;
36+ using std::pair;
37+ using std::unordered_map;
38+ using std::string;
39+ using std::unique_ptr;
40+ using std::vector;
41+
3442// borrowed from
3543// https://stackoverflow.com/questions/874134/find-if-string-ends-with-another-string-in-c
36- inline bool ends_with (std:: string const &value, std:: string const &ending) {
44+ inline bool ends_with (string const &value, string const &ending) {
3745 if (ending.size () > value.size ()) return false ;
3846 return std::equal (ending.rbegin (), ending.rend (), value.rbegin ());
3947}
@@ -50,26 +58,26 @@ std::ostream &operator<<(std::ostream &os, const framework::VarDesc &var_desc) {
5058 return os;
5159}
5260
53- std:: string to_string (const std:: string &type,
54- const VariableHandleMap &in_vars,
55- const VariableHandleMap &out_vars,
56- const framework::AttributeMap &attrs) {
61+ string to_string (const string &type,
62+ const VariableHandleMap &in_vars,
63+ const VariableHandleMap &out_vars,
64+ const framework::AttributeMap &attrs) {
5765 std::stringstream ss;
5866 ss << type << " " ;
5967 for (auto ¶m_name : in_vars) {
6068 for (auto &var : param_name.second ) {
61- ss << param_name.first << " :(" << var << " ) " ;
69+ ss << param_name.first << " :(" << var-> Name () << " ) " ;
6270 }
6371 }
6472 for (auto ¶m_name : out_vars) {
6573 for (auto &var : param_name.second ) {
66- ss << param_name.first << " :(" << var << " ) " ;
74+ ss << param_name.first << " :(" << var-> Name () << " ) " ;
6775 }
6876 }
6977 return ss.str ();
7078}
7179
72- framework::OpDesc CreateOpDesc (const std:: string &type,
80+ framework::OpDesc CreateOpDesc (const string &type,
7381 const VariableHandleMap &in_vars,
7482 const VariableHandleMap &out_vars,
7583 const framework::AttributeMap &attrs) {
@@ -90,7 +98,7 @@ framework::OpDesc CreateOpDesc(const std::string &type,
9098 return op_desc;
9199}
92100
93- void InferShapeAndVarType (const std:: string &type,
101+ void InferShapeAndVarType (const string &type,
94102 const VariableHandleMap &in_vars,
95103 VariableHandleMap *out_vars,
96104 const framework::AttributeMap &attrs) {
@@ -114,11 +122,12 @@ void InferShapeAndVarType(const std::string &type,
114122 op_with_kernel->InferShape (&infer_shape_ctx);
115123}
116124
117- void Tape::AddOp (const std:: string &type,
125+ void Tape::AddOp (const string &type,
118126 const VariableHandleMap &in_vars,
119127 VariableHandleMap out_vars,
120128 const framework::AttributeMap &attrs) {
121129 PADDLE_ENFORCE (!has_been_backwarded_);
130+ LOG (INFO) << " AddOp " << to_string (type, in_vars, out_vars, attrs);
122131 InferShapeAndVarType (type, in_vars, &out_vars, attrs);
123132 tape_.emplace_back (type, in_vars, out_vars, attrs);
124133}
@@ -138,6 +147,41 @@ void Tape::Forward() {
138147 VLOG (3 ) << " Finishing forward -------------------------" ;
139148}
140149
150+ void DescMapToVarMap (
151+ const unordered_map<string, VariableHandle> &name2var,
152+ const framework::VariableNameMap &vmp,
153+ VariableHandleMap *vhm,
154+ vector<pair<VariableHandle, VariableHandle>> *duplicated_grad,
155+ bool is_output) {
156+ for (auto &p2a : vmp) {
157+ for (auto &argu : p2a.second ) {
158+ if (name2var.count (argu)) {
159+ (*vhm)[p2a.first ].push_back (name2var.at (argu));
160+ } else {
161+ PADDLE_ENFORCE (ends_with (argu, framework::kGradVarSuffix ),
162+ " %s not end with %s" ,
163+ argu,
164+ framework::kGradVarSuffix );
165+ string name =
166+ argu.substr (0 , argu.size () - strlen (framework::kGradVarSuffix ));
167+ PADDLE_ENFORCE (name2var.count (name), " %s not found" , name);
168+ if (is_output && name2var.at (name)->GradExist ()) {
169+ VariableHandle temp_grad (new Variable (
170+ name + framework::kGradVarSuffix + framework::kTempVarName ));
171+ (*vhm)[p2a.first ].emplace_back (temp_grad);
172+ duplicated_grad->emplace_back (name2var.at (name)->Grad (),
173+ temp_grad); // name2var[name]->Grad has
174+ // to be the first element
175+ // since sum_op use X[0] ==
176+ // Out to determine inplace
177+ } else {
178+ (*vhm)[p2a.first ].push_back (name2var.at (name)->Grad ());
179+ }
180+ }
181+ }
182+ }
183+ }
184+
141185void Tape::Backward (VariableHandle target) {
142186 PADDLE_ENFORCE (!has_been_backwarded_);
143187
@@ -146,21 +190,20 @@ void Tape::Backward(VariableHandle target) {
146190 // TODO(tonyyang-svail): check output of last op is target
147191 backward_tape_.reset (new Tape ());
148192
149- // FIXME(tonyyang-svail): Need to infer_data_type
150193 backward_tape_->AddOp (
151194 " fill_ones_like" , {{" X" , {target}}}, {{" Out" , {target->Grad ()}}}, {});
152195
153196 for (auto it = tape_.rbegin (); it != tape_.rend (); ++it) {
154197 framework::OpDesc op_desc =
155198 CreateOpDesc (it->type_ , it->inputs_ , it->outputs_ , it->attrs_ );
156- std:: unordered_map<std:: string, std:: string> grad_to_var;
157- std:: vector<std:: unique_ptr<framework::OpDesc>> grad_op_descs =
199+ unordered_map<string, string> grad_to_var;
200+ vector<unique_ptr<framework::OpDesc>> grad_op_descs =
158201 framework::OpInfoMap::Instance ()
159202 .Get (op_desc.Type ())
160203 .GradOpMaker ()(op_desc, {}, &grad_to_var, {});
161204
162- for (auto &op_desc : grad_op_descs) {
163- std:: unordered_map<std:: string, VariableHandle> name2var;
205+ for (auto &op_grad_desc : grad_op_descs) {
206+ unordered_map<string, VariableHandle> name2var;
164207 for (auto ¶m2vars : it->inputs_ ) {
165208 for (auto &a : param2vars.second ) {
166209 name2var[a->Name ()] = a;
@@ -172,36 +215,25 @@ void Tape::Backward(VariableHandle target) {
172215 }
173216 }
174217
175- VariableHandleMap in_vars;
176- VariableHandleMap out_vars;
177- std::map<const framework::VariableNameMap *, VariableHandleMap *>
178- loop_over{{&op_desc->Inputs (), &in_vars},
179- {&op_desc->Outputs (), &out_vars}};
180- for (auto &each : loop_over) {
181- auto &vmp = *each.first ;
182- auto &vhm = *each.second ;
183- for (auto &p2a : vmp) {
184- for (auto &argu : p2a.second ) {
185- if (name2var.count (argu)) {
186- vhm[p2a.first ].push_back (name2var[argu]);
187- } else {
188- PADDLE_ENFORCE (ends_with (argu, framework::kGradVarSuffix ),
189- argu.c_str ());
190- std::string name = argu.substr (
191- 0 , argu.size () - std::strlen (framework::kGradVarSuffix ));
192- PADDLE_ENFORCE (name2var.count (name), name.c_str ());
193- vhm[p2a.first ].push_back (name2var[name]->Grad ());
194- }
195- }
196- }
197- }
218+ vector<pair<VariableHandle, VariableHandle>>
219+ duplicated_grad; // vector of {grad, grad@temp}
220+ VariableHandleMap in_vars, out_vars;
221+ DescMapToVarMap (
222+ name2var, op_grad_desc->Inputs (), &in_vars, &duplicated_grad, false );
223+ DescMapToVarMap (
224+ name2var, op_grad_desc->Outputs (), &out_vars, &duplicated_grad, true );
198225
199226 backward_tape_->AddOp (
200- op_desc->Type (), in_vars, out_vars, op_desc->GetAttrMap ());
227+ op_grad_desc->Type (), in_vars, out_vars, op_grad_desc->GetAttrMap ());
228+ for (auto &pair : duplicated_grad) {
229+ backward_tape_->AddOp (" sum" ,
230+ {{" X" , {pair.first , pair.second }}},
231+ {{" Out" , {pair.first }}},
232+ {});
233+ }
201234 }
202235
203236 // TODO(tonyyang-svail): how to fill empty grad?
204- // TODO(tonyyang-svail): Sum var grad is necessary
205237 }
206238
207239 backward_tape_->Forward ();
0 commit comments