Skip to content

Commit f3dc311

Browse files
authored
add split ids op (#9370)
* add split_ids_op * add TestSplitIdsOp * fix comment * add test for empty tensor * clean code * rm unused code
1 parent 2e4a398 commit f3dc311

File tree

3 files changed

+176
-0
lines changed

3 files changed

+176
-0
lines changed
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/fluid/operators/split_ids_op.h"
16+
17+
namespace paddle {
18+
namespace operators {
19+
20+
class SplitIdsOpMaker : public framework::OpProtoAndCheckerMaker {
21+
public:
22+
SplitIdsOpMaker(OpProto *proto, OpAttrChecker *op_checker)
23+
: OpProtoAndCheckerMaker(proto, op_checker) {
24+
AddInput("Ids", "(LoDTensor) the input ids with shape{batch_num, 1}");
25+
AddOutput("Out", "(LoDTensor) The outputs of the input Ids.")
26+
.AsDuplicable();
27+
28+
AddComment(R"DOC(
29+
Split a LoDTensor of Ids into multi LoDTensors, the number is pserver's number
30+
Example:
31+
Input:
32+
X = [1,2,3,4,5,6]
33+
34+
Out(3 output):
35+
out0 = [3, 6]
36+
out1 = [1, 4]
37+
out2 = [2, 5]
38+
)DOC");
39+
}
40+
};
41+
42+
class SplitIdsOp : public framework::OperatorWithKernel {
43+
public:
44+
using framework::OperatorWithKernel::OperatorWithKernel;
45+
46+
void InferShape(framework::InferShapeContext *ctx) const override {
47+
PADDLE_ENFORCE(ctx->HasInput("Ids"), "SplitIdsOp must has input Ids.");
48+
PADDLE_ENFORCE(ctx->HasOutputs("Out"), "SplitIdsOp must has output Out.");
49+
50+
auto ids_var_type = ctx->GetInputsVarType("Ids").front();
51+
PADDLE_ENFORCE_EQ(ids_var_type, framework::proto::VarType::LOD_TENSOR);
52+
53+
auto ids_dims = ctx->GetInputDim("Ids");
54+
PADDLE_ENFORCE_EQ(ids_dims.size(), 2);
55+
PADDLE_ENFORCE_EQ(ids_dims[1], 1);
56+
}
57+
};
58+
59+
class SplitIdsOpInferVarType : public framework::VarTypeInference {
60+
public:
61+
void operator()(const framework::OpDesc &op_desc,
62+
framework::BlockDesc *block) const override {
63+
for (auto &out_var : op_desc.Output("Out")) {
64+
block->Var(out_var)->SetType(framework::proto::VarType::LOD_TENSOR);
65+
}
66+
}
67+
};
68+
69+
} // namespace operators
70+
} // namespace paddle
71+
72+
namespace ops = paddle::operators;
73+
REGISTER_OPERATOR(split_ids, ops::SplitIdsOp, ops::SplitIdsOpMaker,
74+
ops::SplitIdsOpInferVarType);
75+
REGISTER_OP_CPU_KERNEL(
76+
split_ids, ops::SplitIdsOpKernel<paddle::platform::CPUPlace, int64_t>);
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#pragma once
16+
17+
#include <vector>
18+
#include "paddle/fluid/framework/op_registry.h"
19+
#include "paddle/fluid/operators/math/selected_rows_functor.h"
20+
21+
namespace paddle {
22+
namespace operators {
23+
24+
template <typename DeviceContext, typename T>
25+
class SplitIdsOpKernel : public framework::OpKernel<T> {
26+
public:
27+
void Compute(const framework::ExecutionContext& ctx) const override {
28+
auto place = ctx.GetPlace();
29+
if (!platform::is_cpu_place(place)) {
30+
PADDLE_THROW("SplitIds do not support GPU kernel");
31+
}
32+
33+
const auto* ids_t = ctx.Input<framework::LoDTensor>("Ids");
34+
auto& ids_dims = ids_t->dims();
35+
auto outs = ctx.MultiOutput<framework::LoDTensor>("Out");
36+
37+
const T* ids = ids_t->data<T>();
38+
39+
const size_t shard_num = outs.size();
40+
41+
std::vector<std::vector<T>> out_ids;
42+
out_ids.resize(outs.size());
43+
44+
// split id by their shard_num.
45+
for (size_t i = 0; i < ids_dims[0]; ++i) {
46+
T id = ids[i];
47+
size_t shard_id = static_cast<size_t>(id) % shard_num;
48+
out_ids[shard_id].push_back(id);
49+
}
50+
51+
// create tensor for each shard and send to parameter server
52+
for (size_t i = 0; i < out_ids.size(); ++i) {
53+
auto* shard_t = outs[i];
54+
std::vector<T> ids = out_ids[i];
55+
auto* shard_data = shard_t->mutable_data<T>(
56+
framework::make_ddim({static_cast<int64_t>(ids.size()), 1}), place);
57+
for (size_t i = 0; i < ids.size(); ++i) {
58+
shard_data[i] = ids[i];
59+
}
60+
}
61+
}
62+
};
63+
64+
} // namespace operators
65+
} // namespace paddle
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import unittest
16+
import numpy as np
17+
from op_test import OpTest
18+
19+
20+
class TestSplitIdsOp(OpTest):
21+
def setUp(self):
22+
self.op_type = "split_ids"
23+
ids = np.array([[0], [2], [2], [3], [5], [5], [6]]).astype('int64')
24+
out0 = np.array([[0], [3], [6]]).astype('int64')
25+
out1 = np.array([[]]).astype('int64')
26+
out2 = np.array([[2], [2], [5], [5]]).astype('int64')
27+
self.inputs = {'Ids': ids}
28+
self.outputs = {'Out': [('out0', out0), ('out1', out1), ('out2', out2)]}
29+
30+
def test_check_output(self):
31+
self.check_output()
32+
33+
34+
if __name__ == '__main__':
35+
unittest.main()

0 commit comments

Comments
 (0)