Skip to content

Commit aee3708

Browse files
author
Yang Yang
committed
add sgd optimizer
1 parent 5ae4c56 commit aee3708

File tree

5 files changed

+52
-2
lines changed

5 files changed

+52
-2
lines changed

paddle/contrib/tape/function.h

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,10 +80,43 @@ class Linear {
8080
return y;
8181
}
8282

83+
std::vector<VariableHandle> Params() { return {w_}; }
84+
8385
private:
8486
VariableHandle w_;
8587
VariableHandle b_;
8688
std::string act_;
8789
};
90+
91+
class SGD {
92+
public:
93+
SGD(float learning_rate) : learning_rate_(new Variable("sgd")) {
94+
Tape init_tape;
95+
96+
std::string initializer = "fill_constant";
97+
framework::AttributeMap attrs;
98+
attrs["dtype"] = paddle::framework::proto::VarType::Type::VarType_Type_FP32;
99+
attrs["shape"] = std::vector<int>{1};
100+
attrs["value"] = learning_rate;
101+
init_tape.AddOp(initializer, {}, {{"Out", {learning_rate_}}}, attrs);
102+
103+
init_tape.Forward();
104+
}
105+
106+
void operator()(VariableHandle input) {
107+
Tape temp_tape;
108+
temp_tape.AddOp("sgd",
109+
{{"Param", {input}},
110+
{"LearningRate", {learning_rate_}},
111+
{"Grad", {input->Grad()}}},
112+
{{"ParamOut", {input}}},
113+
{});
114+
temp_tape.Forward();
115+
input->ResetGrad();
116+
}
117+
118+
private:
119+
VariableHandle learning_rate_;
120+
};
88121
}
89122
}

paddle/contrib/tape/tape.cc

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -135,12 +135,16 @@ class ScopeWrapper : public framework::Scope {
135135
const VariableHandleMap &out_vars) {
136136
for (auto &v : in_vars) {
137137
for (auto &vv : v.second) {
138-
vars_[vv->Name()].reset(vv->Var());
138+
if (!vars_.count(vv->Name())) {
139+
vars_[vv->Name()].reset(vv->Var());
140+
}
139141
}
140142
}
141143
for (auto &v : out_vars) {
142144
for (auto &vv : v.second) {
143-
vars_[vv->Name()].reset(vv->Var());
145+
if (!vars_.count(vv->Name())) {
146+
vars_[vv->Name()].reset(vv->Var());
147+
}
144148
}
145149
}
146150
}

paddle/contrib/tape/test_tape.cc

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ TEST(Tape, TestMLP) {
2121
paddle::tape::Linear linear2(3, 3, "relu");
2222
paddle::tape::Mean mean;
2323

24+
paddle::tape::SGD sgd(0.001);
25+
2426
for (int i = 0; i < 2; ++i) {
2527
paddle::tape::reset_global_tape();
2628

@@ -36,6 +38,13 @@ TEST(Tape, TestMLP) {
3638
auto loss = mean(linear2(linear1(input)));
3739

3840
paddle::tape::get_global_tape().Backward(loss);
41+
42+
for (auto w : linear1.Params()) {
43+
sgd(w);
44+
}
45+
for (auto w : linear2.Params()) {
46+
sgd(w);
47+
}
3948
}
4049
}
4150

paddle/contrib/tape/variable.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,8 @@ class Variable {
5252
return grad_;
5353
}
5454

55+
void ResetGrad() { grad_ = nullptr; }
56+
5557
// Stochastic Gradient Descent with Momentum
5658
// VariableHandle Momentum ();
5759

paddle/fluid/framework/operator.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ static LoD GetLoD(const Scope& scope, const std::string& name) {
8585
}
8686

8787
void OperatorBase::Run(const Scope& scope, const platform::Place& place) {
88+
VLOG(10) << "- " << DebugStringEx(&scope);
8889
if (platform::is_gpu_place(place)) {
8990
#ifndef PADDLE_WITH_CUDA
9091
PADDLE_THROW("Cannot run operator on place %s", place);
@@ -94,6 +95,7 @@ void OperatorBase::Run(const Scope& scope, const platform::Place& place) {
9495
#endif
9596
}
9697
RunImpl(scope, place);
98+
VLOG(10) << "+ " << DebugStringEx(&scope);
9799
}
98100

99101
bool OperatorBase::HasInputs(const std::string& name) const {

0 commit comments

Comments
 (0)