Skip to content

Refactorize Scope #3116

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Aug 1, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 9 additions & 9 deletions doc/design/scope.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ Scope is an association of a name to variable. All variables belong to `Scope`.
```cpp
class Scope {
public:
Variable* CreateVariable(const std::string& name);
const Variable* GetVariable(const std::string& name) const;
Variable* NewVar(const std::string& name);
const Variable* FindVar(const std::string& name) const;

private:
std::unordered_map<std::string, std::unique_ptr<Variable>> vars_;
Expand All @@ -58,12 +58,12 @@ class Scope {
public:
Scope(const std::shared_ptr<Scope>& scope): parent_(scope) {}

Variable* GetVariable(const std::string& name) const {
Variable* FindVar(const std::string& name) const {
auto it = vars_.find(name);
if (it != vars_.end()) {
return it->second.get();
} else if (parent_ != nullptr) {
return parent_->GetVariable(name);
return parent_->FindVar(name);
} else {
return nullptr;
}
Expand Down Expand Up @@ -95,10 +95,10 @@ class Scope {
static std::shared_ptr<Scope> Create(const std::shared_ptr<Scope>& parent = nullptr);

// return nullptr if not found.
Variable* GetVariable(const std::string& name) const;
Variable* FindVar(const std::string& name) const;

// return if already contains same name variable.
Variable* CreateVariable(const std::string& name);
Variable* NewVar(const std::string& name);

private:
std::shared_ptr<Scope> parent_;
Expand All @@ -107,11 +107,11 @@ class Scope {
```
## Only scope can create a variable

To ensure `only scope can create a variable`, we should mark `Variable`'s constructor as a private member function, and Scope is a friend class of Variable. And then only `CreateVariable` can construct `Variable`.
To ensure `only scope can create a variable`, we should mark `Variable`'s constructor as a private member function, and Scope is a friend class of Variable. And then only `NewVar` can construct `Variable`.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

@Superjomn Superjomn Jul 31, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not only Variable, but alsoOperator, Tensor andScope should all be DISALLOW_COPY_AND_ASSIGN, this macro will make default and copy constructor = delete.


## When scope destroyed, all variables inside this scope should be destroyed together

The scope hold unique pointers for all variables. User can `GetVariable` from scope, but he should not hold this pointer as a member variable. Because when scope is destroyed, all variables inside this scope will be destroyed together.
The scope hold unique pointers for all variables. User can `FindVar` from scope, but he should not hold this pointer as a member variable. Because when scope is destroyed, all variables inside this scope will be destroyed together.

## Sharing a parent scope

Expand All @@ -121,4 +121,4 @@ Also, as the parent scope is a `shared_ptr`, we can only `Create()` a scope shar

## Orthogonal interface

`GetVariable` will return `nullptr` when `name` is not found. It can be used as `Contains` method. `CreateVariable` will return a `Error` when there is a name conflict locally. Combine `GetVariable` and `CreateVariable`, we can implement `CreateOrGetVariable` easily.
`FindVar` will return `nullptr` when `name` is not found. It can be used as `Contains` method. `NewVar` will return a `Error` when there is a name conflict locally. Combine `FindVar` and `NewVar`, we can implement `NewVar` easily.
6 changes: 4 additions & 2 deletions paddle/framework/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,17 @@ cc_test(tensor_test SRCS tensor_test.cc DEPS tensor)
cc_test(eigen_test SRCS eigen_test.cc DEPS tensor)

cc_test(variable_test SRCS variable_test.cc)
cc_test(scope_test SRCS scope_test.cc)

cc_library(scope SRCS scope.cc)
cc_test(scope_test SRCS scope_test.cc DEPS scope)

proto_library(attr_type SRCS attr_type.proto)
proto_library(op_proto SRCS op_proto.proto DEPS attr_type)
proto_library(op_desc SRCS op_desc.proto DEPS attr_type)
cc_test(op_proto_test SRCS op_proto_test.cc DEPS op_proto protobuf)
cc_test(op_desc_test SRCS op_desc_test.cc DEPS op_desc protobuf)

cc_library(operator SRCS operator.cc DEPS op_desc device_context tensor)
cc_library(operator SRCS operator.cc DEPS op_desc device_context tensor scope)
cc_test(operator_test SRCS operator_test.cc DEPS operator op_registry)

