Skip to content
This repository was archived by the owner on Jan 24, 2024. It is now read-only.

Commit 0b5b13e

Browse files
author
Yang Yang
committed
backward pass all tests
1 parent 78e3c8f commit 0b5b13e

File tree

5 files changed

+104
-33
lines changed

5 files changed

+104
-33
lines changed

src/function.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -378,6 +378,13 @@ VariableHandle cross_entropy(VariableHandle x, VariableHandle label) {
378378
return out;
379379
}
380380

381+
VariableHandle add(VariableHandle x, VariableHandle y) {
382+
VariableHandle out(new Variable("add"));
383+
get_global_tape().AddOp(
384+
"elementwise_add", {{"X", {x}}, {"Y", {y}}}, {{"Out", {out}}}, {});
385+
return out;
386+
}
387+
381388
VariableHandle CreateRecordioFileReader(std::string filename,
382389
std::vector<int> shape_concat,
383390
std::vector<int> ranks,

src/tape.cc

Lines changed: 22 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -147,35 +147,41 @@ void Tape::Forward() {
147147
VLOG(3) << "Finishing forward -------------------------";
148148
}
149149

150-
void DescMapToVarMap(
150+
void Tape::DescMapToVarMap(
151151
const unordered_map<string, VariableHandle> &name2var,
152-
const framework::VariableNameMap &vmp,
152+
const framework::VariableNameMap &variable_name_map,
153153
VariableHandleMap *vhm,
154154
vector<pair<VariableHandle, VariableHandle>> *duplicated_grad,
155155
bool is_output) {
156-
for (auto &p2a : vmp) {
157-
for (auto &argu : p2a.second) {
158-
if (name2var.count(argu)) {
159-
(*vhm)[p2a.first].push_back(name2var.at(argu));
156+
for (auto &p2a : variable_name_map) {
157+
for (auto &arg : p2a.second) {
158+
auto &param = p2a.first;
159+
if (name2var.count(arg)) {
160+
(*vhm)[param].push_back(name2var.at(arg));
160161
} else {
161-
PADDLE_ENFORCE(ends_with(argu, framework::kGradVarSuffix),
162+
PADDLE_ENFORCE(ends_with(arg, framework::kGradVarSuffix),
162163
"%s not end with %s",
163-
argu,
164+
arg,
164165
framework::kGradVarSuffix);
165166
string name =
166-
argu.substr(0, argu.size() - strlen(framework::kGradVarSuffix));
167+
arg.substr(0, arg.size() - strlen(framework::kGradVarSuffix));
167168
PADDLE_ENFORCE(name2var.count(name), "%s not found", name);
168169
if (is_output && name2var.at(name)->GradExist()) {
170+
// Sum duplicated grad
169171
VariableHandle temp_grad(new Variable(
170172
name + framework::kGradVarSuffix + framework::kTempVarName));
171-
(*vhm)[p2a.first].emplace_back(temp_grad);
172-
duplicated_grad->emplace_back(name2var.at(name)->Grad(),
173-
temp_grad); // name2var[name]->Grad has
174-
// to be the first element
175-
// since sum_op use X[0] ==
176-
// Out to determine inplace
173+
// name2var[name]->Grad has to be the first element since sum_op use
174+
// X[0] == Out to determine inplace
175+
duplicated_grad->emplace_back(name2var.at(name)->Grad(), temp_grad);
176+
(*vhm)[param].emplace_back(temp_grad);
177+
} else if (!is_output && !name2var.at(name)->GradExist()) {
178+
// zero initialize empty grad
179+
auto var = name2var.at(name);
180+
backward_tape_->AddOp(
181+
"fill_zeros_like", {{"X", {var}}}, {{"Out", {var->Grad()}}}, {});
182+
(*vhm)[param].push_back(var->Grad());
177183
} else {
178-
(*vhm)[p2a.first].push_back(name2var.at(name)->Grad());
184+
(*vhm)[param].push_back(name2var.at(name)->Grad());
179185
}
180186
}
181187
}
@@ -232,8 +238,6 @@ void Tape::Backward(VariableHandle target) {
232238
{});
233239
}
234240
}
235-
236-
// TODO(tonyyang-svail): how to fill empty grad?
237241
}
238242

239243
backward_tape_->Forward();

src/tape.h

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include <map>
1717
#include <memory>
1818
#include <string>
19+
#include <utility>
1920
#include <vector>
2021

2122
#include "src/variable.h"
@@ -83,18 +84,18 @@ class Tape {
8384
bool HasBeenBackwarded() { return has_been_backwarded_; }
8485

8586
private:
87+
void Tape::DescMapToVarMap(
88+
const std::unordered_map<std::string, VariableHandle> &name2var,
89+
const framework::VariableNameMap &variable_name_map,
90+
VariableHandleMap *vhm,
91+
std::vector<std::pair<VariableHandle, VariableHandle>> *duplicated_grad,
92+
bool is_output);
93+
8694
bool has_been_backwarded_ = false;
8795
size_t current_position_ = 0;
8896

8997
std::vector<OpHandle> tape_;
9098
std::shared_ptr<Tape> backward_tape_;
91-
92-
// void DescMapToVarMap(const std::unordered_map<std::string, VariableHandle>
93-
// &name2var,
94-
// const std::vector<std::pair<VariableHandle,
95-
// VariableHandle>> &duplicated_grad,
96-
// const std::pair<const framework::VariableNameMap
97-
// *const, VariableHandleMap *> &each) const;
9899
};
99100

100101
Tape &get_global_tape();

src/test_backward.cc

Lines changed: 63 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,26 +12,85 @@
1212
// See the License for the specific language governing permissions and
1313
// limitations under the License.
1414

15+
#include <vector>
16+
1517
#include "gtest/gtest.h"
1618
#include "src/function.h"
1719

20+
using paddle::tape::reset_global_tape;
21+
using paddle::tape::get_global_tape;
22+
23+
using paddle::framework::LoDTensor;
24+
25+
using paddle::tape::add;
26+
using paddle::tape::mean;
27+
using paddle::tape::Variable;
28+
using paddle::tape::VariableHandle;
29+
30+
VariableHandle fill_with_vector(std::vector<float> data) {
31+
VariableHandle var(new Variable("fill"));
32+
auto* tensor = var->GetMutable<LoDTensor>();
33+
tensor->Resize({static_cast<int64_t>(data.size())});
34+
LOG(INFO) << tensor->dims();
35+
auto* ptr = tensor->mutable_data<float>(paddle::platform::CPUPlace());
36+
for (size_t i = 0; i < data.size(); ++i) {
37+
ptr[i] = data[i];
38+
}
39+
return var;
40+
}
41+
1842
/*
1943
* y = op(x)
2044
* z = op(x)
2145
* loss = y + z
2246
*/
23-
TEST(Backward, TestMultipleAssignment) {}
47+
TEST(Backward, TestMultipleAssignment) {
48+
reset_global_tape();
49+
50+
auto x = fill_with_vector({42});
51+
auto y = mean(x);
52+
auto z = mean(x);
53+
auto loss = add(y, z);
54+
55+
get_global_tape().Backward(loss);
56+
57+
LOG(INFO) << x->Value();
58+
LOG(INFO) << x->Grad()->Value();
59+
PADDLE_ENFORCE_EQ(x->Grad()->Get<LoDTensor>().data<float>()[0], 2.0);
60+
}
2461

2562
/*
2663
* loss = x + x
2764
*/
28-
TEST(Backward, TestInplaceSum) {}
65+
TEST(Backward, TestInplaceSum) {
66+
reset_global_tape();
67+
68+
auto x = fill_with_vector({42});
69+
auto loss = add(x, x);
70+
71+
get_global_tape().Backward(loss);
72+
73+
PADDLE_ENFORCE_EQ(x->Grad()->Get<LoDTensor>().data<float>()[0], 2.0);
74+
}
2975

3076
/*
3177
* y = op(x) // y@grad is not initialized
3278
* loss = op(z)
3379
*/
34-
TEST(Backward, TestEmptyGrad) {}
80+
TEST(Backward, TestEmptyGrad) {
81+
reset_global_tape();
82+
auto x = fill_with_vector({42});
83+
auto y = mean(x);
84+
85+
auto z = fill_with_vector({42});
86+
auto loss = mean(z);
87+
88+
get_global_tape().Backward(loss);
89+
90+
PADDLE_ENFORCE_EQ(x->Grad()->Get<LoDTensor>().data<float>()[0], 0.0);
91+
PADDLE_ENFORCE_EQ(y->Grad()->Get<LoDTensor>().data<float>()[0], 0.0);
92+
PADDLE_ENFORCE_EQ(z->Grad()->Get<LoDTensor>().data<float>()[0], 1.0);
93+
}
3594

3695
/*
3796
* vector<> v
@@ -41,7 +100,7 @@ TEST(Backward, TestEmptyGrad) {}
41100
* v.push_back(out)
42101
* loss = v.back()
43102
*/
44-
TEST(Backward, TestForLoop) {}
103+
TEST(Backward, TestForLoop) { reset_global_tape(); }
45104

46105
int main(int argc, char** argv) {
47106
std::vector<paddle::platform::Place> places;

src/variable.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,8 @@ class Variable {
4545

4646
~Variable() { VLOG(10) << "Deleting " << Name(); }
4747

48+
bool GradExist() { return !grad_.expired(); }
49+
4850
VariableHandle Grad() {
4951
if (grad_.expired()) {
5052
VariableHandle new_grad(new Variable(name_, true));
@@ -55,8 +57,6 @@ class Variable {
5557
}
5658
}
5759

58-
bool GradExist() { return !grad_.expired(); }
59-
6060
// Evaluate a variable by running Forward() on the global tape
6161
const Variable& Value();
6262

@@ -85,8 +85,8 @@ class Variable {
8585
}
8686

8787
private:
88-
int count() {
89-
static int counter = 0;
88+
int64_t count() {
89+
static int64_t counter = 0;
9090
return counter++;
9191
}
9292

0 commit comments

Comments
 (0)