Skip to content

Commit 0d878e4

Browse files
authored
Add Go_op, Channel_create, channel_close, channel_send and channel_receive ops (#8593)
* Adding Python boilerplate code for Go op * Add very basic test case * Adding the python logic for go routine * Fix syntax * Changing test to notest * Rename Routine to Go * Combining GoGuard and Go in one class * Modify test * Adding fluid close channel * Fixing __init__.py for calling fluid.go() * Adding stubs for channel methods and updating test case * Removing import * * Adding imports from concurrency * Initial commit of GO_OP (for varun) * Creating local scopes and go through them * Updated go op inputs persistability enforcement * Add thread execution; compile failing though * Fix go op * Cleaned up Go op * Fix yapf format issue * Readd warp ctc dir for unit tests * Updated make_channel, channel_send, channel_recv and channel_close * Moved thread function to another method, update unit tests * remove output var * Add stubs for channel operators * Updating concurrency with signatures * Updated the signature with return status * Fixed dtype in variables * Updating stub of ChannelSend + add infershape * Updating stub of ChannelRecv + add infershape * Updated signature * Adding the channel_create operator * Merge channel send+receive ops * Update concurrency tests using all operators * Updating the create op with ChannelHolder * Fix issues with channel_create_op * Add the implementation for channel_close op * Add channel close operator, fix channel close op * Adding the channel_send op * Comment channels C++ and Python code * Concurrency python api comment fix * Update unit test to add Status variable * Adding channel receive operator * Update concurrency test to demonstrate a complete CSP flow * Fix clang-format issues * Fixed "Out" parameter name * Fixing merge conflict in framework.py * Add channel ops to framework.py no_kernel_op_set * Seperating channel_send and channel_recv operators * Documenting capacity type * Update concurrency test to create go block as child block of main program * Changing set status implementation
1 parent 2edeb63 commit 0d878e4

File tree

11 files changed

+830
-12
lines changed

11 files changed

+830
-12
lines changed

paddle/fluid/framework/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,3 +96,5 @@ cc_test(op_kernel_type_test SRCS op_kernel_type_test.cc DEPS place device_contex
9696
cc_test(cow_ptr_tests SRCS details/cow_ptr_test.cc)
9797

9898
cc_test(channel_test SRCS channel_test.cc)
99+
cc_test(concurrency_test SRCS concurrency_test.cc DEPS go_op channel_close_op channel_create_op
100+
channel_send_op channel_recv_op sum_op elementwise_add_op executor proto_desc)

paddle/fluid/framework/channel.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,11 @@ class ChannelHolder {
9191

9292
inline bool IsInitialized() const { return holder_ != nullptr; }
9393

94+
inline const std::type_index Type() {
95+
PADDLE_ENFORCE_EQ(IsInitialized(), true);
96+
return holder_->Type();
97+
}
98+
9499
private:
95100
/**
96101
* @note Placeholder hides type T, so it doesn't appear as a template
Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include <thread>
16+
17+
#include "gtest/gtest.h"
18+
#include "paddle/fluid/framework/block_desc.h"
19+
#include "paddle/fluid/framework/channel.h"
20+
#include "paddle/fluid/framework/executor.h"
21+
#include "paddle/fluid/framework/op_registry.h"
22+
#include "paddle/fluid/framework/program_desc.h"
23+
24+
USE_NO_KERNEL_OP(go);
25+
USE_NO_KERNEL_OP(channel_close);
26+
USE_NO_KERNEL_OP(channel_create);
27+
USE_NO_KERNEL_OP(channel_recv);
28+
USE_NO_KERNEL_OP(channel_send);
29+
USE_NO_KERNEL_OP(elementwise_add);
30+
31+
namespace f = paddle::framework;
32+
namespace p = paddle::platform;
33+
34+
namespace paddle {
35+
namespace framework {
36+
37+
template <typename T>
38+
void CreateIntVariable(Scope &scope, p::CPUPlace &place, std::string name,
39+
T value) {
40+
// Create LoDTensor<int> of dim [1,1]
41+
auto var = scope.Var(name);
42+
auto tensor = var->GetMutable<LoDTensor>();
43+
tensor->Resize({1, 1});
44+
T *expect = tensor->mutable_data<T>(place);
45+
expect[0] = value;
46+
}
47+
48+
void InitTensorsInScope(Scope &scope, p::CPUPlace &place) {
49+
p::CPUDeviceContext ctx(place);
50+
51+
// Create channel variable
52+
scope.Var("Channel");
53+
54+
// Create Variables, x0 will be put into channel,
55+
// result will be pulled from channel
56+
CreateIntVariable(scope, place, "Status", false);
57+
CreateIntVariable(scope, place, "x0", 99);
58+
CreateIntVariable(scope, place, "result", 0);
59+
}
60+
61+
void AddOp(const std::string &type, const VariableNameMap &inputs,
62+
const VariableNameMap &outputs, AttributeMap attrs,
63+
BlockDesc *block) {
64+
// insert op
65+
auto op = block->AppendOp();
66+
op->SetType(type);
67+
for (auto &kv : inputs) {
68+
op->SetInput(kv.first, kv.second);
69+
}
70+
for (auto &kv : outputs) {
71+
op->SetOutput(kv.first, kv.second);
72+
}
73+
op->SetAttrMap(attrs);
74+
}
75+
76+
TEST(Concurrency, Go_Op) {
77+
Scope scope;
78+
p::CPUPlace place;
79+
80+
// Initialize scope variables
81+
InitTensorsInScope(scope, place);
82+
83+
framework::Executor executor(place);
84+
ProgramDesc program;
85+
BlockDesc *block = program.MutableBlock(0);
86+
87+
// Create channel OP
88+
AddOp("channel_create", {}, {{"Out", {"Channel"}}},
89+
{{"capacity", 10}, {"data_type", f::proto::VarType::LOD_TENSOR}},
90+
block);
91+
92+
// Create Go Op routine
93+
BlockDesc *goOpBlock = program.AppendBlock(program.Block(0));
94+
AddOp("channel_send", {{"Channel", {"Channel"}}, {"X", {"x0"}}},
95+
{{"Status", {"Status"}}}, {}, goOpBlock);
96+
97+
// Create Go Op
98+
AddOp("go", {{"X", {"Channel", "x0"}}}, {}, {{"sub_block", goOpBlock}},
99+
block);
100+
101+
// Create Channel Receive Op
102+
AddOp("channel_recv", {{"Channel", {"Channel"}}},
103+
{{"Status", {"Status"}}, {"Out", {"result"}}}, {}, block);
104+
105+
// Create Channel Close Op
106+
AddOp("channel_close", {{"Channel", {"Channel"}}}, {}, {}, block);
107+
108+
// Check the result tensor to make sure it is set to 0
109+
const LoDTensor &tensor = (scope.FindVar("result"))->Get<LoDTensor>();
110+
auto *initialData = tensor.data<int>();
111+
EXPECT_EQ(initialData[0], 0);
112+
113+
executor.Run(program, &scope, 0, true, true);
114+
115+
// After we call executor.run, the Go operator should do a channel_send to set
116+
// the
117+
// "result" variable to 99
118+
auto *finalData = tensor.data<int>();
119+
EXPECT_EQ(finalData[0], 99);
120+
}
121+
} // namespace framework
122+
} // namespace paddle
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/fluid/framework/channel.h"
16+
#include "paddle/fluid/framework/op_registry.h"
17+
18+
namespace pf = paddle::framework;
19+
static constexpr char kChannel[] = "Channel";
20+
21+
namespace paddle {
22+
namespace operators {
23+
24+
class ChannelCloseOp : public framework::OperatorBase {
25+
public:
26+
ChannelCloseOp(const std::string &type,
27+
const framework::VariableNameMap &inputs,
28+
const framework::VariableNameMap &outputs,
29+
const framework::AttributeMap &attrs)
30+
: framework::OperatorBase(type, inputs, outputs, attrs) {}
31+
32+
private:
33+
void RunImpl(const framework::Scope &scope,
34+
const platform::Place &dev_place) const override {
35+
auto &inp = *scope.FindVar(Input(kChannel));
36+
37+
// Get the mutable version of the channel variable and closes it.
38+
pf::ChannelHolder *ch = inp.GetMutable<framework::ChannelHolder>();
39+
ch->close();
40+
}
41+
};
42+
43+
class ChannelCloseOpOpInferShape : public framework::InferShapeBase {
44+
public:
45+
void operator()(framework::InferShapeContext *context) const override {
46+
PADDLE_ENFORCE(context->HasInput("Channel"),
47+
"The input of ChannelClose op must be set");
48+
}
49+
};
50+
51+
class ChannelCloseOpMaker : public framework::OpProtoAndCheckerMaker {
52+
public:
53+
ChannelCloseOpMaker(OpProto *proto, OpAttrChecker *op_checker)
54+
: OpProtoAndCheckerMaker(proto, op_checker) {
55+
AddInput(kChannel,
56+
"The Channel Variable that should be closed by"
57+
" the ChannelClose Op.");
58+
AddComment(R"DOC(
59+
Channel Close Operator.
60+
61+
This operator closes an open channel.
62+
)DOC");
63+
}
64+
};
65+
66+
} // namespace operators
67+
} // namespace paddle
68+
69+
REGISTER_OPERATOR(channel_close, paddle::operators::ChannelCloseOp,
70+
paddle::framework::EmptyGradOpMaker,
71+
paddle::operators::ChannelCloseOpMaker);
Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/fluid/framework/channel.h"
16+
#include "paddle/fluid/framework/lod_rank_table.h"
17+
#include "paddle/fluid/framework/lod_tensor_array.h"
18+
#include "paddle/fluid/framework/op_registry.h"
19+
#include "paddle/fluid/framework/reader.h"
20+
21+
namespace pf = paddle::framework;
22+
23+
static constexpr char kOutput[] = "Out";
24+
25+
namespace paddle {
26+
namespace operators {
27+
28+
class ChannelCreateOp : public framework::OperatorBase {
29+
public:
30+
ChannelCreateOp(const std::string &type,
31+
const framework::VariableNameMap &inputs,
32+
const framework::VariableNameMap &outputs,
33+
const framework::AttributeMap &attrs)
34+
: framework::OperatorBase(type, inputs, outputs, attrs) {}
35+
36+
private:
37+
void RunImpl(const framework::Scope &scope,
38+
const platform::Place &dev_place) const override {
39+
auto &out = *scope.FindVar(Output(kOutput));
40+
41+
// Determine the datatype and capacity of the channel to be created
42+
// from the attributes provided.
43+
auto dtype =
44+
static_cast<framework::proto::VarType::Type>(Attr<int>("data_type"));
45+
auto capacity = Attr<int>("capacity");
46+
47+
// Based on the datatype, create a new channel holder initialized with
48+
// the given capacity. When capacity is 0, an unbuffered channel is
49+
// created.
50+
pf::ChannelHolder *ch = out.GetMutable<framework::ChannelHolder>();
51+
if (dtype == framework::proto::VarType::LOD_TENSOR) {
52+
ch->Reset<pf::LoDTensor>(capacity);
53+
} else if (dtype == framework::proto::VarType::SELECTED_ROWS) {
54+
ch->Reset<pf::SelectedRows>(capacity);
55+
} else if (dtype == framework::proto::VarType::LOD_RANK_TABLE) {
56+
ch->Reset<pf::LoDRankTable>(capacity);
57+
} else if (dtype == framework::proto::VarType::LOD_TENSOR_ARRAY) {
58+
ch->Reset<pf::LoDTensorArray>(capacity);
59+
} else if (dtype == framework::proto::VarType::READER) {
60+
ch->Reset<pf::ReaderHolder>(capacity);
61+
} else if (dtype == framework::proto::VarType::CHANNEL) {
62+
ch->Reset<pf::ChannelHolder>(capacity);
63+
} else if (dtype == framework::proto::VarType::BOOL) {
64+
ch->Reset<bool>(capacity);
65+
} else if (dtype == framework::proto::VarType::INT32) {
66+
ch->Reset<int>(capacity);
67+
} else if (dtype == framework::proto::VarType::INT64) {
68+
ch->Reset<int64_t>(capacity);
69+
} else if (dtype == framework::proto::VarType::FP32) {
70+
ch->Reset<float>(capacity);
71+
} else if (dtype == framework::proto::VarType::FP64) {
72+
ch->Reset<double>(capacity);
73+
} else {
74+
PADDLE_THROW(
75+
"Data type %d is not in "
76+
"[LOD_TENSOR, SELECTED_ROWS, LOD_RANK_TABLE, LOD_TENSOR_ARRAY, "
77+
"READER, CHANNEL, BOOL, INT32, INT64, FP32, FP64]",
78+
dtype);
79+
}
80+
}
81+
};
82+
83+
class ChannelCreateOpOpInferShape : public framework::InferShapeBase {
84+
public:
85+
void operator()(framework::InferShapeContext *context) const override {
86+
PADDLE_ENFORCE(context->HasOutput(kOutput),
87+
"The output of ChannelCreate op must be set");
88+
context->SetOutputDim(kOutput, {1});
89+
}
90+
};
91+
92+
class ChannelCreateOpMaker : public framework::OpProtoAndCheckerMaker {
93+
public:
94+
ChannelCreateOpMaker(OpProto *proto, OpAttrChecker *op_checker)
95+
: OpProtoAndCheckerMaker(proto, op_checker) {
96+
AddOutput(kOutput,
97+
"The object of a Channel type created by ChannelCreate Op.");
98+
AddAttr<int>("capacity", "The size of the buffer of Channel.")
99+
.SetDefault(0);
100+
AddAttr<int>("data_type", "The data type of elements inside the Channel.");
101+
AddComment(R"DOC(
102+
Channel Create Operator.
103+
104+
This operator creates an object of the VarType Channel and returns it.
105+
)DOC");
106+
}
107+
};
108+
109+
} // namespace operators
110+
} // namespace paddle
111+
112+
REGISTER_OPERATOR(channel_create, paddle::operators::ChannelCreateOp,
113+
paddle::framework::EmptyGradOpMaker,
114+
paddle::operators::ChannelCreateOpMaker);

0 commit comments

Comments
 (0)