Skip to content

Rewrite StaticRNN with Executor #5224

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 47 commits into from
Nov 2, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
e1d22b0
Init commit
reyoung Oct 29, 2017
c35f956
Make executor use ProgramDescBind
reyoung Oct 29, 2017
aaf9173
Change Attribute from BlockDesc to BlockDescBind
reyoung Oct 29, 2017
76575d1
Add DeviceContext to Executor API
reyoung Oct 29, 2017
c65528f
Rewrite RNN
reyoung Oct 29, 2017
9197a26
Merge branch 'develop' into feature/rnn_use_executor
reyoung Oct 29, 2017
61fa04d
Pass Python
reyoung Oct 29, 2017
4b3a3ab
AddBiasOp does not care num_flatten_dims
reyoung Oct 30, 2017
711da5a
Merge branch 'feature/error_in_add_bias' into feature/rnn_use_executor
reyoung Oct 30, 2017
1e2954c
Merge branch 'develop' of github.com:baidu/Paddle into feature/rnn_us…
reyoung Oct 30, 2017
705329b
Stash
reyoung Oct 30, 2017
4aca6c2
Fix MacOS Compile
reyoung Oct 30, 2017
cccc17c
Merge branch 'feature/fill_compile' into feature/rnn_use_executor
reyoung Oct 30, 2017
20bdfe2
Pass RNN forward
reyoung Oct 30, 2017
425f283
add python test
Oct 30, 2017
28090cf
refactor test
Oct 30, 2017
7610c30
Make compile pass
reyoung Oct 30, 2017
de7738e
Merge remote-tracking branch 'pr/5224' into 5224
Oct 30, 2017
f207187
add gradopmaker
Oct 30, 2017
1aafbe6
First draft done
reyoung Oct 30, 2017
d39534f
Polish code
reyoung Oct 30, 2017
4f72621
add grad op maker and grad infershape
Oct 30, 2017
0e757f5
Polish code
reyoung Oct 30, 2017
afcf54d
Merge remote-tracking branch 'tony/add_rnn_test' into feature/rnn_use…
reyoung Oct 30, 2017
3457931
Fix backward.cc bug
reyoung Oct 30, 2017
2147542
Fix infershape
reyoung Oct 31, 2017
71be7bc
Rename function
reyoung Oct 31, 2017
3d8653f
add backward test
Oct 31, 2017
4739392
Merge remote-tracking branch 'tony/add_rnn_test' into feature/rnn_use…
reyoung Oct 31, 2017
e8976a0
simplify recurrent test
Oct 31, 2017
1fac53c
Merge remote-tracking branch 'tony/add_rnn_test' into feature/rnn_use…
reyoung Oct 31, 2017
ff3de12
Update
reyoung Oct 31, 2017
4cc5e9c
Pass unittest
reyoung Oct 31, 2017
b197e3c
Add comments & refine test
reyoung Oct 31, 2017
058f5e0
Add comments
reyoung Oct 31, 2017
7cc88ce
refactor test
Oct 31, 2017
58bf36f
Merge remote-tracking branch 'tony/add_rnn_test' into feature/rnn_use…
reyoung Oct 31, 2017
adc448e
Complete Unittest
reyoung Oct 31, 2017
f456256
fix StepScopes enforce
Oct 31, 2017
9504d94
Remove unused unittest
reyoung Oct 31, 2017
c2ea4d2
Merge remote-tracking branch 'tony/add_rnn_test' into feature/rnn_use…
reyoung Oct 31, 2017
db0a122
Merge branch 'develop' of github.com:baidu/Paddle into feature/rnn_us…
reyoung Oct 31, 2017
9289391
no type error
Oct 31, 2017
3a06076
Merge remote-tracking branch 'tony/add_rnn_test' into feature/rnn_use…
reyoung Oct 31, 2017
48e2543
Update
reyoung Oct 31, 2017
ced9496
Merge branch 'develop' of github.com:baidu/Paddle into feature/rnn_us…
reyoung Nov 1, 2017
9e7db4e
Make RNN Pass unittest
reyoung Nov 1, 2017
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 16 additions & 27 deletions paddle/framework/backward.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
#include "paddle/framework/op_registry.h"
#include "paddle/operators/dynamic_recurrent_op.h"
#include "paddle/operators/net_op.h"
#include "paddle/operators/recurrent_op.h"

