Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
64 commits
Select commit Hold shift + click to select a range
c418dac
add rnn op interfaces
Superjomn Jul 7, 2017
6042795
add Run
Superjomn Jul 7, 2017
13d8ca9
rename state -> memory
Superjomn Jul 7, 2017
a645ae6
change state -> memory
Superjomn Jul 7, 2017
8640f96
make compilable
Superjomn Jul 8, 2017
d4cde51
add .cc
Superjomn Jul 8, 2017
6e99289
init test
Superjomn Jul 8, 2017
63b5841
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
qingqing01 Jul 8, 2017
08f69f6
Merge branch 'develop' of github.com:PaddlePaddle/Paddle into rnnimpl
Superjomn Jul 10, 2017
007ca1e
add op fake implementation
Superjomn Jul 10, 2017
2538b2f
add CreateStepNet and CreateScopes implementation.
qingqing01 Jul 10, 2017
5eb87f0
add TODO list
luotao1 Jul 10, 2017
4dcb02e
Merge branch 'rnnimpl' of github.com:Superjom/Paddle into rnnimpl
Superjomn Jul 10, 2017
ca53f3a
init memory attributes.
qingqing01 Jul 10, 2017
671cc26
Merge branch 'rnnimpl' of https://github.com/Superjom/Paddle into fea…
qingqing01 Jul 10, 2017
1e48cc8
add LinkMemories
Superjomn Jul 10, 2017
e0cbcd0
Merge branch 'rnnimpl' of github.com:Superjom/Paddle into rnnimpl
Superjomn Jul 10, 2017
f7916a6
add PlainNet fake implementation
Superjomn Jul 10, 2017
089c448
Use std::shared_ptr<Scope> in the OpRunContext.
qingqing01 Jul 10, 2017
bffd11e
add test
Superjomn Jul 10, 2017
c7947de
disable mutable_data
Superjomn Jul 10, 2017
94766b6
Merge branch 'rnnimpl' of github.com:Superjom/Paddle into rnnimpl
Superjomn Jul 10, 2017
6dca711
finist segmentInput function
luotao1 Jul 10, 2017
eabf1bf
Merge branch 'develop' of github.com:PaddlePaddle/Paddle into rnnimpl
Superjomn Jul 11, 2017
d210b0b
enable mutable_data with a trick
Superjomn Jul 11, 2017
6674fee
RNNOp test.
qingqing01 Jul 11, 2017
778ebb4
enable LinkMemories with mutable_data
Superjomn Jul 11, 2017
c60ed35
update
qingqing01 Jul 11, 2017
8642b27
update SegmentInput function with comments
luotao1 Jul 11, 2017
b0938ed
Merge branch 'develop' of github.com:PaddlePaddle/Paddle into rnnimpl
Superjomn Jul 11, 2017
3921fbb
Merge branch 'rnnimpl' of github.com:Superjom/Paddle into rnnimpl
Superjomn Jul 11, 2017
244fe51
create rnn op and step net in unit test.
qingqing01 Jul 11, 2017
020c189
Merge branch 'rnnimpl' of https://github.com/Superjom/Paddle into rnn…
luotao1 Jul 11, 2017
8e70b37
finish ConcatOutput function
luotao1 Jul 11, 2017
4150fa7
Merge branch 'rnnimpl' of github.com:Superjom/Paddle into rnnimpl
Superjomn Jul 12, 2017
1584414
Merge branch 'develop' of github.com:PaddlePaddle/Paddle into rnnimpl
Superjomn Jul 12, 2017
ce802c0
reformat inputs and attributes
Superjomn Jul 12, 2017
a883b4c
Refine unit test.
qingqing01 Jul 12, 2017
b98cae4
Merge branch 'rnnimpl' of https://github.com/Superjom/Paddle into fea…
qingqing01 Jul 12, 2017
a81be58
Refine unit test.
qingqing01 Jul 12, 2017
acde9b7
modify inlinks.
qingqing01 Jul 12, 2017
638384e
update from develop branch.
qingqing01 Jul 12, 2017
82464f5
add OpDesc to Net
Superjomn Jul 12, 2017
bbcc149
Merge branch 'netimpl' into rnnimpl
Superjomn Jul 12, 2017
c92ce74
Merge branch 'develop' into rnnimpl
luotao1 Jul 12, 2017
5c5d890
fix bug and update unit test.
qingqing01 Jul 12, 2017
522445b
resolve conflict.
qingqing01 Jul 12, 2017
01f20be
Merge branch 'rnnimpl' of github.com:Superjom/Paddle into rnnimpl
Superjomn Jul 12, 2017
08003de
Merge branch 'rnnimpl' of github.com:Superjom/Paddle into rnnimpl
Superjomn Jul 12, 2017
a6483e8
move step scopes from inputs to outputs
Superjomn Jul 12, 2017
7b1d123
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
qingqing01 Jul 12, 2017
bcd03bf
fix merge conflict, update SegmentInput function
luotao1 Jul 13, 2017
de319bb
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
qingqing01 Jul 14, 2017
0a4a502
Merge branch 'develop' into rnnimpl
luotao1 Jul 14, 2017
e64b5d3
add RecurrentOpProtoAndCheckerMaker.
qingqing01 Jul 14, 2017
e700bf6
Merge branch 'rnnimpl' of https://github.com/Superjom/Paddle into fea…
qingqing01 Jul 14, 2017
f525390
clean the codes
luotao1 Jul 14, 2017
3a27b02
Abstract GetStepScopes and GetMaxSeqLen function
luotao1 Jul 14, 2017
aede869
refine LinkMemories
luotao1 Jul 14, 2017
45682d2
Refine code and add some comments.
qingqing01 Jul 15, 2017
497c7ff
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
qingqing01 Jul 15, 2017
fc5acee
add backward core
Superjomn Jul 15, 2017
14dd843
update for develop branch.
qingqing01 Jul 15, 2017
3c15641
Merge branch 'rnnimpl' of https://github.com/Superjom/Paddle into fea…
qingqing01 Jul 15, 2017
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions paddle/framework/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ py_proto_compile(framework_py_proto SRCS attr_type.proto op_proto.proto op_desc.
# Generate an empty __init__.py to make framework_py_proto as a valid python module.
add_custom_target(framework_py_proto_init ALL COMMAND ${CMAKE_COMMAND} -E touch __init__.py)
add_dependencies(framework_py_proto framework_py_proto_init)
cc_library(recurrent_network_op SRCS recurrent_network_op.cc DEPS op_desc place)
cc_test(recurrent_network_op_test SRCS recurrent_network_op_test.cc DEPS recurrent_network_op glog gtest gflags ddim op_desc)

proto_library(net_proto SRCS net_proto.proto DEPS op_proto)
cc_library(net SRCS net.cc DEPS net_proto)
6 changes: 3 additions & 3 deletions paddle/framework/net.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#pragma once

#include "paddle/framework/net_proto.pb.h"
#include "paddle/framework/op_desc.pb.h"
#include "paddle/framework/op_proto.pb.h"
#include "paddle/framework/scope.h"
#include "paddle/platform/device_context.h"
Expand All @@ -31,7 +32,6 @@ typedef int OpIndex;
* keep updating if the concepts related are implemented.
*/

struct OpDesc;
struct OpAttrs {};

class Operator {
Expand Down Expand Up @@ -74,7 +74,7 @@ class Net {
/**
* @brief Add an Operator according to `def`.
*/
virtual OpIndex AddOp(const OpProto &def) = 0;
virtual OpIndex AddOp(const OpDesc &def) = 0;

/**
* @brief Add optimizer operators acctording to `attrs`.
Expand Down Expand Up @@ -129,7 +129,7 @@ class PlainNet : public Net {
/**
* @brief Add an operator to this network.
*/
virtual OpIndex AddOp(const OpProto &def) override;
virtual OpIndex AddOp(const OpProto &def);

/**
* @brief Add all optimizer operators related into the network.
Expand Down
248 changes: 248 additions & 0 deletions paddle/framework/recurrent_network_op.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,248 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/framework/recurrent_network_op.h"
#include "paddle/framework/tensor.h"

#include <glog/logging.h>
#include <cstring>

namespace paddle {
namespace framework {

void RecurrentOp::Run(OpContext* contex) const {
auto scope = contex->scope;

PADDLE_ENFORCE(scope->HasVariable(net_name_), "step net is not in scope.");
Variable* net = scope->GetVariable(net_name_);
PADDLE_ENFORCE(net, "failed to get step net");

LOG(INFO) << "create scopes";
CreateScopes(scope);
LOG(INFO) << "segment input";
SegmentInputs(scope);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

32和34行和53行的LOG可以都打在函数里面。

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

测试完后会全部删掉。


// forward
size_t max_seq_len = GetMaxSeqLen(scope);
LOG(INFO) << "sequence length " << max_seq_len;
auto step_scopes = GetStepScopes(scope);
for (size_t step_id = 0; step_id < max_seq_len; step_id++) {
LOG(INFO) << "run step " << step_id;
LinkMemories(step_scopes, step_id);

net->GetMutable<PlainNet>()->Run(step_scopes[step_id]);
}

LOG(INFO) << "concat outputs";
// prepare outputs
ConcatOutputs(scope);
}

void RecurrentOp::Init(const OpDesc& op_desc, AttributeMap& attrs) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Init是不带参数的,op_desc和attr都可以从Op成员变量中拿出来~小问题

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里会连到 OpBase.Run 里,

if (!is_inited) Init(...);

只是还没引入 OpBase

OperatorBase::Init(op_desc, attrs);

// set original inputs
for (const std::string& input : op_desc.inputs()) {
LOG(INFO) << "set input " << input;
inputs_.push_back(input);
}
// set original outputs
for (const std::string& output : op_desc.outputs()) {
LOG(INFO) << "set output " << output;
outputs_.push_back(output);
}

net_name_ = inputs_.at(GetAttr<int>("step_net"));
step_scopes_name_ = outputs_.back();

// prepare inlinks
PADDLE_ENFORCE(inlinks_.empty(), "RecurrentOp duplicate inited");
LOG(INFO) << "set inlinks";
for (auto id : GetAttr<std::vector<int>>("in_links")) {
inlinks_.push_back(inputs_[id]);
}
auto inlink_alias = GetAttr<std::vector<std::string>>("in_link_alias");
in_link_alias_ =
std::vector<std::string>{inlink_alias.begin(), inlink_alias.end()};
PADDLE_ENFORCE(inlinks_.size() == in_link_alias_.size(),
"in_links/in_link_alias mismatch.");

PADDLE_ENFORCE(
outputs_.size() > 1,
"more than 1 output should be provided and the last is `step_scopes`");
outlinks_ = std::vector<std::string>{outputs_.begin(), outputs_.end() - 1};

auto outlink_alias = GetAttr<std::vector<std::string>>("out_link_alias");
out_link_alias_ =
std::vector<std::string>{outlink_alias.begin(), outlink_alias.end()};
PADDLE_ENFORCE(outlinks_.size() == outlink_alias.size(),
"out_links/out_link_alias mismatch.");

// set memories
auto memories = GetAttr<std::vector<std::string>>("memories");
auto pre_memories = GetAttr<std::vector<std::string>>("pre_memories");
PADDLE_ENFORCE(memories.size() == pre_memories.size(),
"The size of memories and pre_memories doesn't match: %d,%d.",
memories.size(), pre_memories.size());

std::vector<std::string> boot_memories;
LOG(INFO) << "set boot_memories";
for (auto id : GetAttr<std::vector<int>>("boot_memories")) {
boot_memories.push_back(inputs_[id]);
}
PADDLE_ENFORCE(memories.size() == boot_memories.size(),
"the size of memories and boot_memories doesn't match: %d,%d",
memories.size(), boot_memories.size());
for (size_t i = 0; i < memories.size(); ++i) {
details::MemoryAttr mem_attr;
mem_attr.var = memories[i];
mem_attr.pre_var = pre_memories[i];
mem_attr.boot_var = boot_memories[i];
memory_attrs_.push_back(mem_attr);
LOG(INFO) << "set memorys:\t"
<< "memory:" << mem_attr.var << "\tboot:" << mem_attr.boot_var;
}
}

size_t RecurrentOp::GetMaxSeqLen(ScopePtr scope) const {
// TODO update this function when using variable-length of sequence.
return Input(scope, inlinks_[0])->GetMutable<Tensor>()->dims()[0];
}

void RecurrentOp::CreateScopes(ScopePtr scope) const {
size_t max_seq_len = GetMaxSeqLen(scope);
std::vector<ScopePtr>* step_scopes =
scope->GetVariable(step_scopes_name_)
->GetMutable<std::vector<ScopePtr>>();
// TODO Only two scopes are needed for inference, this case will be
// supported later.
if (max_seq_len > step_scopes->size()) {
for (size_t i = step_scopes->size(); i < max_seq_len; ++i) {
step_scopes->push_back(std::make_shared<Scope>(scope));
}
}
}

void RecurrentOp::SegmentInputs(ScopePtr scope) const {
PADDLE_ENFORCE(!inlinks_.empty(), "no in links are provided.");
auto step_scopes = GetStepScopes(scope);
size_t max_seq_len = GetMaxSeqLen(scope);
for (size_t i = 0; i < inlinks_.size(); ++i) {
Tensor* input_tensor = Input(scope, inlinks_[i])->GetMutable<Tensor>();
for (size_t j = 0; j < max_seq_len; j++) {
Variable* input_var = step_scopes[j]->CreateVariable(in_link_alias_[i]);
Tensor* step_input_tensor = input_var->GetMutable<Tensor>();
*step_input_tensor = input_tensor->Slice<float>(j, j + 1);
// TODO (luotao1): use reshape function to decrease the dims of
// step_input_tensor.
}
}
}

void RecurrentOp::ConcatOutputs(ScopePtr scope) const {
auto step_scopes = GetStepScopes(scope);
size_t max_seq_len = GetMaxSeqLen(scope);
// TODO (luotao1): update using CopyFrom function in tensor.
auto dims = Input(scope, inlinks_[0])->GetMutable<Tensor>()->dims();
int batch_size = dims[1];
for (size_t i = 0; i < outlinks_.size(); i++) {
auto output_dims = step_scopes[0]
->GetVariable(out_link_alias_[0])
->GetMutable<Tensor>()
->dims();
int output_dim = output_dims[1];
int length = batch_size * output_dim;
Tensor* output_tensor =
scope->CreateVariable(outlinks_[i])->GetMutable<Tensor>();
float* output = output_tensor->mutable_data<float>(
make_ddim({(int)max_seq_len, batch_size, output_dim}),
platform::CPUPlace());
for (size_t j = 0; j < max_seq_len; j++) {
Variable* output_var = step_scopes[j]->GetVariable(out_link_alias_[i]);
const float* step_output =
output_var->GetMutable<Tensor>()->data<float>();
std::memcpy(output + j * length, step_output, length);
}
}
}

void RecurrentOp::LinkMemories(std::vector<ScopePtr>& step_scopes,
size_t step_id) const {
PADDLE_ENFORCE(step_id < step_scopes.size(),
"step [%d] out of range of step scopes' size [%d]", step_id,
step_scopes.size());
ScopePtr step_scope = step_scopes[step_id];
for (auto& attr : memory_attrs_) {
Tensor* pre_memory_tensor =
step_scope->CreateVariable(attr.pre_var)->GetMutable<Tensor>();

if (step_id == 0) {
PADDLE_ENFORCE(step_scope->HasVariable(attr.boot_var),
"memory [%s]'s boot variable [%s] not exists", attr.var,
attr.boot_var);
Tensor* boot_tensor =
step_scope->CreateVariable(attr.boot_var)->GetMutable<Tensor>();
PADDLE_ENFORCE(boot_tensor, "boot_tensor should be retrieved before");
// copy from boot memory
pre_memory_tensor->ShareDataFrom<float>(*boot_tensor);
} else {
// copy from previous step scope's memory to this scope's
// `pre - memory`
Tensor* pre_step_memory =
step_scopes[step_id - 1]->GetVariable(attr.var)->GetMutable<Tensor>();
pre_memory_tensor->ShareDataFrom<float>(*pre_step_memory);
}

// TODO the memory of current step should be allocated in step net
Tensor* cur_memory_tensor =
step_scopes[step_id]->CreateVariable(attr.var)->GetMutable<Tensor>();
cur_memory_tensor->mutable_data<float>(pre_memory_tensor->dims(),
platform::CPUPlace());
}
}

// TODO testing when including operator.h

// class RecurrentOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
// public:
// RecurrentOpProtoAndCheckerMaker(OpProto* proto, OpAttrChecker* op_checker)
// : OpProtoAndCheckerMaker(proto, op_checker) {
// // AddInput("input", "input of test op"); // need to support dynamic
// number
// // AddOutput("output", "output of test op"); // need to support dynamic
// number
// AddAttr<std::std::vector<int>>("in_links", "The input link positions in
// the all inputs.")
// .SetDefault({0});
// AddAttr<std::std::vector<int>>("boot_memories", "The initial memory
// positions in the all inputs.");
// AddAttr<int>("step_net", "The step net position in the all inputs.");
//
// AddAttr<std::std::vector<std::string>>("in_link_alias", "The input link
// alias in the step network.");
// AddAttr<std::std::vector<std::string>>("out_link_alias", "The output link
// alias in the step network.");
// AddAttr<std::std::vector<std::string>>("memories", "The memory names.");
// AddAttr<std::std::vector<std::string>>("pre_memories", "The
// history/previous memory names.");
//
// AddType("recurrent_op");
// AddComment("This is a recurrent group operator.");
// }
// };
//
// REGISTER_OP(recurrent_op, RecurrentOp, RecurrentOpProtoAndCheckerMaker);

} // namespace framework
} // namespace paddle
Loading