Skip to content

remove Runtime InferShape for cond op #4518

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Sep 30, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
242 changes: 123 additions & 119 deletions paddle/operators/cond_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,7 @@ limitations under the License. */

#include "paddle/operators/cond_op.h"

#include <cstring>
#include <sstream>

#include "paddle/framework/op_registry.h"
#include "paddle/operators/gather.h"
#include "paddle/operators/net_op.h"
#include "paddle/operators/scatter.h"

namespace paddle {
Expand All @@ -31,175 +26,184 @@ using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor;
using DDim = framework::DDim;

void CondOp::CreateScope(const Scope& scope) const {
framework::Scope& CondOp::AddSubScope(const Scope& scope) const {
auto sub_scopes_var = scope.FindVar("SubScopes");
PADDLE_ENFORCE_NOT_NULL(sub_scopes_var,
"Output(SubScopes) of CondOp should not be null.");
auto sub_scopes = sub_scopes_var->GetMutable<std::vector<Scope*>>();
auto& sub_scope = scope.NewScope();
sub_scopes->push_back(&sub_scope);
return sub_scope;
}

void CondOp::CreateIndexTensor(const Scope& scope) const {
std::vector<framework::Scope*>& CondOp::GetSubScopes(
const framework::Scope& scope) const {
auto sub_scopes_var = scope.FindVar("SubScopes");
PADDLE_ENFORCE_NOT_NULL(sub_scopes_var,
"Output(SubScopes) of CondOp should not be null.");
return *sub_scopes_var->GetMutable<std::vector<framework::Scope*>>();
}

LoDTensor& CondOp::AddIndexTensor(const Scope& scope) const {
auto index_tensors_var = scope.FindVar("IndexTensors");
PADDLE_ENFORCE_NOT_NULL(index_tensors_var,
"Output(IndexTensors) of CondOp should not be null.");
auto& index_tensors =
*index_tensors_var->GetMutable<std::vector<LoDTensor>>();
index_tensors.push_back(LoDTensor());
return index_tensors.back();
}

void CondOp::InferShape(const Scope& scope) const {
auto sub_scopes_var = scope.FindVar("SubScopes");
PADDLE_ENFORCE_NOT_NULL(sub_scopes_var,
"Output(SubScopes) of CondOp should not be null.");
auto& sub_scopes = *sub_scopes_var->GetMutable<std::vector<Scope*>>();

for (int i = 0; i < 2; ++i) {
// Create two sub scopes for true and false branches
// sub_scopes[0] for the true branch and sub_scopes[1] for the false
// branch
CreateScope(scope);

// Create two tensors for true and false indices
// index_tensors[0] for the true branch and index_tensors[1] for the false
// branch
CreateIndexTensor(scope);

PADDLE_ENFORCE(!Inputs("Xs").empty(),
"Inputs(Xs) of CondOp can't be empty.");
for (auto& input : Inputs("Xs")) {
// Create a new tensor in sub-scope for input-type tensor
Variable* v = sub_scopes[i]->NewVar(input);
LoDTensor* sub_input = v->GetMutable<LoDTensor>();
sub_input->Resize(scope.FindVar(input)->GetMutable<LoDTensor>()->dims());
}

for (auto& output : (*sub_net_op_[i]).Outputs()) {
for (auto& var_name : output.second) {
sub_scopes[i]->NewVar(var_name);
}
}

// each net calls InferShape
// sub_net_op_[i]->InferShape(*sub_scopes[i]);
}

for (auto& output : Outputs("Outs")) {
LoDTensor* tensor_t_out =
sub_scopes[0]->FindVar(output)->GetMutable<LoDTensor>();
PADDLE_ENFORCE_NOT_NULL(tensor_t_out, "True output should not be NULL");
LoDTensor* tensor_f_out =
sub_scopes[1]->FindVar(output)->GetMutable<LoDTensor>();
PADDLE_ENFORCE_NOT_NULL(tensor_f_out, "False output should not be NULL");

auto* tensor_out_var = scope.FindVar(output);
PADDLE_ENFORCE_NOT_NULL(tensor_out_var, "Output not found");
LoDTensor* tensor_out = tensor_out_var->GetMutable<LoDTensor>();
PADDLE_ENFORCE_NOT_NULL(tensor_t_out,
"True output tensor should not be NULL");

// check output size should be same
PADDLE_ENFORCE_EQ(tensor_t_out->dims(), tensor_f_out->dims(),
"Outputs not of the same shape");
tensor_out->Resize(tensor_t_out->dims());
// tensor_out->mutable_data<float>(tensor_out->dims(),
// platform::CPUPlace());
tensor_out->mutable_data<float>(platform::CPUPlace());
}
}

void CondOp::Run(const Scope& scope,
const platform::DeviceContext& dev_ctx) const {
auto* sub_scopes_var = scope.FindVar("SubScopes");
PADDLE_ENFORCE_NOT_NULL(sub_scopes_var,
"Output(SubScopes) of CondOp should not be null.");
auto sub_scopes = sub_scopes_var->Get<std::vector<Scope*>>();
std::vector<framework::LoDTensor>& CondOp::GetIndexTensors(
const framework::Scope& scope) const {
auto* index_tensors_var = scope.FindVar("IndexTensors");
PADDLE_ENFORCE_NOT_NULL(index_tensors_var,
"Output(IndexTensors) of CondOp should not be null.");
auto index_tensors = index_tensors_var->Get<std::vector<LoDTensor>>();
return *index_tensors_var->GetMutable<std::vector<framework::LoDTensor>>();
}

std::string cond_name = Input("Cond");
Variable* cond_var = scope.FindVar(cond_name);
void CondOp::PrepareDataForSubnet(
const framework::Scope& scope,
const platform::DeviceContext& dev_ctx) const {
PADDLE_ENFORCE(!Inputs("Xs").empty(), "Inputs(Xs) of CondOp can't be empty.");

for (int i = 0; i < BRANCH_NUM; ++i) {
// Create two sub scopes for true and false branches
// sub_scopes[0] for the true branch
// sub_scopes[1] for the false branch
AddSubScope(scope);
// Create two tensors for true and false indices:
// index_tensors[0] for the true branch
// index_tensors[1] for the false branch
AddIndexTensor(scope);
}

Variable* cond_var = scope.FindVar(Input("Cond"));
PADDLE_ENFORCE_NOT_NULL(cond_var,
"Input(Cond) of CondOp should not be null.");
const LoDTensor* cond = cond_var->GetMutable<LoDTensor>();

// Step 1: get the true/false index at runtime
// index_[0]: vector<int>, contains all index for cond[i] == true
// index_[1]: vector<int>, contains all index for cond[i] == false
for (int i = 0; i < 2; ++i) index_[i].clear();
// get the true/false index at runtime according to cond tensor
// index_vectors[0]: vector<int>, contains all index for cond[i] == true
// index_vectors[1]: vector<int>, contains all index for cond[i] == false
std::vector<std::vector<int>> index_vectors;
index_vectors.resize(BRANCH_NUM);

const int* cond_data = cond->data<int>();
for (int i = 0; i < cond->dims()[0]; ++i) {
if (cond_data[i])
index_[0].push_back(i);
index_vectors[TRUE_BRANCH].push_back(i);
else
index_[1].push_back(i);
index_vectors[FALSE_BRANCH].push_back(i);
}

// put index_[0] and index_[1] into two tensors:
// index_tensor_[0] and index_tensor_[1]
DDim dim = paddle::framework::make_ddim({0});
for (int i = 0; i < 2; ++i) {
dim[0] = index_[i].size();
int* tmp_ptr =
// put index_vectors[0] and index_vectors[1] into two tensors:
// index_tensors[0] and index_tensors[1]
std::vector<framework::LoDTensor>& index_tensors = GetIndexTensors(scope);
std::vector<framework::Scope*>& sub_scopes = GetSubScopes(scope);

for (int i = 0; i < BRANCH_NUM; ++i) {
DDim dim = {static_cast<int64_t>(index_vectors[i].size())};
int* index_tensor_data_ptr =
index_tensors[i].mutable_data<int>(dim, platform::CPUPlace());
index_tensors[i].Resize(dim);
memcpy(tmp_ptr, index_[i].data(), dim[0] * sizeof(int));
memcpy(index_tensor_data_ptr, index_vectors[i].data(),
dim[0] * sizeof(int));
}

// Step 2: collect data by calling gather
for (int i = 0; i < 2; ++i) {
// i= 0/i for True and False branches respectively
for (auto& input : Inputs("Xs")) {
// find Tensor
Variable* v = scope.FindVar(input);
PADDLE_ENFORCE_NOT_NULL(v);
LoDTensor* tensor_parent = v->GetMutable<LoDTensor>();
// create input in subscopes according to index_vectors
for (auto& input : Inputs("Xs")) {
Variable* var_parent = scope.FindVar(input);
PADDLE_ENFORCE_NOT_NULL(var_parent);
const auto* tensor_parent = &var_parent->Get<LoDTensor>();

v = sub_scopes[i]->FindVar(input);
PADDLE_ENFORCE_NOT_NULL(v);
LoDTensor* tensor_child = v->GetMutable<LoDTensor>();
for (int i = 0; i < BRANCH_NUM; ++i) {
Variable* var_child = sub_scopes[i]->FindVar(input);
PADDLE_ENFORCE_NOT_NULL(var_child);
auto* tensor_child = var_child->GetMutable<LoDTensor>();

// Resize child
DDim dim = tensor_child->dims();
dim[0] = index_[i].size();
tensor_child->Resize(dim);
DDim dim = tensor_parent->dims();
dim[0] = index_tensors[i].dims()[0];
tensor_child->mutable_data<float>(dim, platform::CPUPlace());

Gather<float>(dev_ctx.GetPlace(), tensor_parent, &index_tensors[i],
tensor_child);
}
}

// Step 3: run
for (int i = 0; i < 2; ++i) {
sub_net_op_[i]->Run(*sub_scopes[i], dev_ctx);
// create output_tensors in subscope for sub_net
for (int i = 0; i < BRANCH_NUM; ++i) {
for (auto& output : (*sub_net_op_[i]).Outputs()) {
for (auto& var_name : output.second) {
sub_scopes[i]->NewVar(var_name);
}
}
}
}

// Step 4: merge output results
void CondOp::MergeDataFromSubnet(const framework::Scope& scope,
const platform::DeviceContext& dev_ctx) const {
std::vector<framework::Scope*>& sub_scopes = GetSubScopes(scope);
const std::vector<framework::LoDTensor>& index_tensors =
GetIndexTensors(scope);

// Infer the output dim, out_dim[0] = true_dim[0] + false_dim[0]
PADDLE_ENFORCE(!Outputs("Outs").empty(),
"Outputs(Outs) of CondOp can't be empty.");
for (int i = 0; i < 2; ++i) {
// i= 0/i for True and False branches respectively
for (auto& output : Outputs("Outs")) {
// find Tensor
Variable* v = scope.FindVar(output);
PADDLE_ENFORCE_NOT_NULL(v);
LoDTensor* tensor_parent = v->GetMutable<LoDTensor>();

v = sub_scopes[i]->FindVar(output);
PADDLE_ENFORCE_NOT_NULL(v);
LoDTensor* tensor_child = v->GetMutable<LoDTensor>();
for (auto& output : Outputs("Outs")) {
const LoDTensor* tensor_t_out =
&sub_scopes[TRUE_BRANCH]->FindVar(output)->Get<LoDTensor>();
PADDLE_ENFORCE_NOT_NULL(tensor_t_out, "True output should not be NULL");
const LoDTensor* tensor_f_out =
&sub_scopes[FALSE_BRANCH]->FindVar(output)->Get<LoDTensor>();
PADDLE_ENFORCE_NOT_NULL(tensor_f_out, "False output should not be NULL");

auto* var_out = scope.FindVar(output);
PADDLE_ENFORCE_NOT_NULL(var_out, "Output not found");
LoDTensor* tensor_out = var_out->GetMutable<LoDTensor>();
PADDLE_ENFORCE_NOT_NULL(tensor_t_out,
"True output tensor should not be NULL");

DDim true_dim = tensor_t_out->dims();
DDim false_dim = tensor_f_out->dims();
true_dim[0] = 0;
false_dim[0] = 0;
PADDLE_ENFORCE_EQ(true_dim, false_dim,
"Outputs not of the same shape except the first dim");

DDim out_dim = tensor_t_out->dims();
out_dim[0] = tensor_t_out->dims()[0] + tensor_f_out->dims()[0];
tensor_out->Resize(out_dim);
tensor_out->mutable_data<float>(platform::CPUPlace());
}

// merge output results:
// output_tensor = true_output_tensor + false_output_tensor
for (auto& output : Outputs("Outs")) {
Variable* var_parent = scope.FindVar(output);
PADDLE_ENFORCE_NOT_NULL(var_parent);
auto* tensor_parent = var_parent->GetMutable<LoDTensor>();

for (int i = 0; i < BRANCH_NUM; ++i) {
Variable* var_child = sub_scopes[i]->FindVar(output);
PADDLE_ENFORCE_NOT_NULL(var_child);
auto* tensor_child = &var_child->Get<LoDTensor>();
ScatterUpdate<float>(dev_ctx.GetPlace(), tensor_child, &index_tensors[i],
tensor_parent);
}
}
}

void CondOp::Run(const Scope& scope,
const platform::DeviceContext& dev_ctx) const {
PrepareDataForSubnet(scope, dev_ctx);
std::vector<framework::Scope*>& sub_scopes = GetSubScopes(scope);
for (int i = 0; i < BRANCH_NUM; ++i) {
sub_net_op_[i]->Run(*sub_scopes[i], dev_ctx);
}
MergeDataFromSubnet(scope, dev_ctx);
}

class CondOpProtoAndCheckerMaker : public framework::OpProtoAndCheckerMaker {
public:
CondOpProtoAndCheckerMaker(framework::OpProto* proto,
Expand Down
33 changes: 17 additions & 16 deletions paddle/operators/cond_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,7 @@ class CondOp : public framework::OperatorBase {
const framework::VariableNameMap& outputs,
const framework::AttributeMap& attrs)
: OperatorBase(type, inputs, outputs, attrs) {
index_.resize(2);
sub_net_op_.resize(2);
sub_net_op_.resize(BRANCH_NUM);
}

CondOp(const CondOp& o)
Expand All @@ -51,42 +50,44 @@ class CondOp : public framework::OperatorBase {
PADDLE_THROW("Not implemented");
}

void CreateScope(const framework::Scope& scope) const;
framework::Scope& AddSubScope(const framework::Scope& scope) const;
std::vector<framework::Scope*>& GetSubScopes(
const framework::Scope& scope) const;

void CreateIndexTensor(const framework::Scope& scope) const;
framework::LoDTensor& AddIndexTensor(const framework::Scope& scope) const;
std::vector<framework::LoDTensor>& GetIndexTensors(
const framework::Scope& scope) const;

/*
* InferShape must be called before Run.
* FIXME(yuyang18): Since InferShape has been removed, this implementation
* could be wrong.
*/
void InferShape(const framework::Scope& scope) const;
void PrepareDataForSubnet(const framework::Scope& scope,
const platform::DeviceContext& dev_ctx) const;
void MergeDataFromSubnet(const framework::Scope& scope,
const platform::DeviceContext& dev_ctx) const;

/*
* Set True Block
*/
void set_truenet(std::unique_ptr<OperatorBase>&& net) {
sub_net_op_[0] = std::move(net);
sub_net_op_[TRUE_BRANCH] = std::move(net);
}

/*
* Set False Block
*/
void set_falsenet(std::unique_ptr<OperatorBase>&& net) {
sub_net_op_[1] = std::move(net);
sub_net_op_[FALSE_BRANCH] = std::move(net);
}

void Run(const framework::Scope& scope,
const platform::DeviceContext& dev_ctx) const override;

private:
const int TRUE_BRANCH = 0;
const int FALSE_BRANCH = 1;
const int BRANCH_NUM = 2;

// sub_net_op_[0]: subnet_t
// sub_net_op_[1]: subnet_f
std::vector<std::unique_ptr<framework::OperatorBase>> sub_net_op_;

// index_[0]: True_index;
// index_[1]: False_index;
mutable std::vector<std::vector<int>> index_;
};

} // namespace operators
Expand Down
4 changes: 0 additions & 4 deletions paddle/operators/scatter.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,10 +78,6 @@ void ScatterUpdate(const platform::Place& place,
for (int i = 1; i < src_dims.size(); i++)
PADDLE_ENFORCE(src_dims[i] == dst_dims[i]);

// slice size
size_t slice_size = 1;
for (int i = 0; i < src_dims.size(); ++i) slice_size *= src_dims[i];

if (platform::is_cpu_place(place)) {
CPUScatterUpdate<T>(src, index->data<int>(), index_size, output);
} else {
Expand Down
3 changes: 0 additions & 3 deletions python/paddle/v2/framework/tests/test_cond_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,4 @@ def test_forward(self):


if __name__ == "__main__":
exit(
0
) # FIXME(yuyang18): Since infer_shape has been removed, cond op may error
unittest.main()