Skip to content
This repository was archived by the owner on Jan 24, 2024. It is now read-only.

Commit 3e5c6f1

Browse files
author
Yang Yang
committed
clean up
1 parent 0b5b13e commit 3e5c6f1

File tree

2 files changed

+15
-12
lines changed

2 files changed

+15
-12
lines changed

src/tape.cc

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -159,23 +159,26 @@ void Tape::DescMapToVarMap(
159159
if (name2var.count(arg)) {
160160
(*vhm)[param].push_back(name2var.at(arg));
161161
} else {
162-
PADDLE_ENFORCE(ends_with(arg, framework::kGradVarSuffix),
163-
"%s not end with %s",
164-
arg,
165-
framework::kGradVarSuffix);
162+
PADDLE_ENFORCE(
163+
ends_with(arg, framework::kGradVarSuffix),
164+
"Backward can only add gradient variable. %s not end with %s",
165+
arg,
166+
framework::kGradVarSuffix);
166167
string name =
167168
arg.substr(0, arg.size() - strlen(framework::kGradVarSuffix));
168169
PADDLE_ENFORCE(name2var.count(name), "%s not found", name);
169-
if (is_output && name2var.at(name)->GradExist()) {
170-
// Sum duplicated grad
170+
if (is_output &&
171+
name2var.at(name)->GradExist()) { // Sum duplicated grad
171172
VariableHandle temp_grad(new Variable(
172173
name + framework::kGradVarSuffix + framework::kTempVarName));
173-
// name2var[name]->Grad has to be the first element since sum_op use
174-
// X[0] == Out to determine inplace
174+
// we want sum duplicated grad to be in-place
175+
// since sum_op uses X[0] == Out to determine inplace
176+
// we assign name2var[name]->Grad to be the first element
175177
duplicated_grad->emplace_back(name2var.at(name)->Grad(), temp_grad);
176178
(*vhm)[param].emplace_back(temp_grad);
177-
} else if (!is_output && !name2var.at(name)->GradExist()) {
178-
// zero initialize empty grad
179+
} else if (!is_output &&
180+
!name2var.at(name)
181+
->GradExist()) { // zero initialize empty grad
179182
auto var = name2var.at(name);
180183
backward_tape_->AddOp(
181184
"fill_zeros_like", {{"X", {var}}}, {{"Out", {var->Grad()}}}, {});
@@ -189,7 +192,7 @@ void Tape::DescMapToVarMap(
189192
}
190193

191194
void Tape::Backward(VariableHandle target) {
192-
PADDLE_ENFORCE(!has_been_backwarded_);
195+
PADDLE_ENFORCE(!has_been_backwarded_, "A tape can only backward once.");
193196

194197
Forward();
195198

src/tape.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ class Tape {
8484
bool HasBeenBackwarded() { return has_been_backwarded_; }
8585

8686
private:
87-
void Tape::DescMapToVarMap(
87+
void DescMapToVarMap(
8888
const std::unordered_map<std::string, VariableHandle> &name2var,
8989
const framework::VariableNameMap &variable_name_map,
9090
VariableHandleMap *vhm,

0 commit comments

Comments
 (0)