cc_library(grad_op_builder SRCS grad_op_builder.cc DEPS op_proto operator)
Expand Down
4 changes: 2 additions & 2 deletions paddle/framework/net.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ class NetOp : public OperatorBase {
* Infer all the operators' input and output variables' shapes, will be called
* before every mini-batch
*/
void InferShape(const std::shared_ptr<Scope>& scope) const override {
void InferShape(const Scope& scope) const override {
for (auto& op : ops_) {
op->InferShape(scope);
}
Expand All @@ -56,7 +56,7 @@ class NetOp : public OperatorBase {
* scope will be used instead. If no OpContext is provicded, default context
* will be used.
*/
void Run(const std::shared_ptr<Scope>& scope,
void Run(const Scope& scope,
const platform::DeviceContext& dev_ctx) const override {
for (auto& op : ops_) {
op->Run(scope, dev_ctx);
Expand Down
6 changes: 3 additions & 3 deletions paddle/framework/net_op_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@ static int run_cnt = 0;

class TestOp : public OperatorBase {
public:
void InferShape(const std::shared_ptr<Scope>& scope) const override {
void InferShape(const framework::Scope& scope) const override {
++infer_shape_cnt;
}
void Run(const std::shared_ptr<framework::Scope>& scope,
void Run(const framework::Scope& scope,
const paddle::platform::DeviceContext& dev_ctx) const override {
++run_cnt;
}
Expand Down Expand Up @@ -61,7 +61,7 @@ TEST(OpKernel, all) {
ASSERT_EQ(1UL, tmp_idx.size());
ASSERT_EQ("y", net->outputs_[tmp_idx[0]]);

auto scope = std::make_shared<Scope>();
Scope scope;
platform::CPUDeviceContext dev_ctx;

net->InferShape(scope);
Expand Down
14 changes: 7 additions & 7 deletions paddle/framework/op_registry_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@ namespace paddle {
namespace framework {
class CosineOp : public OperatorBase {
public:
void Run(const std::shared_ptr<Scope>& scope,
void Run(const Scope& scope,
const platform::DeviceContext& dev_ctx) const override {}
void InferShape(const std::shared_ptr<Scope>& scope) const override {}
void InferShape(const Scope& scope) const override {}
};

class CosineOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
Expand All @@ -27,8 +27,8 @@ class CosineOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker {

class MyTestOp : public OperatorBase {
public:
void InferShape(const std::shared_ptr<Scope>& scope) const override {}
void Run(const std::shared_ptr<Scope>& scope,
void InferShape(const Scope& scope) const override {}
void Run(const Scope& scope,
const platform::DeviceContext& dev_ctx) const override {}
};

Expand Down Expand Up @@ -69,7 +69,7 @@ TEST(OpRegistry, CreateOp) {

std::shared_ptr<paddle::framework::OperatorBase> op =
paddle::framework::OpRegistry::CreateOp(op_desc);
auto scope = std::make_shared<paddle::framework::Scope>();
paddle::framework::Scope scope;
paddle::platform::CPUDeviceContext dev_ctx;
op->Run(scope, dev_ctx);
float scale_get = op->GetAttr<float>("scale");
Expand Down Expand Up @@ -111,7 +111,7 @@ TEST(OpRegistry, DefaultValue) {

std::shared_ptr<paddle::framework::OperatorBase> op =
paddle::framework::OpRegistry::CreateOp(op_desc);
auto scope = std::make_shared<paddle::framework::Scope>();
paddle::framework::Scope scope;
paddle::platform::CPUDeviceContext dev_ctx;
op->Run(scope, dev_ctx);
ASSERT_EQ(op->GetAttr<float>("scale"), 1.0);
Expand Down Expand Up @@ -173,7 +173,7 @@ TEST(OpRegistry, CustomChecker) {
SetInputFormat(&op_desc);
auto op = paddle::framework::OpRegistry::CreateOp(op_desc);
paddle::platform::CPUDeviceContext dev_ctx;
auto scope = std::make_shared<paddle::framework::Scope>();
paddle::framework::Scope scope;
op->Run(scope, dev_ctx);
int test_attr = op->GetAttr<int>("test_attr");
ASSERT_EQ(test_attr, 4);
Expand Down
40 changes: 20 additions & 20 deletions paddle/framework/operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,10 +71,10 @@ class OperatorBase {

/// InferShape infer the size of Variables used by this Operator with
/// information inside scope
virtual void InferShape(const std::shared_ptr<Scope>& scope) const = 0;
virtual void InferShape(const Scope& scope) const = 0;

/// Net will call this function to Run an op.
virtual void Run(const std::shared_ptr<Scope>& scope,
virtual void Run(const Scope& scope,
const platform::DeviceContext& dev_ctx) const = 0;

virtual bool IsNetOp() const { return false; }
Expand All @@ -101,27 +101,27 @@ class OperatorBase {

class OperatorContext {
public:
OperatorContext(const OperatorBase* op, const std::shared_ptr<Scope>& scope)
OperatorContext(const OperatorBase* op, const Scope& scope)
: op_(*op), scope_(scope) {}

size_t InputSize() const { return op_.inputs_.size(); }

size_t OutputSize() const { return op_.outputs_.size(); }

const Variable* InputVar(const size_t& index) const {
return scope_->GetVariable(op_.inputs_.at(index));
const Variable* InputVar(const size_t index) const {
return scope_.FindVar(op_.inputs_.at(index));
}

Variable* OutputVar(const size_t& index) const {
return scope_->GetVariable(op_.outputs_.at(index));
Variable* OutputVar(const size_t index) const {
return scope_.FindVar(op_.outputs_.at(index));
}

const Variable* InputVar(const std::string& name) const {
return scope_->GetVariable(op_.Input(name));
return scope_.FindVar(op_.Input(name));
}

Variable* OutputVar(const std::string& name) const {
return scope_->GetVariable(op_.Output(name));
return scope_.FindVar(op_.Output(name));
}

const std::vector<const Variable*> MultiInputVar(
Expand All @@ -131,7 +131,7 @@ class OperatorContext {
res.reserve(names.size());
std::transform(
names.begin(), names.end(), std::back_inserter(res),
[this](const std::string& name) { return scope_->GetVariable(name); });
[this](const std::string& name) { return scope_.FindVar(name); });
return res;
}

Expand All @@ -141,17 +141,17 @@ class OperatorContext {
res.reserve(names.size());
std::transform(
names.begin(), names.end(), std::back_inserter(res),
[this](const std::string& name) { return scope_->GetVariable(name); });
[this](const std::string& name) { return scope_.FindVar(name); });
return res;
}

template <typename T>
const T* Input(const size_t& index) const {
const T* Input(const size_t index) const {
return &(InputVar(index)->Get<T>());
}

template <typename T>
T* Output(const size_t& index) const {
T* Output(const size_t index) const {
return OutputVar(index)->GetMutable<T>();
}

Expand All @@ -172,7 +172,7 @@ class OperatorContext {
res.reserve(names.size());
std::transform(names.begin(), names.end(), std::back_inserter(res),
[this](const std::string& name) {
return &scope_->GetVariable(name)->Get<T>();
return &scope_.FindVar(name)->Get<T>();
});
return res;
}
Expand All @@ -184,18 +184,18 @@ class OperatorContext {
res.reserve(names.size());
std::transform(names.begin(), names.end(), std::back_inserter(res),
[this](const std::string& name) {
return scope_->GetVariable(name)->GetMutable<T>();
return scope_.FindVar(name)->GetMutable<T>();
});
return res;
}

const OperatorBase& op_;
const std::shared_ptr<Scope>& scope_;
const Scope& scope_;
};

class InferShapeContext : public OperatorContext {
public:
InferShapeContext(const OperatorBase* op, const std::shared_ptr<Scope>& scope)
InferShapeContext(const OperatorBase* op, const Scope& scope)
: OperatorContext(op, scope) {}
};

Expand All @@ -216,7 +216,7 @@ struct EigenDeviceConverter<platform::GPUPlace> {

class ExecutionContext : public OperatorContext {
public:
ExecutionContext(const OperatorBase* op, const std::shared_ptr<Scope>& scope,
ExecutionContext(const OperatorBase* op, const Scope& scope,
const platform::DeviceContext& device_context)
: OperatorContext(op, scope), device_context_(device_context) {}

Expand Down Expand Up @@ -269,11 +269,11 @@ class OperatorWithKernel : public OperatorBase {
using OpKernelMap =
std::unordered_map<OpKernelKey, std::unique_ptr<OpKernel>, OpKernelHash>;

void InferShape(const std::shared_ptr<Scope>& scope) const {
void InferShape(const Scope& scope) const {
InferShape(InferShapeContext(this, scope));
}

void Run(const std::shared_ptr<Scope>& scope,
void Run(const Scope& scope,
const platform::DeviceContext& dev_ctx) const final {
auto& opKernel = AllOpKernels().at(type_).at(OpKernelKey(dev_ctx));
opKernel->Compute(ExecutionContext(this, scope, dev_ctx));
Expand Down
38 changes: 18 additions & 20 deletions paddle/framework/operator_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,16 +24,15 @@ static int op_run_num = 0;
class OpWithoutKernelTest : public OperatorBase {
public:
void Init() override { x = 1; }
void InferShape(
const std::shared_ptr<framework::Scope>& scope) const override {}
void Run(const std::shared_ptr<Scope>& scope,
void InferShape(const Scope& scope) const override {}
void Run(const Scope& scope,
const platform::DeviceContext& dev_ctx) const override {
op_run_num++;
ASSERT_EQ((int)inputs_.size(), 1);
ASSERT_EQ((int)outputs_.size(), 1);
ASSERT_EQ(scope->GetVariable(inputs_[0]), nullptr);
ASSERT_EQ(scope.FindVar(inputs_[0]), nullptr);
ASSERT_EQ(x, 1);
ASSERT_NE(scope->GetVariable(outputs_[0]), nullptr);
ASSERT_NE(scope.FindVar(outputs_[0]), nullptr);
}

public:
Expand Down Expand Up @@ -69,10 +68,10 @@ TEST(OperatorBase, all) {
attr->set_f(3.14);

paddle::platform::CPUDeviceContext device_context;
auto scope = std::make_shared<paddle::framework::Scope>();
paddle::framework::Scope scope;

auto op = paddle::framework::OpRegistry::CreateOp(op_desc);
scope->CreateVariable("OUT1");
scope.NewVar("OUT1");
ASSERT_EQ(paddle::framework::op_run_num, 0);
op->InferShape(scope);
op->Run(scope, device_context);
Expand Down Expand Up @@ -118,13 +117,12 @@ class CPUKernelTest : public OpKernel {
class OperatorMultiInputsTest : public OperatorBase {
public:
void Init() override { x = 1; }
void InferShape(
const std::shared_ptr<framework::Scope>& scope) const override {}
void Run(const std::shared_ptr<Scope>& scope,
void InferShape(const Scope& scope) const override {}
void Run(const Scope& scope,
const platform::DeviceContext& dev_ctx) const override {
ASSERT_EQ(scope->GetVariable(inputs_[0]), nullptr);
ASSERT_EQ(scope.FindVar(inputs_[0]), nullptr);
ASSERT_EQ(x, 1);
ASSERT_NE(scope->GetVariable(outputs_[0]), nullptr);
ASSERT_NE(scope.FindVar(outputs_[0]), nullptr);
ASSERT_EQ(Input("x"), "IN1");
ASSERT_EQ(Input("y"), "OUT1");
}
Expand Down Expand Up @@ -206,7 +204,7 @@ TEST(OpKernel, all) {
attr->set_f(3.14);

paddle::platform::CPUDeviceContext cpu_device_context;
auto scope = std::make_shared<paddle::framework::Scope>();
paddle::framework::Scope scope;

auto op = paddle::framework::OpRegistry::CreateOp(op_desc);
ASSERT_EQ(paddle::framework::cpu_kernel_run_num, 0);
Expand Down Expand Up @@ -252,13 +250,13 @@ TEST(OpKernel, multi_inputs) {
output_format->Add(2); // y1

paddle::platform::CPUDeviceContext cpu_device_context;
auto scope = std::make_shared<Scope>();
scope->CreateVariable("x0")->GetMutable<Tensor>();
scope->CreateVariable("x1")->GetMutable<Tensor>();
scope->CreateVariable("x2")->GetMutable<Tensor>();
scope->CreateVariable("k0")->GetMutable<Tensor>();
scope->CreateVariable("y0")->GetMutable<Tensor>();
scope->CreateVariable("y1")->GetMutable<Tensor>();
paddle::framework::Scope scope;
scope.NewVar("x0")->GetMutable<Tensor>();
scope.NewVar("x1")->GetMutable<Tensor>();
scope.NewVar("x2")->GetMutable<Tensor>();
scope.NewVar("k0")->GetMutable<Tensor>();
scope.NewVar("y0")->GetMutable<Tensor>();
scope.NewVar("y1")->GetMutable<Tensor>();

auto op = paddle::framework::OpRegistry::CreateOp(op_desc);
op->Run(scope, cpu_device_context);
Expand Down
Loading