Skip to content

[WIP] Executor RNN support #4910

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
51 commits
Select commit Hold shift + click to select a range
00c8869
bring back executor_test
Oct 17, 2017
a158ddf
merge pr 4855
Oct 17, 2017
f3bb088
temp save
Oct 18, 2017
a8e6bfe
pass linking; ready to dive in
Oct 18, 2017
48fdd24
before new rnn
Oct 18, 2017
8b8a4de
rnn pass simple forward
Oct 18, 2017
90fe5c1
merge develop; fix conflict
Oct 18, 2017
7e62a52
clean up test
Oct 18, 2017
dc373e1
more log
Oct 19, 2017
af858bc
create vars by var list
Oct 19, 2017
dccd3c6
add var_ check
Oct 19, 2017
a57b76c
fix typo
Oct 19, 2017
0efa6e9
merge develop
Oct 19, 2017
8d067c5
merge develop
Oct 24, 2017
e4348cd
Merge remote-tracking branch 'upstream/develop' into executor_rnn
Oct 24, 2017
65a1709
pass compile
Oct 24, 2017
6f5e435
add add_fill_constant_batch_size_like_op
Oct 24, 2017
82c43f9
add test
Oct 24, 2017
22f05a6
register double type kernel
Oct 24, 2017
f5420a1
Merge remote-tracking branch 'pr/5057' into executor_rnn
Oct 24, 2017
34f4e72
change test name
Oct 24, 2017
cad0d64
Merge remote-tracking branch 'pr/5057' into executor_rnn
Oct 24, 2017
37d12be
fix compile
Oct 24, 2017
7f5a36b
Merge remote-tracking branch 'pr/5057' into executor_rnn
Oct 24, 2017
c46297d
Partial complete rnn
reyoung Oct 24, 2017
8a0e4e6
Merge pull request #2 from reyoung/rnn_exec
Oct 24, 2017
9a1427d
switch from fill_constant to fill_constant_batch_size_like
Oct 24, 2017
9fe3c0f
Merge remote-tracking branch 'pr/4910' into executor_rnn
Oct 24, 2017
3845d6d
before merge
Oct 25, 2017
8a45492
Merge remote-tracking branch 'upstream/develop' into executor_rnn
Oct 25, 2017
e8d8528
Merge remote-tracking branch 'upstream/develop' into executor_rnn
Oct 25, 2017
019f604
rnn pass simple
Oct 25, 2017
1c8c568
pass rnn forward
Oct 25, 2017
4d4029a
fix test_rnn_helpers
Oct 25, 2017
bf5e534
clean up
Oct 26, 2017
9f6a292
Merge remote-tracking branch 'upstream/develop' into executor_rnn
Oct 26, 2017
a6c34a5
test_recurrent_op add get_numerical_gradient
Oct 26, 2017
59d2ae7
Make InferShape as a field in OpInfo
reyoung Oct 26, 2017
cb9d119
add rnn grad infershape
Oct 27, 2017
830bb21
clean up
Oct 27, 2017
2f69947
merge develop
Oct 27, 2017
d897e45
before merge develop
Oct 27, 2017
99781eb
merge develop
Oct 27, 2017
88e1afe
push for review
Oct 27, 2017
0483373
pass backward compile time infershape
Oct 27, 2017
bfbcae2
add step_scope; may not be necesarry
Oct 27, 2017
72e2490
change from Tensor to LoDTensor
Oct 27, 2017
ed53e17
use LoDTensor by default
Oct 27, 2017
acb93f0
add mean to make backward fill op work
Oct 27, 2017
3c7cef4
op grad can't pass compile
Oct 28, 2017
7fe76be
add grad maker; pass compile
Oct 28, 2017
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 14 additions & 12 deletions paddle/framework/backward.cc
Original file line number Diff line number Diff line change
Expand Up @@ -220,17 +220,18 @@ static std::unique_ptr<OperatorBase> BackwardRecursive(

// process recurrent gradient op as a special operator.
if (forwardOp.Type() == "recurrent") {
PADDLE_THROW("Disable old backward");
// NOTE clean up cycle call somewhere (RNN's stepnet constains itself),
// or this will result in infinite loop.
const auto& rnnop =
*static_cast<const operators::RecurrentOp*>(&forwardOp);
auto rnn_grad_op =
static_cast<operators::RecurrentGradientOp*>(grad_op.get());
const auto& stepnet_op =
*static_cast<const OperatorBase*>(&rnnop.stepnet());
// create stepnet's gradient op
rnn_grad_op->set_stepnet(
BackwardRecursive(stepnet_op, no_grad_names, grad_to_var, uniq_id));
// const auto& rnnop =
// *static_cast<const operators::RecurrentOp*>(&forwardOp);
// auto rnn_grad_op =
// static_cast<operators::RecurrentGradientOp*>(grad_op.get());
// const auto& stepnet_op =
// *static_cast<const OperatorBase*>(&rnnop.stepnet());
// // create stepnet's gradient op
// rnn_grad_op->set_stepnet(
// BackwardRecursive(stepnet_op, no_grad_names, grad_to_var, uniq_id));
} else if (forwardOp.Type() == "dynamic_recurrent") {
// NOTE clean up cycle call somewhere (RNN's stepnet constains itself),
// or this will result in infinite loop.
Expand Down Expand Up @@ -382,14 +383,15 @@ std::vector<std::unique_ptr<OpDescBind>> MakeBlockBackward(
PADDLE_ENFORCE_EQ(
op_grads.size(), static_cast<size_t>(1),
"rnn_op's gradient process should contain only one op.");
int step_block_idx = (*it)->GetBlockAttr("step_block");
int step_block_idx = (*it)->GetBlockAttr("block_idx");
auto backward_block_op_descs = MakeBlockBackward(
program_desc, step_block_idx, no_grad_vars, grad_to_var);
BlockDescBind* backward_block = program_desc.AppendBlock(*cur_block);
BlockDescBind* backward_block =
program_desc.AppendBlock(*program_desc.Block(step_block_idx));
for (auto& ptr : backward_block_op_descs) {
backward_block->AppendAllocatedOp(std::move(ptr));
}
op_grads[0]->SetBlockAttr("step_block", *backward_block);
op_grads[0]->SetBlockAttr("block_idx", *backward_block);
}

for (const auto& desc : op_grads) {
Expand Down
70 changes: 63 additions & 7 deletions paddle/framework/executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ limitations under the License. */
#include "paddle/framework/op_registry.h"
#include "paddle/framework/scope.h"

#include "paddle/operators/recurrent_op.h"

namespace paddle {
namespace framework {

Expand Down Expand Up @@ -57,6 +59,47 @@ Executor::~Executor() {
}
}

int getBlockIdx(const OpDesc& op_desc) {
for (auto& attr : op_desc.attrs()) {
if (attr.has_block_idx()) {
return attr.block_idx();
}
}

PADDLE_THROW("Missing block_idx in recurrent opDesc");
return -1;
}

std::unique_ptr<OperatorBase> create_op(const ProgramDesc& pdesc,
const OpDesc& op_desc) {
auto op = paddle::framework::OpRegistry::CreateOp(
op_desc, const_cast<ProgramDesc*>(&pdesc));

if (op_desc.type() == "recurrent" || op_desc.type() == "recurrent_grad") {
int block_idx = getBlockIdx(op_desc);
std::unique_ptr<std::vector<std::unique_ptr<OperatorBase>>> step_net{
new std::vector<std::unique_ptr<OperatorBase>>};
for (auto& my_op_desc : pdesc.blocks(block_idx).ops()) {
step_net->push_back(create_op(pdesc, my_op_desc));
}
std::vector<std::string> vars;
for (auto& var : pdesc.blocks(block_idx).vars()) {
vars.push_back(var.name());
}
if (auto* rnn_op = dynamic_cast<operators::RecurrentOp*>(op.get())) {
rnn_op->set_stepnet(step_net, vars);
} else if (auto* rnn_op =
dynamic_cast<operators::RecurrentGradientOp*>(op.get())) {
rnn_op->set_stepnet(step_net, vars);
} else {
PADDLE_THROW("dynamic_cast<RecurrentOp*> fail");
}
VLOG(3) << "GO";
}

return op;
}

static void CreateTensor(Variable* var, VarDesc::VarType var_type) {
if (var_type == VarDesc::LOD_TENSOR) {
var->GetMutable<LoDTensor>();
Expand All @@ -66,10 +109,12 @@ static void CreateTensor(Variable* var, VarDesc::VarType var_type) {
var->GetMutable<FeedFetchList>();
} else if (var_type == VarDesc::FETCH_LIST) {
var->GetMutable<FeedFetchList>();
} else if (var_type == VarDesc::STEP_SCOPES) {
var->GetMutable<std::vector<Scope*>>();
} else {
PADDLE_THROW(
"Variable type must be "
"LoDTensor/SelectedRows/FEED_MINIBATCH/FETCH_LIST.");
"LoDTensor/SelectedRows/FEED_MINIBATCH/FETCH_LIST/STEP_SCOPES.");
}
}

Expand All @@ -85,24 +130,35 @@ void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id) {

for (auto& var : block.vars()) {
if (var.persistable()) {
VLOG(3) << "Create Variable " << var.name();
auto* ptr = scope->Var(var.name());
CreateTensor(ptr, var.type());
VLOG(3) << "Create Variable " << var.name()
<< " global, which pointer is " << ptr;
} else {
VLOG(3) << "Create Variable " << var.name();
auto* ptr = local_scope.Var(var.name());
CreateTensor(ptr, var.type());
VLOG(3) << "Create Variable " << var.name()
<< " locally, which pointer is " << ptr;
}
}

for (auto& op_desc : block.ops()) {
auto op = paddle::framework::OpRegistry::CreateOp(
op_desc, const_cast<ProgramDesc*>(&pdesc));
VLOG(2) << op_desc.type();
auto op = create_op(pdesc, op_desc);
op->Run(local_scope, *device);
}

for (auto& var : block.vars()) {
std::set<std::string> name_to_print{"a", "b", "h_boot"};
if (!var.persistable() && name_to_print.count(var.name())) {
VLOG(2) << var.name();
auto* v = local_scope.Var(var.name());
const float* f = v->GetMutable<LoDTensor>()->data<float>();
const int64_t s = v->GetMutable<LoDTensor>()->numel();
for (int i = 0; i < s; ++i) {
VLOG(10) << f[i];
}
}
}

scope->DeleteScope(&local_scope);
}

Expand Down
11 changes: 11 additions & 0 deletions paddle/framework/op_desc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -270,11 +270,22 @@ void OpDescBind::InferShape(const BlockDescBind &block) const {
auto &infer_shape = OpInfoMap::Instance().Get(this->Type()).infer_shape_;
PADDLE_ENFORCE(static_cast<bool>(infer_shape),
"%s's infer_shape has not been registered", this->Type());
for (auto &i : Inputs()) {
for (auto &ii : i.second) {
VLOG(10) << "Input: " << i.first << " " << ii;
}
}
for (auto &o : Outputs()) {
for (auto &oo : o.second) {
VLOG(10) << "Output: " << o.first << " " << oo;
}
}
CompileTimeInferShapeContext ctx(*this, block);
infer_shape(&ctx);
}

void OpDescBind::InferVarType(BlockDescBind *block) const {
VLOG(3) << "CompileTime infer varType on " << Type();
auto &info = OpInfoMap::Instance().Get(this->Type());
if (info.infer_var_type_) {
info.infer_var_type_(*this, block);
Expand Down
2 changes: 2 additions & 0 deletions paddle/framework/operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -512,6 +512,7 @@ class RuntimeInferShapeContext : public InferShapeContext {

private:
DDim GetDim(const std::string& name) const override {
VLOG(10) << name;
Variable* var = scope_.FindVar(name);
if (var->IsType<LoDTensor>()) {
return var->Get<LoDTensor>().dims();
Expand All @@ -523,6 +524,7 @@ class RuntimeInferShapeContext : public InferShapeContext {
}

void SetDim(const std::string& name, const DDim& dim) override {
VLOG(10) << name;
Variable* var = scope_.FindVar(name);
if (var->IsType<LoDTensor>()) {
var->GetMutable<LoDTensor>()->Resize(dim);
Expand Down
106 changes: 80 additions & 26 deletions paddle/operators/recurrent_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,12 @@
#include <sstream>

#include "paddle/framework/op_registry.h"
#include "paddle/operators/net_op.h"

namespace paddle {
namespace operators {

using Scope = framework::Scope;
using Variable = framework::Variable;
using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor;

void RecurrentAlgorithm::Run(const Scope& scope,
Expand All @@ -41,10 +39,13 @@ void RecurrentAlgorithm::Run(const Scope& scope,
InitMemories(step_scopes[0]);

for (size_t step_id = 0; step_id < seq_len; step_id++) {
VLOG(4) << "step " << step_id << " run";
if (step_id > 0) {
rnn::LinkMemories(step_scopes, arg_->states, step_id, -1);
}
(*stepnet_)->Run(*step_scopes[step_id], dev_ctx);
for (auto& op : **stepnet_) {
op->Run(*step_scopes[step_id], dev_ctx);
}
}
rnn::ConcatOutputs(step_scopes, arg_->outlinks, seq_len, dev_ctx);
}
Expand All @@ -59,27 +60,15 @@ void RecurrentAlgorithm::CreateScopes(const Scope& scope,

// Now all variables in scope must be created outside of op.
PADDLE_ENFORCE_NOT_NULL(stepnet_);
PADDLE_ENFORCE(!(*stepnet_)->Outputs().empty(),
"step_unit_ op has no outputs");
PADDLE_ENFORCE_NOT_NULL(vars_);
PADDLE_ENFORCE(!(*stepnet_)->empty());

if (seq_len > step_scopes->size()) {
for (size_t i = step_scopes->size(); i < seq_len; ++i) {
auto& step_scope = scope.NewScope();

// create step net's temp inputs
for (auto& input : (*stepnet_)->Inputs()) {
// the weight are located in parent scope
for (auto& var_name : input.second) {
if (!step_scope.FindVar(var_name)) {
step_scope.Var(var_name)->GetMutable<LoDTensor>();
}
}
}
// create stepnet's outputs
for (const auto& output : (*stepnet_)->Outputs()) {
for (auto& var_name : output.second) {
step_scope.Var(var_name);
}
for (auto& var_name : *vars_) {
VLOG(5) << "step " << i << " create " << var_name;
step_scope.Var(var_name)->GetMutable<LoDTensor>();
}
step_scopes->emplace_back(&step_scope);
}
Expand Down Expand Up @@ -114,7 +103,7 @@ RecurrentOp::RecurrentOp(const std::string& type,
const framework::AttributeMap& attrs)
: OperatorBase(type, inputs, outputs, attrs) {
rnn::InitArgument(kArgName, &arg_, *this);
alg_.Init(&arg_, &stepnet_);
alg_.Init(&arg_, &stepnet_, &vars_);
}

class RecurrentAlgorithmProtoAndCheckerMaker
Expand All @@ -131,30 +120,42 @@ class RecurrentAlgorithmProtoAndCheckerMaker
AddInput(name.initial_states, "variables to initialize states.")
.AsDuplicable();

AddInput("parameters", "parameter variables used inside").AsDuplicable();

AddOutput(name.outlinks, "the outputs that need to concated for all steps.")
.AsDuplicable();
AddOutput(name.step_scopes, "step scopes");

// Attributes stored in AttributeMap
AddAttr<std::vector<std::string>>(name.ex_states, "names of pre-states");
AddAttr<std::vector<std::string>>(name.states, "names of states");
AddAttr<paddle::framework::BlockDesc*>("block_idx", "rnn block idx");

AddComment("This is a recurrent group operator.");
}
};

void RecurrentGradientAlgorithm::Run(
const Scope& scope, const platform::DeviceContext& dev_ctx) const {
VLOG(10) << "---------------------------";
auto* input0 = scope.FindVar(arg_->inlinks[0]);
VLOG(10) << "---------------------------";
PADDLE_ENFORCE_NOT_NULL(input0);
VLOG(10) << "---------------------------";
size_t seq_len = input0->GetMutable<LoDTensor>()->dims()[0];
VLOG(10) << "---------------------------";
auto& step_scopes = GetStepScopes(scope);
VLOG(10) << "---------------------------";
rnn::SegmentInputs(step_scopes, arg_->inlinks, seq_len);
VLOG(10) << "---------------------------";
for (int step_id = seq_len - 1; step_id >= 0; --step_id) {
VLOG(10) << "---------------------------";
if (static_cast<size_t>(step_id) != seq_len - 1) {
rnn::LinkMemories(step_scopes, arg_->states, step_id, 1);
}
(*stepnet_)->Run(*step_scopes[step_id], dev_ctx);
for (auto& op : **stepnet_) {
op->Run(*step_scopes[step_id], dev_ctx);
}
}
rnn::ConcatOutputs(step_scopes, arg_->outlinks, seq_len, dev_ctx);
LinkBootMemoryGradients(step_scopes[0]);
Expand All @@ -181,12 +182,65 @@ RecurrentGradientOp::RecurrentGradientOp(
const framework::AttributeMap& attrs)
: OperatorBase(type, inputs, outputs, attrs) {
rnn::InitArgument(kArgName, &arg_, *this, true /*is grad*/);
alg_.Init(&arg_, &stepnet_);
alg_.Init(&arg_, &stepnet_, &vars_);
}

class RecurrentGradOpDescMaker : public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
using OpDescBind = framework::OpDescBind;

protected:
virtual std::unique_ptr<OpDescBind> Apply() const {
auto* grad = new OpDescBind();
grad->SetType(this->GradOpType());

for (auto& input_param : this->InputNames()) {
grad->SetInput(input_param, this->Input(input_param));
grad->SetOutput(framework::GradVarName(input_param),
this->InputGrad(input_param));
}

for (auto& output_param : this->OutputNames()) {
if (output_param == "step_scopes") {
grad->SetInput(output_param, this->Output(output_param));
grad->SetInput(framework::GradVarName(output_param),
this->Output(output_param));
} else {
grad->SetInput(output_param, this->Output(output_param));
grad->SetInput(framework::GradVarName(output_param),
this->OutputGrad(output_param));
}
}

grad->SetAttrMap(this->Attrs());

return std::unique_ptr<OpDescBind>(grad);
}

virtual std::string GradOpType() const {
return this->ForwardOpType() + "_grad";
}
};

class RecurrentGradientOpShapeInference : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext* ctx) const override {
const auto& in_name = RecurrentOp::kArgName;
const auto& out_name = RecurrentGradientOp::kArgName;
PADDLE_ENFORCE(ctx->HasInput(in_name.inlinks));
PADDLE_ENFORCE(ctx->HasInput(in_name.outlinks));
PADDLE_ENFORCE(ctx->HasInput(out_name.inlinks));
PADDLE_ENFORCE(ctx->HasOutput(out_name.outlinks));
ctx->SetOutputDim(out_name.outlinks, ctx->GetInputDim(in_name.inlinks));
}
};

} // namespace operators
} // namespace paddle

REGISTER_OP(recurrent, paddle::operators::RecurrentOp,
paddle::operators::RecurrentAlgorithmProtoAndCheckerMaker,
recurrent_grad, paddle::operators::RecurrentGradientOp);
REGISTER_OPERATOR(recurrent, paddle::operators::RecurrentOp,
paddle::operators::RecurrentAlgorithmProtoAndCheckerMaker,
paddle::operators::RecurrentGradOpDescMaker);
REGISTER_OPERATOR(recurrent_grad, paddle::operators::RecurrentGradientOp,
paddle::operators::RecurrentGradientOpShapeInference);
Loading