-
Notifications
You must be signed in to change notification settings - Fork 5.9k
add rnn op interfaces #2775
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
Closed
add rnn op interfaces #2775
Changes from all commits
Commits
Show all changes
64 commits
Select commit
Hold shift + click to select a range
c418dac
add rnn op interfaces
Superjomn 6042795
add Run
Superjomn 13d8ca9
rename state -> memory
Superjomn a645ae6
change state -> memory
Superjomn 8640f96
make compilable
Superjomn d4cde51
add .cc
Superjomn 6e99289
init test
Superjomn 63b5841
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
qingqing01 08f69f6
Merge branch 'develop' of github.com:PaddlePaddle/Paddle into rnnimpl
Superjomn 007ca1e
add op fake implementation
Superjomn 2538b2f
add CreateStepNet and CreateScopes implementation.
qingqing01 5eb87f0
add TODO list
luotao1 4dcb02e
Merge branch 'rnnimpl' of github.com:Superjom/Paddle into rnnimpl
Superjomn ca53f3a
init memory attributes.
qingqing01 671cc26
Merge branch 'rnnimpl' of https://github.com/Superjom/Paddle into fea…
qingqing01 1e48cc8
add LinkMemories
Superjomn e0cbcd0
Merge branch 'rnnimpl' of github.com:Superjom/Paddle into rnnimpl
Superjomn f7916a6
add PlainNet fake implementation
Superjomn 089c448
Use std::shared_ptr<Scope> in the OpRunContext.
qingqing01 bffd11e
add test
Superjomn c7947de
disable mutable_data
Superjomn 94766b6
Merge branch 'rnnimpl' of github.com:Superjom/Paddle into rnnimpl
Superjomn 6dca711
finist segmentInput function
luotao1 eabf1bf
Merge branch 'develop' of github.com:PaddlePaddle/Paddle into rnnimpl
Superjomn d210b0b
enable mutable_data with a trick
Superjomn 6674fee
RNNOp test.
qingqing01 778ebb4
enable LinkMemories with mutable_data
Superjomn c60ed35
update
qingqing01 8642b27
update SegmentInput function with comments
luotao1 b0938ed
Merge branch 'develop' of github.com:PaddlePaddle/Paddle into rnnimpl
Superjomn 3921fbb
Merge branch 'rnnimpl' of github.com:Superjom/Paddle into rnnimpl
Superjomn 244fe51
create rnn op and step net in unit test.
qingqing01 020c189
Merge branch 'rnnimpl' of https://github.com/Superjom/Paddle into rnn…
luotao1 8e70b37
finish ConcatOutput function
luotao1 4150fa7
Merge branch 'rnnimpl' of github.com:Superjom/Paddle into rnnimpl
Superjomn 1584414
Merge branch 'develop' of github.com:PaddlePaddle/Paddle into rnnimpl
Superjomn ce802c0
reformat inputs and attributes
Superjomn a883b4c
Refine unit test.
qingqing01 b98cae4
Merge branch 'rnnimpl' of https://github.com/Superjom/Paddle into fea…
qingqing01 a81be58
Refine unit test.
qingqing01 acde9b7
modify inlinks.
qingqing01 638384e
update from develop branch.
qingqing01 82464f5
add OpDesc to Net
Superjomn bbcc149
Merge branch 'netimpl' into rnnimpl
Superjomn c92ce74
Merge branch 'develop' into rnnimpl
luotao1 5c5d890
fix bug and update unit test.
qingqing01 522445b
resolve conflict.
qingqing01 01f20be
Merge branch 'rnnimpl' of github.com:Superjom/Paddle into rnnimpl
Superjomn 08003de
Merge branch 'rnnimpl' of github.com:Superjom/Paddle into rnnimpl
Superjomn a6483e8
move step scopes from inputs to outputs
Superjomn 7b1d123
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
qingqing01 bcd03bf
fix merge conflict, update SegmentInput function
luotao1 de319bb
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
qingqing01 0a4a502
Merge branch 'develop' into rnnimpl
luotao1 e64b5d3
add RecurrentOpProtoAndCheckerMaker.
qingqing01 e700bf6
Merge branch 'rnnimpl' of https://github.com/Superjom/Paddle into fea…
qingqing01 f525390
clean the codes
luotao1 3a27b02
Abstract GetStepScopes and GetMaxSeqLen function
luotao1 aede869
refine LinkMemories
luotao1 45682d2
Refine code and add some comments.
qingqing01 497c7ff
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
qingqing01 fc5acee
add backward core
Superjomn 14dd843
update for develop branch.
qingqing01 3c15641
Merge branch 'rnnimpl' of https://github.com/Superjom/Paddle into fea…
qingqing01 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,248 @@ | ||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. */ | ||
|
||
#include "paddle/framework/recurrent_network_op.h" | ||
#include "paddle/framework/tensor.h" | ||
|
||
#include <glog/logging.h> | ||
#include <cstring> | ||
|
||
namespace paddle { | ||
namespace framework { | ||
|
||
void RecurrentOp::Run(OpContext* contex) const { | ||
auto scope = contex->scope; | ||
|
||
PADDLE_ENFORCE(scope->HasVariable(net_name_), "step net is not in scope."); | ||
Variable* net = scope->GetVariable(net_name_); | ||
PADDLE_ENFORCE(net, "failed to get step net"); | ||
|
||
LOG(INFO) << "create scopes"; | ||
CreateScopes(scope); | ||
LOG(INFO) << "segment input"; | ||
SegmentInputs(scope); | ||
|
||
// forward | ||
size_t max_seq_len = GetMaxSeqLen(scope); | ||
LOG(INFO) << "sequence length " << max_seq_len; | ||
auto step_scopes = GetStepScopes(scope); | ||
for (size_t step_id = 0; step_id < max_seq_len; step_id++) { | ||
LOG(INFO) << "run step " << step_id; | ||
LinkMemories(step_scopes, step_id); | ||
|
||
net->GetMutable<PlainNet>()->Run(step_scopes[step_id]); | ||
} | ||
|
||
LOG(INFO) << "concat outputs"; | ||
// prepare outputs | ||
ConcatOutputs(scope); | ||
} | ||
|
||
void RecurrentOp::Init(const OpDesc& op_desc, AttributeMap& attrs) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Init是不带参数的,op_desc和attr都可以从Op成员变量中拿出来~小问题 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里会连到 OpBase.Run 里, if (!is_inited) Init(...); 只是还没引入 |
||
OperatorBase::Init(op_desc, attrs); | ||
|
||
// set original inputs | ||
for (const std::string& input : op_desc.inputs()) { | ||
LOG(INFO) << "set input " << input; | ||
inputs_.push_back(input); | ||
} | ||
// set original outputs | ||
for (const std::string& output : op_desc.outputs()) { | ||
LOG(INFO) << "set output " << output; | ||
outputs_.push_back(output); | ||
} | ||
|
||
net_name_ = inputs_.at(GetAttr<int>("step_net")); | ||
step_scopes_name_ = outputs_.back(); | ||
|
||
// prepare inlinks | ||
PADDLE_ENFORCE(inlinks_.empty(), "RecurrentOp duplicate inited"); | ||
LOG(INFO) << "set inlinks"; | ||
for (auto id : GetAttr<std::vector<int>>("in_links")) { | ||
inlinks_.push_back(inputs_[id]); | ||
} | ||
auto inlink_alias = GetAttr<std::vector<std::string>>("in_link_alias"); | ||
in_link_alias_ = | ||
std::vector<std::string>{inlink_alias.begin(), inlink_alias.end()}; | ||
PADDLE_ENFORCE(inlinks_.size() == in_link_alias_.size(), | ||
"in_links/in_link_alias mismatch."); | ||
|
||
PADDLE_ENFORCE( | ||
outputs_.size() > 1, | ||
"more than 1 output should be provided and the last is `step_scopes`"); | ||
outlinks_ = std::vector<std::string>{outputs_.begin(), outputs_.end() - 1}; | ||
|
||
auto outlink_alias = GetAttr<std::vector<std::string>>("out_link_alias"); | ||
out_link_alias_ = | ||
std::vector<std::string>{outlink_alias.begin(), outlink_alias.end()}; | ||
PADDLE_ENFORCE(outlinks_.size() == outlink_alias.size(), | ||
"out_links/out_link_alias mismatch."); | ||
|
||
// set memories | ||
auto memories = GetAttr<std::vector<std::string>>("memories"); | ||
auto pre_memories = GetAttr<std::vector<std::string>>("pre_memories"); | ||
PADDLE_ENFORCE(memories.size() == pre_memories.size(), | ||
"The size of memories and pre_memories doesn't match: %d,%d.", | ||
memories.size(), pre_memories.size()); | ||
|
||
std::vector<std::string> boot_memories; | ||
LOG(INFO) << "set boot_memories"; | ||
for (auto id : GetAttr<std::vector<int>>("boot_memories")) { | ||
boot_memories.push_back(inputs_[id]); | ||
} | ||
PADDLE_ENFORCE(memories.size() == boot_memories.size(), | ||
"the size of memories and boot_memories doesn't match: %d,%d", | ||
memories.size(), boot_memories.size()); | ||
for (size_t i = 0; i < memories.size(); ++i) { | ||
details::MemoryAttr mem_attr; | ||
mem_attr.var = memories[i]; | ||
mem_attr.pre_var = pre_memories[i]; | ||
mem_attr.boot_var = boot_memories[i]; | ||
memory_attrs_.push_back(mem_attr); | ||
LOG(INFO) << "set memorys:\t" | ||
<< "memory:" << mem_attr.var << "\tboot:" << mem_attr.boot_var; | ||
} | ||
} | ||
|
||
size_t RecurrentOp::GetMaxSeqLen(ScopePtr scope) const { | ||
// TODO update this function when using variable-length of sequence. | ||
return Input(scope, inlinks_[0])->GetMutable<Tensor>()->dims()[0]; | ||
} | ||
|
||
void RecurrentOp::CreateScopes(ScopePtr scope) const { | ||
size_t max_seq_len = GetMaxSeqLen(scope); | ||
std::vector<ScopePtr>* step_scopes = | ||
scope->GetVariable(step_scopes_name_) | ||
->GetMutable<std::vector<ScopePtr>>(); | ||
// TODO Only two scopes are needed for inference, this case will be | ||
// supported later. | ||
if (max_seq_len > step_scopes->size()) { | ||
for (size_t i = step_scopes->size(); i < max_seq_len; ++i) { | ||
step_scopes->push_back(std::make_shared<Scope>(scope)); | ||
} | ||
} | ||
} | ||
|
||
void RecurrentOp::SegmentInputs(ScopePtr scope) const { | ||
PADDLE_ENFORCE(!inlinks_.empty(), "no in links are provided."); | ||
auto step_scopes = GetStepScopes(scope); | ||
size_t max_seq_len = GetMaxSeqLen(scope); | ||
for (size_t i = 0; i < inlinks_.size(); ++i) { | ||
Tensor* input_tensor = Input(scope, inlinks_[i])->GetMutable<Tensor>(); | ||
for (size_t j = 0; j < max_seq_len; j++) { | ||
Variable* input_var = step_scopes[j]->CreateVariable(in_link_alias_[i]); | ||
Tensor* step_input_tensor = input_var->GetMutable<Tensor>(); | ||
*step_input_tensor = input_tensor->Slice<float>(j, j + 1); | ||
// TODO (luotao1): use reshape function to decrease the dims of | ||
// step_input_tensor. | ||
} | ||
} | ||
} | ||
|
||
void RecurrentOp::ConcatOutputs(ScopePtr scope) const { | ||
auto step_scopes = GetStepScopes(scope); | ||
size_t max_seq_len = GetMaxSeqLen(scope); | ||
// TODO (luotao1): update using CopyFrom function in tensor. | ||
auto dims = Input(scope, inlinks_[0])->GetMutable<Tensor>()->dims(); | ||
int batch_size = dims[1]; | ||
for (size_t i = 0; i < outlinks_.size(); i++) { | ||
auto output_dims = step_scopes[0] | ||
->GetVariable(out_link_alias_[0]) | ||
->GetMutable<Tensor>() | ||
->dims(); | ||
int output_dim = output_dims[1]; | ||
int length = batch_size * output_dim; | ||
Tensor* output_tensor = | ||
scope->CreateVariable(outlinks_[i])->GetMutable<Tensor>(); | ||
float* output = output_tensor->mutable_data<float>( | ||
make_ddim({(int)max_seq_len, batch_size, output_dim}), | ||
platform::CPUPlace()); | ||
for (size_t j = 0; j < max_seq_len; j++) { | ||
Variable* output_var = step_scopes[j]->GetVariable(out_link_alias_[i]); | ||
const float* step_output = | ||
output_var->GetMutable<Tensor>()->data<float>(); | ||
std::memcpy(output + j * length, step_output, length); | ||
} | ||
} | ||
} | ||
|
||
void RecurrentOp::LinkMemories(std::vector<ScopePtr>& step_scopes, | ||
size_t step_id) const { | ||
PADDLE_ENFORCE(step_id < step_scopes.size(), | ||
"step [%d] out of range of step scopes' size [%d]", step_id, | ||
step_scopes.size()); | ||
ScopePtr step_scope = step_scopes[step_id]; | ||
for (auto& attr : memory_attrs_) { | ||
Tensor* pre_memory_tensor = | ||
step_scope->CreateVariable(attr.pre_var)->GetMutable<Tensor>(); | ||
|
||
if (step_id == 0) { | ||
PADDLE_ENFORCE(step_scope->HasVariable(attr.boot_var), | ||
"memory [%s]'s boot variable [%s] not exists", attr.var, | ||
attr.boot_var); | ||
Tensor* boot_tensor = | ||
step_scope->CreateVariable(attr.boot_var)->GetMutable<Tensor>(); | ||
PADDLE_ENFORCE(boot_tensor, "boot_tensor should be retrieved before"); | ||
// copy from boot memory | ||
pre_memory_tensor->ShareDataFrom<float>(*boot_tensor); | ||
} else { | ||
// copy from previous step scope's memory to this scope's | ||
// `pre - memory` | ||
Tensor* pre_step_memory = | ||
step_scopes[step_id - 1]->GetVariable(attr.var)->GetMutable<Tensor>(); | ||
pre_memory_tensor->ShareDataFrom<float>(*pre_step_memory); | ||
} | ||
|
||
// TODO the memory of current step should be allocated in step net | ||
Tensor* cur_memory_tensor = | ||
step_scopes[step_id]->CreateVariable(attr.var)->GetMutable<Tensor>(); | ||
cur_memory_tensor->mutable_data<float>(pre_memory_tensor->dims(), | ||
platform::CPUPlace()); | ||
} | ||
} | ||
|
||
// TODO testing when including operator.h | ||
|
||
// class RecurrentOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker { | ||
// public: | ||
// RecurrentOpProtoAndCheckerMaker(OpProto* proto, OpAttrChecker* op_checker) | ||
// : OpProtoAndCheckerMaker(proto, op_checker) { | ||
// // AddInput("input", "input of test op"); // need to support dynamic | ||
// number | ||
// // AddOutput("output", "output of test op"); // need to support dynamic | ||
// number | ||
// AddAttr<std::std::vector<int>>("in_links", "The input link positions in | ||
// the all inputs.") | ||
// .SetDefault({0}); | ||
// AddAttr<std::std::vector<int>>("boot_memories", "The initial memory | ||
// positions in the all inputs."); | ||
// AddAttr<int>("step_net", "The step net position in the all inputs."); | ||
// | ||
// AddAttr<std::std::vector<std::string>>("in_link_alias", "The input link | ||
// alias in the step network."); | ||
// AddAttr<std::std::vector<std::string>>("out_link_alias", "The output link | ||
// alias in the step network."); | ||
// AddAttr<std::std::vector<std::string>>("memories", "The memory names."); | ||
// AddAttr<std::std::vector<std::string>>("pre_memories", "The | ||
// history/previous memory names."); | ||
// | ||
// AddType("recurrent_op"); | ||
// AddComment("This is a recurrent group operator."); | ||
// } | ||
// }; | ||
// | ||
// REGISTER_OP(recurrent_op, RecurrentOp, RecurrentOpProtoAndCheckerMaker); | ||
|
||
} // namespace framework | ||
} // namespace paddle |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
32和34行和53行的LOG可以都打在函数里面。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
测试完后会全部删掉。