Skip to content
This repository was archived by the owner on Jan 24, 2024. It is now read-only.

Commit 051d42a

Browse files
author
Yang Yang
committed
add sum; pass mnist
1 parent e5b5539 commit 051d42a

File tree

4 files changed

+84
-46
lines changed

4 files changed

+84
-46
lines changed

src/CMakeLists.txt

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,11 +40,14 @@ include_directories(${CMAKE_BINARY_DIR}/third_party/eigen3/src/extern_eigen3)
4040
# hack in paddle/cmake/external/eigen.cmake
4141

4242
cc_library(tape_variable SRCS variable.cc DEPS operator)
43-
cc_library(tape SRCS tape.cc DEPS ${GLOB_OP_LIB} tape_variable)
43+
cc_library(tape SRCS tape.cc backward.cc DEPS ${GLOB_OP_LIB} tape_variable)
4444
cc_library(tape_function SRCS function.cc DEPS ${GLOB_OP_LIB} tape_variable tape)
4545

4646
cc_test(test_tape
4747
SRCS test_tape.cc
4848
DEPS tape tape_variable tape_function)
49+
cc_test(test_backward
50+
SRCS test_backward.cc
51+
DEPS tape tape_variable tape_function)
4952

5053
add_subdirectory(example)

src/tape.cc

Lines changed: 72 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include <map>
1919
#include <memory>
2020
#include <string>
21+
#include <utility>
2122
#include <vector>
2223