namespace paddle {
namespace framework {
Expand All @@ -38,7 +37,7 @@ static inline std::unique_ptr<OperatorBase> CreateGradOp(
op_desc.SetType(op.Type());
op_desc.SetAttrMap(op.Attrs());
auto& info = OpInfoMap::Instance().Get(op.Type());
auto grad_descs = info.GradOpMaker()(op_desc, no_grad_set, grad_to_var);
auto grad_descs = info.GradOpMaker()(op_desc, no_grad_set, grad_to_var, {});
std::vector<std::unique_ptr<OperatorBase>> grad_ops;
grad_ops.reserve(grad_descs.size());
std::transform(grad_descs.begin(), grad_descs.end(),
Expand Down Expand Up @@ -220,19 +219,7 @@ static std::unique_ptr<OperatorBase> BackwardRecursive(
});

// process recurrent gradient op as a special operator.
if (forwardOp.Type() == "recurrent") {
// NOTE clean up cycle call somewhere (RNN's stepnet constains itself),
// or this will result in infinite loop.
const auto& rnnop =
*static_cast<const operators::RecurrentOp*>(&forwardOp);
auto rnn_grad_op =
static_cast<operators::RecurrentGradientOp*>(grad_op.get());
const auto& stepnet_op =
*static_cast<const OperatorBase*>(&rnnop.stepnet());
// create stepnet's gradient op
rnn_grad_op->set_stepnet(
BackwardRecursive(stepnet_op, no_grad_names, grad_to_var, uniq_id));
} else if (forwardOp.Type() == "dynamic_recurrent") {
if (forwardOp.Type() == "dynamic_recurrent") {
// NOTE clean up cycle call somewhere (RNN's stepnet constains itself),
// or this will result in infinite loop.
const auto& rnnop =
Expand Down Expand Up @@ -331,7 +318,7 @@ static void CreateGradVarInBlock(
continue;
}
auto pname = FwdName(arg);
auto* param = block_desc->FindVar(pname);
auto* param = block_desc->FindVarRecursive(pname);
auto* grad = block_desc->FindVar(arg);
if (param == nullptr) {
LOG(WARNING) << "Cannot find forward variable of " << arg
Expand All @@ -348,7 +335,9 @@ static void CreateGradVarInBlock(

std::vector<std::unique_ptr<OpDescBind>> MakeOpGrad(
const OpDescBind* op_desc, std::unordered_set<std::string>* no_grad_vars,
std::unordered_map<std::string, std::string>* grad_to_var) {
std::unordered_map<std::string, std::string>* grad_to_var,
const std::vector<BlockDescBind*>& grad_block =
std::vector<BlockDescBind*>()) {
std::vector<std::unique_ptr<OpDescBind>> grad_op_descs;
// All input gradients of forwarding operator do not need to calculate.
const std::vector<std::string>& inputs = op_desc->InputArgumentNames();
Expand All @@ -364,9 +353,10 @@ std::vector<std::unique_ptr<OpDescBind>> MakeOpGrad(
return grad_op_descs; // empty vector
}

grad_op_descs = OpInfoMap::Instance()
.Get(op_desc->Type())
.GradOpMaker()(*op_desc, *no_grad_vars, grad_to_var);
grad_op_descs =
OpInfoMap::Instance()
.Get(op_desc->Type())
.GradOpMaker()(*op_desc, *no_grad_vars, grad_to_var, grad_block);

std::list<std::unique_ptr<OpDescBind>> pending_fill_zeros_ops;
for (auto& desc : grad_op_descs) {
Expand Down Expand Up @@ -400,21 +390,20 @@ std::vector<std::unique_ptr<OpDescBind>> MakeBlockBackward(
std::vector<std::unique_ptr<OpDescBind>> backward_descs;

for (auto it = op_descs.rbegin(); it != op_descs.rend(); ++it) {
std::vector<std::unique_ptr<OpDescBind>> op_grads =
MakeOpGrad(*it, no_grad_vars, grad_to_var);
std::vector<std::unique_ptr<OpDescBind>> op_grads;

if ((*it)->Type() == "recurrent") {
PADDLE_ENFORCE_EQ(
op_grads.size(), static_cast<size_t>(1),
"rnn_op's gradient process should contain only one op.");
int step_block_idx = (*it)->GetBlockAttr("step_block");
auto backward_block_op_descs = MakeBlockBackward(
program_desc, step_block_idx, no_grad_vars, grad_to_var);
BlockDescBind* backward_block = program_desc.AppendBlock(*cur_block);
BlockDescBind* backward_block =
program_desc.AppendBlock(*program_desc.MutableBlock(step_block_idx));
for (auto& ptr : backward_block_op_descs) {
backward_block->AppendAllocatedOp(std::move(ptr));
}
op_grads[0]->SetBlockAttr("step_block", *backward_block);
op_grads = MakeOpGrad(*it, no_grad_vars, grad_to_var, {backward_block});
} else {
op_grads = MakeOpGrad(*it, no_grad_vars, grad_to_var);
}

for (const auto& desc : op_grads) {
Expand Down
2 changes: 2 additions & 0 deletions paddle/framework/block_desc.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,8 @@ class BlockDescBind {

BlockDesc *Proto();

ProgramDescBind *Program() { return this->prog_; }

private:
void ClearPBOps();
void ClearPBVars();
Expand Down
5 changes: 3 additions & 2 deletions paddle/framework/details/op_registry.h
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,9 @@ struct OpInfoFiller<T, kGradOpDescMaker> {
info->grad_op_maker_ = [](
const OpDescBind& fwd_op,
const std::unordered_set<std::string>& no_grad_set,
std::unordered_map<std::string, std::string>* grad_to_var) {
T maker(fwd_op, no_grad_set, grad_to_var);
std::unordered_map<std::string, std::string>* grad_to_var,
const std::vector<BlockDescBind*>& grad_block) {
T maker(fwd_op, no_grad_set, grad_to_var, grad_block);
return maker();
};
}
Expand Down
61 changes: 40 additions & 21 deletions paddle/framework/executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ namespace framework {
const std::string kFeedOpType = "feed";
const std::string kFetchOpType = "fetch";

Executor::Executor(const std::vector<platform::Place>& places) {
Executor::Executor(const std::vector<platform::Place>& places) : own_(true) {
PADDLE_ENFORCE_GT(places.size(), 0);
device_contexts_.resize(places.size());
for (size_t i = 0; i < places.size(); i++) {
Expand All @@ -52,8 +52,10 @@ Executor::Executor(const std::vector<platform::Place>& places) {
}

Executor::~Executor() {
for (auto& device_context : device_contexts_) {
delete device_context;
if (own_) {
for (auto& device_context : device_contexts_) {
delete device_context;
}
}
}

Expand All @@ -66,44 +68,61 @@ static void CreateTensor(Variable* var, VarDesc::VarType var_type) {
var->GetMutable<FeedFetchList>();
} else if (var_type == VarDesc::FETCH_LIST) {
var->GetMutable<FeedFetchList>();
} else if (var_type == VarDesc::STEP_SCOPES) {
var->GetMutable<std::vector<framework::Scope>>();
} else {
PADDLE_THROW(
"Variable type must be "
"LoDTensor/SelectedRows/FEED_MINIBATCH/FETCH_LIST.");
"Variable type %d is not in "
"[LoDTensor, SelectedRows, FEED_MINIBATCH, FETCH_LIST]",
var_type);
}
}

void Executor::Run(const ProgramDescBind& pdesc, Scope* scope, int block_id) {
void Executor::Run(const ProgramDescBind& pdesc, Scope* scope, int block_id,
bool create_local_scope) {
// TODO(tonyyang-svail):
// - only runs on the first device (i.e. no interdevice communication)
// - will change to use multiple blocks for RNN op and Cond Op
PADDLE_ENFORCE_LT(block_id, pdesc.Size());
auto& block = pdesc.Block(block_id);
auto& device = device_contexts_[0];

Scope& local_scope = scope->NewScope();

for (auto& var : block.AllVars()) {
if (var->Persistable()) {
auto* ptr = scope->Var(var->Name());
CreateTensor(ptr, var->GetType());
VLOG(3) << "Create Variable " << var->Name()
<< " global, which pointer is " << ptr;
} else {
auto* ptr = local_scope.Var(var->Name());
Scope* local_scope = scope;
if (create_local_scope) {
local_scope = &scope->NewScope();
for (auto& var : block.AllVars()) {
if (var->Persistable()) {
auto* ptr = scope->Var(var->Name());
CreateTensor(ptr, var->GetType());
VLOG(3) << "Create Variable " << var->Name()
<< " global, which pointer is " << ptr;
} else {
auto* ptr = local_scope->Var(var->Name());
CreateTensor(ptr, var->GetType());
VLOG(3) << "Create Variable " << var->Name()
<< " locally, which pointer is " << ptr;
}
}
} else {
for (auto& var : block.AllVars()) {
auto* ptr = local_scope->Var(var->Name());
CreateTensor(ptr, var->GetType());
VLOG(3) << "Create Variable " << var->Name()
<< " locally, which pointer is " << ptr;
VLOG(3) << "Create variable " << var->Name() << ", which pointer is "
<< ptr;
}
}

for (auto& op_desc : block.AllOps()) {
auto op = paddle::framework::OpRegistry::CreateOp(*op_desc);
op->Run(local_scope, *device);
op->Run(*local_scope, *device);
}
if (create_local_scope) {
scope->DeleteScope(local_scope);
}

scope->DeleteScope(&local_scope);
}

Executor::Executor(const platform::DeviceContext& device)
: device_contexts_({&device}), own_(false) {}

} // namespace framework
} // namespace paddle
6 changes: 4 additions & 2 deletions paddle/framework/executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ namespace framework {
class Executor {
public:
explicit Executor(const std::vector<platform::Place>& places);
explicit Executor(const platform::DeviceContext& devices);
~Executor();

/* @Brief
Expand All @@ -34,10 +35,11 @@ class Executor {
* ProgramDesc
* Scope
*/
void Run(const ProgramDescBind&, Scope*, int);
void Run(const ProgramDescBind&, Scope*, int, bool create_local_scope = true);

private:
std::vector<platform::DeviceContext*> device_contexts_;
std::vector<const platform::DeviceContext*> device_contexts_;
bool own_;
};

} // namespace framework
Expand Down
13 changes: 11 additions & 2 deletions paddle/framework/grad_op_desc_maker.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#pragma once
#include <string>
#include <unordered_set>
#include <vector>
#include "paddle/framework/op_desc.h"
#include "paddle/framework/operator.h"

Expand All @@ -26,8 +27,13 @@ class GradOpDescMakerBase {
explicit GradOpDescMakerBase(
const OpDescBind& fwd_op,
const std::unordered_set<std::string>& no_grad_set,
std::unordered_map<std::string, std::string>* grad_to_var)
: fwd_op_(fwd_op), no_grad_set_(no_grad_set), grad_to_var_(grad_to_var) {}
std::unordered_map<std::string, std::string>* grad_to_var,
const std::vector<BlockDescBind*>& grad_block =
std::vector<BlockDescBind*>())
: fwd_op_(fwd_op),
no_grad_set_(no_grad_set),
grad_to_var_(grad_to_var),
grad_block_(grad_block) {}

virtual ~GradOpDescMakerBase() = default;
virtual std::vector<std::unique_ptr<OpDescBind>> operator()() const = 0;
Expand Down Expand Up @@ -102,6 +108,9 @@ class GradOpDescMakerBase {
const OpDescBind& fwd_op_;
const std::unordered_set<std::string>& no_grad_set_;
std::unordered_map<std::string, std::string>* grad_to_var_;

protected:
std::vector<BlockDescBind*> grad_block_;
};

class SingleGradOpDescMaker : public GradOpDescMakerBase {
Expand Down
13 changes: 13 additions & 0 deletions paddle/framework/op_desc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,19 @@ void OpDescBind::InferShape(const BlockDescBind &block) const {
PADDLE_ENFORCE(static_cast<bool>(infer_shape),
"%s's infer_shape has not been registered", this->Type());
CompileTimeInferShapeContext ctx(*this, block);
if (VLOG_IS_ON(10)) {
std::ostringstream sout;
auto inames = this->InputArgumentNames();
sout << " From [";
std::copy(inames.begin(), inames.end(),
std::ostream_iterator<std::string>(sout, ", "));
sout << "] to [";
auto onames = this->OutputArgumentNames();
std::copy(onames.begin(), onames.end(),
std::ostream_iterator<std::string>(sout, ", "));
sout << "]";
VLOG(10) << sout.str();
}
infer_shape(&ctx);
}

Expand Down
16 changes: 14 additions & 2 deletions paddle/framework/operator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ OperatorBase::OperatorBase(const std::string& type,

std::vector<std::string> OperatorBase::InputVars() const {
std::vector<std::string> ret_val;
for (auto& o : outputs_) {
for (auto& o : inputs_) {
ret_val.reserve(ret_val.size() + o.second.size());
ret_val.insert(ret_val.end(), o.second.begin(), o.second.end());
}
Expand Down Expand Up @@ -394,7 +394,19 @@ class RuntimeInferShapeContext : public InferShapeContext {

void OperatorWithKernel::Run(const Scope& scope,
const platform::DeviceContext& dev_ctx) const {
VLOG(3) << "Running operator " << this->Type();
if (VLOG_IS_ON(1)) {
auto inputs = this->InputVars();
auto outputs = this->OutputVars(true);
std::ostringstream sout;
sout << "Run operator " << this->Type() << " From [";
std::ostream_iterator<std::string> out_it(sout, ",");
std::copy(inputs.begin(), inputs.end(), out_it);
sout << "] to [";
std::copy(outputs.begin(), outputs.end(), out_it);
sout << "]";
VLOG(1) << sout.str();
}

RuntimeInferShapeContext infer_shape_ctx(*this, scope);
this->InferShape(&infer_shape_ctx);

Expand Down
8 changes: 6 additions & 2 deletions paddle/framework/scope.cc
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,12 @@ Variable* Scope::Var(const std::string& name) {
return v;
}

Variable* Scope::Var() {
return Var(string::Sprintf("%p.%d", this, vars_.size()));
Variable* Scope::Var(std::string* name) {
auto var_name = string::Sprintf("%p.%d", this, vars_.size());
if (name != nullptr) {
*name = var_name;
}
return Var(var_name);
}

Variable* Scope::FindVar(const std::string& name) const {
Expand Down
2 changes: 1 addition & 1 deletion paddle/framework/scope.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ class Scope {
Variable* Var(const std::string& name);

/// Create a variable with a scope-unique name.
Variable* Var();
Variable* Var(std::string* name = nullptr);

/// Find a variable in the scope or any of its ancestors. Returns
/// nullptr if cannot find.
Expand Down
2 changes: 1 addition & 1 deletion paddle/framework/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ class Tensor {
* @param[in] end_idx The index of the end row(exclusive) to slice.
* The index number begins from 0.
*/
inline Tensor Slice(const int& begin_idx, const int& end_idx) const;
inline Tensor Slice(int begin_idx, int end_idx) const;

platform::Place place() const {
PADDLE_ENFORCE_NOT_NULL(
Expand Down
2 changes: 1 addition & 1 deletion paddle/framework/tensor_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ inline void Tensor::CopyFromVector(const std::vector<T>& src,
#endif
}

inline Tensor Tensor::Slice(const int& begin_idx, const int& end_idx) const {
inline Tensor Tensor::Slice(int begin_idx, int end_idx) const {
check_memory_size();
PADDLE_ENFORCE_GE(begin_idx, 0,
"The start row index must be greater than 0.");
Expand Down
4 changes: 3 additions & 1 deletion paddle/framework/type_defs.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ class OpDescBind;
class BlockDescBind;
class BlockDesc;
class InferShapeContext;
class BlockDescBind;

using VariableNameMap = std::map<std::string, std::vector<std::string>>;

Expand All @@ -46,7 +47,8 @@ using OpCreator = std::function<OperatorBase*(

using GradOpMakerFN = std::function<std::vector<std::unique_ptr<OpDescBind>>(
const OpDescBind&, const std::unordered_set<std::string>& /*no_grad_set*/,
std::unordered_map<std::string, std::string>* /*grad_to_var*/)>;
std::unordered_map<std::string, std::string>* /*grad_to_var*/,
const std::vector<BlockDescBind*>& grad_block)>;

using InferVarTypeFN = std::function<void(const OpDescBind& /*op_desc*/,
BlockDescBind* /*block*/)>;
Expand Down
Loading