@@ -147,35 +147,41 @@ void Tape::Forward() {
147147 VLOG (3 ) << " Finishing forward -------------------------" ;
148148}
149149
150- void DescMapToVarMap (
150+ void Tape:: DescMapToVarMap (
151151 const unordered_map<string, VariableHandle> &name2var,
152- const framework::VariableNameMap &vmp ,
152+ const framework::VariableNameMap &variable_name_map ,
153153 VariableHandleMap *vhm,
154154 vector<pair<VariableHandle, VariableHandle>> *duplicated_grad,
155155 bool is_output) {
156- for (auto &p2a : vmp) {
157- for (auto &argu : p2a.second ) {
158- if (name2var.count (argu)) {
159- (*vhm)[p2a.first ].push_back (name2var.at (argu));
156+ for (auto &p2a : variable_name_map) {
157+ for (auto &arg : p2a.second ) {
158+ auto ¶m = p2a.first ;
159+ if (name2var.count (arg)) {
160+ (*vhm)[param].push_back (name2var.at (arg));
160161 } else {
161- PADDLE_ENFORCE (ends_with (argu , framework::kGradVarSuffix ),
162+ PADDLE_ENFORCE (ends_with (arg , framework::kGradVarSuffix ),
162163 " %s not end with %s" ,
163- argu ,
164+ arg ,
164165 framework::kGradVarSuffix );
165166 string name =
166- argu .substr (0 , argu .size () - strlen (framework::kGradVarSuffix ));
167+ arg .substr (0 , arg .size () - strlen (framework::kGradVarSuffix ));
167168 PADDLE_ENFORCE (name2var.count (name), " %s not found" , name);
168169 if (is_output && name2var.at (name)->GradExist ()) {
170+ // Sum duplicated grad
169171 VariableHandle temp_grad (new Variable (
170172 name + framework::kGradVarSuffix + framework::kTempVarName ));
171- (*vhm)[p2a.first ].emplace_back (temp_grad);
172- duplicated_grad->emplace_back (name2var.at (name)->Grad (),
173- temp_grad); // name2var[name]->Grad has
174- // to be the first element
175- // since sum_op use X[0] ==
176- // Out to determine inplace
173+ // name2var[name]->Grad has to be the first element since sum_op use
174+ // X[0] == Out to determine inplace
175+ duplicated_grad->emplace_back (name2var.at (name)->Grad (), temp_grad);
176+ (*vhm)[param].emplace_back (temp_grad);
177+ } else if (!is_output && !name2var.at (name)->GradExist ()) {
178+ // zero initialize empty grad
179+ auto var = name2var.at (name);
180+ backward_tape_->AddOp (
181+ " fill_zeros_like" , {{" X" , {var}}}, {{" Out" , {var->Grad ()}}}, {});
182+ (*vhm)[param].push_back (var->Grad ());
177183 } else {
178- (*vhm)[p2a. first ].push_back (name2var.at (name)->Grad ());
184+ (*vhm)[param ].push_back (name2var.at (name)->Grad ());
179185 }
180186 }
181187 }
@@ -232,8 +238,6 @@ void Tape::Backward(VariableHandle target) {
232238 {});
233239 }
234240 }
235-
236- // TODO(tonyyang-svail): how to fill empty grad?
237241 }
238242
239243 backward_tape_->Forward ();
0 commit comments