2324
#include "paddle/fluid/framework/data_type.h"
@@ -31,9 +32,16 @@
3132
namespace paddle {
3233
namespace tape {
3334

35+
using std::map;
36+
using std::pair;
37+
using std::unordered_map;
38+
using std::string;
39+
using std::unique_ptr;
40+
using std::vector;
41+
3442
// borrowed from
3543
// https://stackoverflow.com/questions/874134/find-if-string-ends-with-another-string-in-c
36-
inline bool ends_with(std::string const &value, std::string const &ending) {
44+
inline bool ends_with(string const &value, string const &ending) {
3745
if (ending.size() > value.size()) return false;
3846
return std::equal(ending.rbegin(), ending.rend(), value.rbegin());
3947
}
@@ -50,26 +58,26 @@ std::ostream &operator<<(std::ostream &os, const framework::VarDesc &var_desc) {
5058
return os;
5159
}
5260

53-
std::string to_string(const std::string &type,
54-
const VariableHandleMap &in_vars,
55-
const VariableHandleMap &out_vars,
56-
const framework::AttributeMap &attrs) {
61+
string to_string(const string &type,
62+
const VariableHandleMap &in_vars,
63+
const VariableHandleMap &out_vars,
64+
const framework::AttributeMap &attrs) {
5765
std::stringstream ss;
5866
ss << type << " ";
5967
for (auto &param_name : in_vars) {
6068
for (auto &var : param_name.second) {
61-
ss << param_name.first << ":(" << var << ") ";
69+
ss << param_name.first << ":(" << var->Name() << ") ";
6270
}
6371
}
6472
for (auto &param_name : out_vars) {
6573
for (auto &var : param_name.second) {
66-
ss << param_name.first << ":(" << var << ") ";
74+
ss << param_name.first << ":(" << var->Name() << ") ";
6775
}
6876
}
6977
return ss.str();
7078
}
7179

72-
framework::OpDesc CreateOpDesc(const std::string &type,
80+
framework::OpDesc CreateOpDesc(const string &type,
7381
const VariableHandleMap &in_vars,
7482
const VariableHandleMap &out_vars,
7583
const framework::AttributeMap &attrs) {
@@ -90,7 +98,7 @@ framework::OpDesc CreateOpDesc(const std::string &type,
9098
return op_desc;
9199
}
92100

93-
void InferShapeAndVarType(const std::string &type,
101+
void InferShapeAndVarType(const string &type,
94102
const VariableHandleMap &in_vars,
95103
VariableHandleMap *out_vars,
96104
const framework::AttributeMap &attrs) {
@@ -114,11 +122,12 @@ void InferShapeAndVarType(const std::string &type,
114122
op_with_kernel->InferShape(&infer_shape_ctx);
115123
}
116124

117-
void Tape::AddOp(const std::string &type,
125+
void Tape::AddOp(const string &type,
118126
const VariableHandleMap &in_vars,
119127
VariableHandleMap out_vars,
120128
const framework::AttributeMap &attrs) {
121129
PADDLE_ENFORCE(!has_been_backwarded_);
130+
LOG(INFO) << "AddOp " << to_string(type, in_vars, out_vars, attrs);
122131
InferShapeAndVarType(type, in_vars, &out_vars, attrs);
123132
tape_.emplace_back(type, in_vars, out_vars, attrs);
124133
}
@@ -138,6 +147,41 @@ void Tape::Forward() {
138147
VLOG(3) << "Finishing forward -------------------------";
139148
}
140149

150+
void DescMapToVarMap(
151+
const unordered_map<string, VariableHandle> &name2var,
152+
const framework::VariableNameMap &vmp,
153+
VariableHandleMap *vhm,
154+
vector<pair<VariableHandle, VariableHandle>> *duplicated_grad,
155+
bool is_output) {
156+
for (auto &p2a : vmp) {
157+
for (auto &argu : p2a.second) {
158+
if (name2var.count(argu)) {
159+
(*vhm)[p2a.first].push_back(name2var.at(argu));
160+
} else {
161+
PADDLE_ENFORCE(ends_with(argu, framework::kGradVarSuffix),
162+
"%s not end with %s",
163+
argu,
164+
framework::kGradVarSuffix);
165+
string name =
166+
argu.substr(0, argu.size() - strlen(framework::kGradVarSuffix));
167+
PADDLE_ENFORCE(name2var.count(name), "%s not found", name);
168+
if (is_output && name2var.at(name)->GradExist()) {
169+
VariableHandle temp_grad(new Variable(
170+
name + framework::kGradVarSuffix + framework::kTempVarName));
171+
(*vhm)[p2a.first].emplace_back(temp_grad);
172+
duplicated_grad->emplace_back(name2var.at(name)->Grad(),
173+
temp_grad); // name2var[name]->Grad has
174+
// to be the first element
175+
// since sum_op use X[0] ==
176+
// Out to determine inplace
177+
} else {
178+
(*vhm)[p2a.first].push_back(name2var.at(name)->Grad());
179+
}
180+
}
181+
}
182+
}
183+
}
184+
141185
void Tape::Backward(VariableHandle target) {
142186
PADDLE_ENFORCE(!has_been_backwarded_);
143187

@@ -146,21 +190,20 @@ void Tape::Backward(VariableHandle target) {
146190
// TODO(tonyyang-svail): check output of last op is target
147191
backward_tape_.reset(new Tape());
148192

149-
// FIXME(tonyyang-svail): Need to infer_data_type
150193
backward_tape_->AddOp(
151194
"fill_ones_like", {{"X", {target}}}, {{"Out", {target->Grad()}}}, {});
152195

153196
for (auto it = tape_.rbegin(); it != tape_.rend(); ++it) {
154197
framework::OpDesc op_desc =
155198
CreateOpDesc(it->type_, it->inputs_, it->outputs_, it->attrs_);
156-
std::unordered_map<std::string, std::string> grad_to_var;
157-
std::vector<std::unique_ptr<framework::OpDesc>> grad_op_descs =
199+
unordered_map<string, string> grad_to_var;
200+
vector<unique_ptr<framework::OpDesc>> grad_op_descs =
158201
framework::OpInfoMap::Instance()
159202
.Get(op_desc.Type())
160203
.GradOpMaker()(op_desc, {}, &grad_to_var, {});
161204

162-
for (auto &op_desc : grad_op_descs) {
163-
std::unordered_map<std::string, VariableHandle> name2var;
205+
for (auto &op_grad_desc : grad_op_descs) {
206+
unordered_map<string, VariableHandle> name2var;
164207
for (auto &param2vars : it->inputs_) {
165208
for (auto &a : param2vars.second) {
166209
name2var[a->Name()] = a;
@@ -172,36 +215,25 @@ void Tape::Backward(VariableHandle target) {
172215
}
173216
}
174217

175-
VariableHandleMap in_vars;
176-
VariableHandleMap out_vars;
177-
std::map<const framework::VariableNameMap *, VariableHandleMap *>
178-
loop_over{{&op_desc->Inputs(), &in_vars},
179-
{&op_desc->Outputs(), &out_vars}};
180-
for (auto &each : loop_over) {
181-
auto &vmp = *each.first;
182-
auto &vhm = *each.second;
183-
for (auto &p2a : vmp) {
184-
for (auto &argu : p2a.second) {
185-
if (name2var.count(argu)) {
186-
vhm[p2a.first].push_back(name2var[argu]);
187-
} else {
188-
PADDLE_ENFORCE(ends_with(argu, framework::kGradVarSuffix),
189-
argu.c_str());
190-
std::string name = argu.substr(
191-
0, argu.size() - std::strlen(framework::kGradVarSuffix));
192-
PADDLE_ENFORCE(name2var.count(name), name.c_str());
193-
vhm[p2a.first].push_back(name2var[name]->Grad());
194-
}
195-
}
196-
}
197-
}
218+
vector<pair<VariableHandle, VariableHandle>>
219+
duplicated_grad; // vector of {grad, grad@temp}
220+
VariableHandleMap in_vars, out_vars;
221+
DescMapToVarMap(
222+
name2var, op_grad_desc->Inputs(), &in_vars, &duplicated_grad, false);
223+
DescMapToVarMap(
224+
name2var, op_grad_desc->Outputs(), &out_vars, &duplicated_grad, true);
198225

199226
backward_tape_->AddOp(
200-
op_desc->Type(), in_vars, out_vars, op_desc->GetAttrMap());
227+
op_grad_desc->Type(), in_vars, out_vars, op_grad_desc->GetAttrMap());
228+
for (auto &pair : duplicated_grad) {
229+
backward_tape_->AddOp("sum",
230+
{{"X", {pair.first, pair.second}}},
231+
{{"Out", {pair.first}}},
232+
{});
233+
}
201234
}
202235

203236
// TODO(tonyyang-svail): how to fill empty grad?
204-
// TODO(tonyyang-svail): Sum var grad is necessary
205237
}
206238

207239
backward_tape_->Forward();

src/tape.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,13 @@ class Tape {
8888

8989
std::vector<OpHandle> tape_;
9090
std::shared_ptr<Tape> backward_tape_;
91+
92+
// void DescMapToVarMap(const std::unordered_map<std::string, VariableHandle>
93+
// &name2var,
94+
// const std::vector<std::pair<VariableHandle,
95+
// VariableHandle>> &duplicated_grad,
96+
// const std::pair<const framework::VariableNameMap
97+
// *const, VariableHandleMap *> &each) const;
9198
};
9299

93100
Tape &get_global_tape();

src/variable.h

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -55,11 +55,7 @@ class Variable {
5555
}
5656
}
5757

58-
// Stochastic Gradient Descent with Momentum
59-
// VariableHandle Momentum ();
60-
61-
// void init(const std::string& initializer,
62-
// const framework::AttributeMap& attrs);
58+
bool GradExist() { return !grad_.expired(); }
6359

6460
// Evaluate a variable by running Forward() on the global tape
6561
const Variable& Value();

0 commit comments

Comments
 (0)