Skip to content

Commit

Permalink
Add setting Scope function for the graph class (PaddlePaddle#17417)
Browse files Browse the repository at this point in the history
* add set_not_owned function for graph

* add scope set. test=develop

* add scope_ptr enforce not null before setting.test=develop
  • Loading branch information
wzzju authored May 16, 2019
1 parent 58d5c61 commit 4a1b7fe
Show file tree
Hide file tree
Showing 13 changed files with 58 additions and 45 deletions.
40 changes: 20 additions & 20 deletions paddle/fluid/framework/ir/attention_lstm_fuse_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -136,22 +136,22 @@ void PrepareLSTMBias(const LoDTensor& B_forget, const LoDTensor& B_input,
void PrepareParameters(Graph* graph, const Param& param) {
// Check parameters
PADDLE_ENFORCE(graph->Has(kParamScopeAttr));
auto* scope = graph->Get<Scope*>(kParamScopeAttr);
auto& scope = graph->Get<Scope>(kParamScopeAttr);

// Create new parameters.
scope->Var(param.LSTMWeight)->GetMutable<LoDTensor>();
scope->Var(param.LSTMBias)->GetMutable<LoDTensor>();
scope->Var(param.Hidden)->GetMutable<LoDTensor>();
scope->Var(param.Cell)->GetMutable<LoDTensor>();
scope->Var(param.AttentionedX)->GetMutable<LoDTensor>();
scope->Var(param.AttentionFCOut)->GetMutable<LoDTensor>();
scope->Var(param.LSTMX)->GetMutable<LoDTensor>();
scope->Var(param.LSTMOUT)->GetMutable<LoDTensor>();
scope.Var(param.LSTMWeight)->GetMutable<LoDTensor>();
scope.Var(param.LSTMBias)->GetMutable<LoDTensor>();
scope.Var(param.Hidden)->GetMutable<LoDTensor>();
scope.Var(param.Cell)->GetMutable<LoDTensor>();
scope.Var(param.AttentionedX)->GetMutable<LoDTensor>();
scope.Var(param.AttentionFCOut)->GetMutable<LoDTensor>();
scope.Var(param.LSTMX)->GetMutable<LoDTensor>();
scope.Var(param.LSTMOUT)->GetMutable<LoDTensor>();

#define GATE_W(name__) \
auto* W_##name__##_w0 = scope->FindVar(#name__ ".w_0"); \
auto* W_##name__##_w1 = scope->FindVar(#name__ ".w_1"); \
auto* W_##name__##_b0 = scope->FindVar(#name__ ".b_0"); \
auto* W_##name__##_w0 = scope.FindVar(#name__ ".w_0"); \
auto* W_##name__##_w1 = scope.FindVar(#name__ ".w_1"); \
auto* W_##name__##_b0 = scope.FindVar(#name__ ".b_0"); \
CHECK_P3(W_##name__##_w0, W_##name__##_w1, W_##name__##_b0); \
VLOG(4) << #name__ "_w0" \
<< " shape: " << W_##name__##_w0->Get<LoDTensor>().dims(); \
Expand All @@ -169,26 +169,26 @@ void PrepareParameters(Graph* graph, const Param& param) {
GATE_W(c);
#undef GATE_W

auto* attention_fc_w = scope->FindVar("attention_fc.w_0");
auto* attention_fc_b = scope->FindVar("attention_fc.b_0");
auto* attention_output_w = scope->FindVar("attention_output.w_0");
auto* attention_output_b = scope->FindVar("attention_output.b_0");
auto* attention_fc_w = scope.FindVar("attention_fc.w_0");
auto* attention_fc_b = scope.FindVar("attention_fc.b_0");
auto* attention_output_w = scope.FindVar("attention_output.w_0");
auto* attention_output_b = scope.FindVar("attention_output.b_0");
CHECK_P4(attention_fc_w, attention_fc_b, attention_output_w,
attention_output_b);

auto* lstm_weight = scope->Var(param.LSTMWeight);
auto* lstm_weight = scope.Var(param.LSTMWeight);
auto* lstm_weight_t = lstm_weight->GetMutable<LoDTensor>();
auto* lstm_bias = scope->Var(param.LSTMBias);
auto* lstm_bias = scope.Var(param.LSTMBias);
auto* lstm_bias_t = lstm_bias->GetMutable<LoDTensor>();

// reshape attention_bias
auto* attention_bias_t =
scope->FindVar(param.AttentionBias)->GetMutable<LoDTensor>();
scope.FindVar(param.AttentionBias)->GetMutable<LoDTensor>();
PADDLE_ENFORCE_EQ(attention_bias_t->dims().size(), 1);
attention_bias_t->Resize(make_ddim({1, attention_bias_t->dims()[0]}));

auto* attention_scalar_bias_t =
scope->FindVar(param.AttentionScalarBias)->GetMutable<LoDTensor>();
scope.FindVar(param.AttentionScalarBias)->GetMutable<LoDTensor>();
attention_scalar_bias_t->Resize(
make_ddim({1, attention_scalar_bias_t->dims()[0]}));

Expand Down
4 changes: 2 additions & 2 deletions paddle/fluid/framework/ir/embedding_fc_lstm_fuse_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -151,11 +151,11 @@ static int BuildFusion(Graph* graph, const std::string& name_scope,
op_desc.SetAttr("use_seq", true);

PADDLE_ENFORCE(graph->Has(kParamScopeAttr));
auto* scope = graph->Get<Scope*>(kParamScopeAttr);
auto& scope = graph->Get<Scope>(kParamScopeAttr);
#define OP_SET_OUT(x) \
const std::string x = patterns::UniqueKey(#x); \
op_desc.SetOutput(#x, {x}); \
scope->Var(x)->GetMutable<LoDTensor>()
scope.Var(x)->GetMutable<LoDTensor>()
OP_SET_OUT(BatchedCell);
OP_SET_OUT(BatchedHidden);
OP_SET_OUT(ReorderedH0);
Expand Down
11 changes: 5 additions & 6 deletions paddle/fluid/framework/ir/fc_gru_fuse_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -69,16 +69,15 @@ static int BuildFusion(Graph* graph, const std::string& name_scope,

auto* op = graph->CreateOpNode(&op_desc);
PADDLE_ENFORCE(graph->Has(kParamScopeAttr));
auto* scope = graph->Get<Scope*>(kParamScopeAttr);
PADDLE_ENFORCE(scope);
auto& scope = graph->Get<Scope>(kParamScopeAttr);
if (with_fc_bias) {
// Fusion GRU bias = fcbias + grubias
auto* fusion_bias_var = scope->Var(NEW_NAME(bias) + bias->Name());
auto* fusion_bias_var = scope.Var(NEW_NAME(bias) + bias->Name());
auto* out_bias_tensor =
fusion_bias_var->GetMutable<framework::LoDTensor>();
PADDLE_ENFORCE(fusion_bias_var);
auto* gru_bias_var = scope->FindVar(bias->Name());
auto* fc_bias_var = scope->FindVar(fc_bias->Name());
auto* gru_bias_var = scope.FindVar(bias->Name());
auto* fc_bias_var = scope.FindVar(fc_bias->Name());
PADDLE_ENFORCE(gru_bias_var);
PADDLE_ENFORCE(fc_bias_var);
const auto& gru_bias_tenosr = gru_bias_var->Get<framework::LoDTensor>();
Expand All @@ -94,7 +93,7 @@ static int BuildFusion(Graph* graph, const std::string& name_scope,
#undef GET_NODE

#define NEW_IMTERMEDIATE_OUT(key) \
scope->Var(NEW_NAME(key))->GetMutable<framework::LoDTensor>()
scope.Var(NEW_NAME(key))->GetMutable<framework::LoDTensor>()
NEW_IMTERMEDIATE_OUT(ReorderedH0);
NEW_IMTERMEDIATE_OUT(XX);
NEW_IMTERMEDIATE_OUT(BatchedInput);
Expand Down
4 changes: 2 additions & 2 deletions paddle/fluid/framework/ir/fc_lstm_fuse_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -100,11 +100,11 @@ int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope,
op_desc.SetAttr("use_seq", true);

PADDLE_ENFORCE(graph->Has(kParamScopeAttr));
auto* scope = graph->Get<Scope*>(kParamScopeAttr);
auto& scope = graph->Get<Scope>(kParamScopeAttr);
#define OP_SET_OUT(x) \
const std::string x = patterns::UniqueKey(#x); \
op_desc.SetOutput(#x, {x}); \
scope->Var(x)->GetMutable<LoDTensor>()
scope.Var(x)->GetMutable<LoDTensor>()
OP_SET_OUT(BatchedCell);
OP_SET_OUT(BatchedHidden);
OP_SET_OUT(ReorderedH0);
Expand Down
6 changes: 4 additions & 2 deletions paddle/fluid/framework/ir/fuse_pass_base.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
// limitations under the License.

#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include <unordered_map>

namespace paddle {
namespace framework {
Expand All @@ -25,7 +26,8 @@ void FusePassBase::Init(const std::string& repr, Graph* graph) const {

Scope* FusePassBase::param_scope() const {
PADDLE_ENFORCE(graph_->Has(kParamScopeAttr));
return graph_->Get<framework::Scope*>(kParamScopeAttr);
auto& scope = graph_->Get<framework::Scope>(kParamScopeAttr);
return &scope;
}

void FusePassBase::AddStatis(int count_of_fused) const {
Expand Down Expand Up @@ -55,7 +57,7 @@ FuseOptions FusePassBase::FindFuseOption(const Node& node1,
#else
return FUSE_NATIVE;
#endif
};
}

} // namespace ir
} // namespace framework
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ void MainTest(bool convWithExistingBias) {
InitTensorHolder(&scope, place, "conv_bias");
InitTensorHolder(&scope, place, "eltwise_bias");
}
graph->Set(kParamScopeAttr, new framework::Scope*(&scope));
graph->SetNotOwned(kParamScopeAttr, &scope);

auto pass = PassRegistry::Instance().Get("conv_bias_mkldnn_fuse_pass");

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ void MainTest(const ProgramDesc& prog, int conv_count, int pool_count,
(*scales)[v] = std::make_pair(false, std::move(tensor));
}

graph->Set(kParamScopeAttr, new framework::Scope*(&scope));
graph->SetNotOwned(kParamScopeAttr, &scope);

auto pass = PassRegistry::Instance().Get("cpu_quantize_pass");
pass->Set("quant_var_scales", scales);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ void MainTest(const ProgramDesc& prog, int removed_nodes_num) {
InitTensorHolder(&scope, place, v.c_str());
}

graph->Set(kParamScopeAttr, new framework::Scope*(&scope));
graph->SetNotOwned(kParamScopeAttr, &scope);

auto pass = PassRegistry::Instance().Get("cpu_quantize_squash_pass");

Expand Down
4 changes: 2 additions & 2 deletions paddle/fluid/framework/ir/seqconv_eltadd_relu_fuse_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,11 @@ int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope) {
op_desc.SetAttr("contextStart", seqconv->Op()->GetAttr("contextStart"));
op_desc.SetAttr("contextStride", seqconv->Op()->GetAttr("contextStride"));
PADDLE_ENFORCE(graph->Has(kParamScopeAttr));
auto* scope = graph->Get<Scope*>(kParamScopeAttr);
auto& scope = graph->Get<Scope>(kParamScopeAttr);
const std::string ColMat = patterns::UniqueKey("SeqConvColMat");
op_desc.SetOutput("ColMat", {ColMat});
op_desc.SetOutput("Out", {relu_out->Name()});
scope->Var(ColMat)->GetMutable<LoDTensor>();
scope.Var(ColMat)->GetMutable<LoDTensor>();

auto* op = graph->CreateOpNode(&op_desc);
IR_NODE_LINK_TO(input, op);
Expand Down
6 changes: 3 additions & 3 deletions paddle/fluid/inference/analysis/ir_pass_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,9 @@ IRPassManager::IRPassManager(Argument *argument) {
ARGUMENT_CHECK_FIELD(argument, main_program);
graph_ = std::unique_ptr<Graph>(new Graph(argument->main_program()));
if (argument->Has("scope")) {
graph_->Set(framework::ir::kParamScopeAttr,
new framework::Scope *(
const_cast<framework::Scope *>(&argument->scope())));
auto *scope_ptr = argument->scope_ptr();
PADDLE_ENFORCE(scope_ptr);
graph_->SetNotOwned(framework::ir::kParamScopeAttr, scope_ptr);
}

ARGUMENT_CHECK_FIELD(argument, ir_analysis_passes);
Expand Down
8 changes: 5 additions & 3 deletions paddle/fluid/inference/analysis/passes/ir_graph_build_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,10 @@
// limitations under the License.

#include "paddle/fluid/inference/analysis/passes/ir_graph_build_pass.h"
#include <paddle/fluid/framework/ir/fuse_pass_base.h>
#include <memory>
#include <string>
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/inference/io.h"
#include "paddle/fluid/platform/enforce.h"

Expand Down Expand Up @@ -56,8 +57,9 @@ void IrGraphBuildPass::RunImpl(Argument *argument) {

auto graph = std::unique_ptr<Graph>(new Graph(argument->main_program()));
argument->SetMainGraph(graph.release());
argument->main_graph().Set(framework::ir::kParamScopeAttr,
new framework::Scope *(argument->scope_ptr()));
auto *scope_ptr = argument->scope_ptr();
PADDLE_ENFORCE(scope_ptr);
argument->main_graph().SetNotOwned(framework::ir::kParamScopeAttr, scope_ptr);
}

std::unique_ptr<framework::ProgramDesc> IrGraphBuildPass::LoadModel(
Expand Down
5 changes: 3 additions & 2 deletions paddle/fluid/inference/api/mkldnn_quantizer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -353,8 +353,9 @@ void AnalysisPredictor::MkldnnQuantizer::PrepareArgument() const {
arg.SetMainProgramNotOwned(predictor_.inference_program_.get());
auto graph = std::unique_ptr<Graph>(new Graph(arg.main_program()));
arg.SetMainGraph(graph.release());
arg.main_graph().Set(framework::ir::kParamScopeAttr,
new framework::Scope*(arg.scope_ptr()));
auto* scope_ptr = arg.scope_ptr();
PADDLE_ENFORCE(scope_ptr);
arg.main_graph().SetNotOwned(framework::ir::kParamScopeAttr, scope_ptr);

auto* builder = predictor_.config_.pass_builder();
builder->SetPasses({
Expand Down
9 changes: 9 additions & 0 deletions paddle/fluid/pybind/ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/node.h"
#include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/var_desc.h"
#include "pybind11/stl.h"

Expand All @@ -37,6 +38,7 @@ using paddle::framework::ir::TopologySortOperations;
using paddle::framework::ir::BuildOperationAdjList;
using paddle::framework::OpDesc;
using paddle::framework::ProgramDesc;
using paddle::framework::Scope;
using paddle::framework::VarDesc;
using pybind11::return_value_policy;

Expand All @@ -57,12 +59,15 @@ void BindGraph(py::module *m) {
.def(py::init<const ProgramDesc &>())
.def("clone", &Graph::Clone)
.def("has", &Graph::Has)
.def("get_bool", &Graph::Get<bool>)
.def("get_int", &Graph::Get<int>)
.def("get_float", &Graph::Get<float>)
.def("get_double", &Graph::Get<double>)
.def("get_string", &Graph::Get<std::string>)
.def("get_marked_nodes", &Graph::Get<std::unordered_set<const Node *>>,
return_value_policy::reference)
.def("set", [](Graph &self, const std::string &attr_name,
bool attr) { return self.Set(attr_name, new bool(attr)); })
.def("set", [](Graph &self, const std::string &attr_name,
int attr) { return self.Set(attr_name, new int(attr)); })
.def("set",
Expand Down Expand Up @@ -90,6 +95,10 @@ void BindGraph(py::module *m) {
return self.Set(attr_name,
new std::unordered_set<std::string>(attr));
})
.def("set_not_owned",
[](Graph &self, const std::string &attr_name, Scope &attr) {
self.SetNotOwned<Scope>(attr_name, &attr);
})
.def("erase", &Graph::Erase)
.def("nodes", &Graph::Nodes, return_value_policy::reference)
.def("create_var_node",
Expand Down

0 comments on commit 4a1b7fe

Please sign in to comment.