Skip to content

Commit 7c426be

Browse files
authored
Merge pull request #11342 from jacquesqiao/add-merge-splited-ids
Add merge_ids_op
2 parents d3b0129 + e6f54d5 commit 7c426be

File tree

6 files changed

+297
-9
lines changed

6 files changed

+297
-9
lines changed

paddle/fluid/framework/executor.cc

+5-1
Original file line numberDiff line numberDiff line change
@@ -330,8 +330,12 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope,
330330
}
331331

332332
for (auto& op : ctx->ops_) {
333-
VLOG(3) << place_ << " " << op->DebugStringEx(local_scope);
333+
VLOG(4) << place_ << " " << op->DebugStringEx(local_scope);
334334
op->Run(*local_scope, place_);
335+
// NOTE! Please do not delete this line, it's usefull because the debug
336+
// string before and after op.run are different, after run the output
337+
// will have right shape which is usefull for debug.
338+
VLOG(3) << place_ << " " << op->DebugStringEx(local_scope);
335339

336340
if (FLAGS_benchmark) {
337341
VLOG(2) << "Memory used after operator " + op->Type() + " running: "

paddle/fluid/framework/operator.cc

+21
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,19 @@ static DDim GetDims(const Scope& scope, const std::string& name,
6969
}
7070
}
7171

72+
static int GetRowSize(const Scope& scope, const std::string& name) {
73+
Variable* var = scope.FindVar(name);
74+
if (var == nullptr) {
75+
return -1;
76+
}
77+
78+
if (var->IsType<SelectedRows>()) {
79+
return var->Get<SelectedRows>().rows().size();
80+
}
81+
82+
return -1;
83+
}
84+
7285
static LoD GetLoD(const Scope& scope, const std::string& name) {
7386
Variable* var = scope.FindVar(name);
7487
auto default_lod = LoD({{}});
@@ -153,6 +166,10 @@ std::string OperatorBase::DebugStringEx(const Scope* scope) const {
153166
for (size_t i = 0; i < input.second.size(); ++i) {
154167
ss << input.second[i];
155168
if (scope) {
169+
int row_size = GetRowSize(*scope, input.second[i]);
170+
if (row_size >= 0) {
171+
ss << "[row_size=" << row_size << "]";
172+
}
156173
ss << "[" << GetDims(*scope, input.second[i], true) << "]";
157174
ss << "(" << GetLoD(*scope, input.second[i]) << ")";
158175
}
@@ -173,6 +190,10 @@ std::string OperatorBase::DebugStringEx(const Scope* scope) const {
173190
for (size_t i = 0; i < output.second.size(); ++i) {
174191
ss << output.second[i];
175192
if (scope) {
193+
int row_size = GetRowSize(*scope, output.second[i]);
194+
if (row_size >= 0) {
195+
ss << "[row_size=" << row_size << "]";
196+
}
176197
ss << "[" << GetDims(*scope, output.second[i], true) << "]";
177198
ss << "(" << GetLoD(*scope, output.second[i]) << ")";
178199
}
+128
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/fluid/operators/merge_ids_op.h"
16+
17+
namespace paddle {
18+
namespace operators {
19+
20+
class MergeIdsOpMaker : public framework::OpProtoAndCheckerMaker {
21+
public:
22+
void Make() override {
23+
AddInput("Ids", "(LoDTensor) the input ids with shape{batch_num, 1}");
24+
AddInput(
25+
"X",
26+
"(LoDTensors) multi input tensor with shape{batch_num, N}, N is the "
27+
"size of embedding table")
28+
.AsDuplicable();
29+
AddOutput("Out", "(LoDTensor) The merged outputs of the input tensors.");
30+
31+
AddComment(R"DOC(
32+
Merge multi LoDTensor's into one according to Ids's shard num.
33+
34+
35+
split_ids_op -> prefetch_op -> merge_ids_op
36+
37+
38+
merge_ids_op should be used after split_ids_op and prefetch_op, split_ids_op
39+
will split input Ids into multiple tensors according to Id's shard number.
40+
prefetch_op will send them to parameter server to prefetch embedding value
41+
back. During split, the order of ids is disordered. In merge_ids_op we use
42+
the original Ids to restore the order of the fetched embedding value and
43+
also pass the lod information to the merged output.
44+
45+
46+
Example:
47+
48+
Ids = [1,2,3,4,5,6] # 3 shared
49+
50+
split_ids_op ->
51+
52+
Id0 = [3, 6] # id % 3 == 0
53+
Id1 = [1, 4] # id % 3 == 1
54+
Id2 = [2, 5] # id % 3 == 2
55+
56+
prefetch_op ->
57+
58+
X0 = [[0.3 0.3] # 3
59+
[0.6 0.6]] # 6
60+
X1 = [[0.1 0.1] # 1
61+
[0.4 0.4]] # 4
62+
X2 = [[0.2 0.2] # 2
63+
[0.5 0.5]] # 5
64+
65+
merge_ids_op ->
66+
67+
Out = [[0.1 0.1] # 1
68+
[0.2 0.2] # 2
69+
[0.3 0.3] # 3
70+
[0.4 0.4] # 4
71+
[0.5 0.5] # 5
72+
[0.6 0.6]] # 6
73+
)DOC");
74+
}
75+
};
76+
77+
class MergeIdsOp : public framework::OperatorWithKernel {
78+
public:
79+
using framework::OperatorWithKernel::OperatorWithKernel;
80+
81+
void InferShape(framework::InferShapeContext *ctx) const override {
82+
PADDLE_ENFORCE(ctx->HasInput("Ids"), "MergeIdsOp must has input Ids.");
83+
PADDLE_ENFORCE(ctx->HasInputs("X"), "MergeIdsOp must has input X.");
84+
PADDLE_ENFORCE(ctx->HasOutput("Out"), "MergeIdsOp must has output Out.");
85+
86+
auto ids_var_type = ctx->GetInputsVarType("Ids").front();
87+
auto ids_dims = ctx->GetInputDim("Ids");
88+
if (ids_var_type == framework::proto::VarType::LOD_TENSOR) {
89+
PADDLE_ENFORCE_EQ(ids_dims.size(), 2);
90+
PADDLE_ENFORCE_EQ(ids_dims[1], 1);
91+
}
92+
auto x_var_type = ctx->GetInputsVarType("X");
93+
for (auto &var_type : x_var_type) {
94+
PADDLE_ENFORCE_EQ(var_type, framework::proto::VarType::LOD_TENSOR,
95+
"input X only support lod tensors");
96+
}
97+
ctx->ShareLoD("Ids", "Out");
98+
}
99+
100+
private:
101+
framework::OpKernelType GetExpectedKernelType(
102+
const framework::ExecutionContext &ctx) const override {
103+
return framework::OpKernelType(
104+
framework::ToDataType(
105+
ctx.MultiInput<framework::Tensor>("X").front()->type()),
106+
ctx.GetPlace());
107+
}
108+
};
109+
110+
class MergeIdsOpInferVarType : public framework::VarTypeInference {
111+
public:
112+
void operator()(const framework::OpDesc &op_desc,
113+
framework::BlockDesc *block) const override {
114+
auto *input_var = block->Var(op_desc.Input("Ids")[0]);
115+
for (auto &out_var : op_desc.Output("Out")) {
116+
block->Var(out_var)->SetType(input_var->GetType());
117+
}
118+
}
119+
};
120+
121+
} // namespace operators
122+
} // namespace paddle
123+
124+
namespace ops = paddle::operators;
125+
REGISTER_OPERATOR(merge_ids, ops::MergeIdsOp, ops::MergeIdsOpMaker,
126+
ops::MergeIdsOpInferVarType);
127+
REGISTER_OP_CPU_KERNEL(
128+
merge_ids, ops::MergeIdsOpKernel<paddle::platform::CPUPlace, float>);

paddle/fluid/operators/merge_ids_op.h

+92
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#pragma once
16+
17+
#include <vector>
18+
#include "paddle/fluid/framework/op_registry.h"
19+
#include "paddle/fluid/framework/tensor_util.h"
20+
#include "paddle/fluid/operators/math/selected_rows_functor.h"
21+
22+
namespace paddle {
23+
namespace operators {
24+
25+
template <typename DeviceContext, typename T>
26+
class MergeIdsOpKernel : public framework::OpKernel<T> {
27+
public:
28+
void Compute(const framework::ExecutionContext &ctx) const override {
29+
auto place = ctx.GetPlace();
30+
if (!platform::is_cpu_place(place)) {
31+
PADDLE_THROW("MergeIds do not support GPU kernel");
32+
}
33+
VLOG(3) << "run in MergeIdsOpKernel";
34+
35+
const auto *ids_var = ctx.InputVar("Ids");
36+
PADDLE_ENFORCE(ids_var->IsType<framework::LoDTensor>(),
37+
"only support to merge Ids of LoDTensor");
38+
39+
const auto &ids_tensor = ids_var->Get<framework::LoDTensor>();
40+
const auto &ids_dims = ids_tensor.dims();
41+
const int64_t *ids = ids_tensor.data<int64_t>();
42+
43+
auto x_tensors = ctx.MultiInput<framework::LoDTensor>("X");
44+
45+
auto *out = ctx.Output<framework::LoDTensor>("Out");
46+
47+
int batch_size = 0;
48+
int embedding_size = 0;
49+
for (auto &input : x_tensors) {
50+
if (framework::product(input->dims()) != 0) {
51+
if (embedding_size == 0) {
52+
embedding_size = input->dims()[1];
53+
}
54+
PADDLE_ENFORCE_EQ(embedding_size, input->dims()[1],
55+
"embedding size of all input should be the same");
56+
batch_size += input->dims()[0];
57+
}
58+
}
59+
PADDLE_ENFORCE_EQ(
60+
batch_size, ids_dims[0],
61+
"the batch size of ids and merged embedding value should be the same");
62+
63+
const size_t shard_num = x_tensors.size();
64+
65+
if (shard_num == 1) {
66+
VLOG(3) << "only one shard, we can copy the data directly";
67+
TensorCopy(*x_tensors[0], place, out);
68+
} else {
69+
std::vector<int> in_indexs(shard_num, 0);
70+
auto *out_data = out->mutable_data<T>(
71+
framework::make_ddim({batch_size, embedding_size}), place);
72+
// copy data from ins[shard_num] to out.
73+
for (int i = 0; i < ids_dims[0]; ++i) {
74+
int64_t id = ids[i];
75+
size_t shard_id = static_cast<size_t>(id) % shard_num;
76+
int index = in_indexs[shard_id];
77+
memcpy(out_data + embedding_size * i,
78+
x_tensors[shard_id]->data<T>() + index * embedding_size,
79+
sizeof(T) * embedding_size);
80+
in_indexs[shard_id] += 1;
81+
}
82+
83+
for (size_t i = 0; i < shard_num; ++i) {
84+
PADDLE_ENFORCE_EQ(in_indexs[i], x_tensors[i]->dims()[0],
85+
"after merge, all data in x_tensor should be used");
86+
}
87+
}
88+
}
89+
};
90+
91+
} // namespace operators
92+
} // namespace paddle
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import unittest
16+
import numpy as np
17+
from op_test import OpTest
18+
19+
20+
class TestMergeIdsOp(OpTest):
21+
def setUp(self):
22+
self.op_type = "merge_ids"
23+
ids = np.array([[0], [2], [2], [3], [5], [5], [6]]).astype('int64')
24+
x0 = np.array([[0.1, 0.2], [0.2, 0.3], [0.3, 0.4]]).astype('float32')
25+
x1 = np.array([]).astype('float32')
26+
x2 = np.array([[0.4, 0.5], [0.4, 0.5], [0.5, 0.6],
27+
[0.5, 0.6]]).astype('float32')
28+
out = np.array([[0.1, 0.2], [0.4, 0.5], [0.4, 0.5], [0.2, 0.3],
29+
[0.5, 0.6], [0.5, 0.6], [0.3, 0.4]]).astype('float32')
30+
self.inputs = {'Ids': ids, "X": [('x0', x0), ('x1', x1), ('x2', x2)]}
31+
self.outputs = {'Out': out}
32+
33+
def test_check_output(self):
34+
self.check_output()
35+
36+
37+
if __name__ == '__main__':
38+
unittest.main()

python/paddle/fluid/transpiler/distribute_transpiler.py

+13-8
Original file line numberDiff line numberDiff line change
@@ -629,7 +629,7 @@ def _replace_lookup_table_op_with_prefetch(self, program,
629629
if op.type == LOOKUP_TABLE_TYPE:
630630
continue_search_lookup_table_op = True
631631

632-
op_index = list(all_ops).index(op)
632+
lookup_table_op_index = list(all_ops).index(op)
633633
ids_name = op.input("Ids")
634634
out_name = op.output("Out")
635635

@@ -649,7 +649,7 @@ def _replace_lookup_table_op_with_prefetch(self, program,
649649

650650
# insert split_ids_op
651651
program.global_block().insert_op(
652-
index=op_index,
652+
index=lookup_table_op_index,
653653
type="split_ids",
654654
inputs={
655655
'Ids': [
@@ -661,7 +661,7 @@ def _replace_lookup_table_op_with_prefetch(self, program,
661661

662662
# insert prefetch_op
663663
program.global_block().insert_op(
664-
index=op_index + 1,
664+
index=lookup_table_op_index + 1,
665665
type="prefetch",
666666
inputs={'X': prefetch_input_vars},
667667
outputs={"Out": prefetch_output_vars},
@@ -672,16 +672,21 @@ def _replace_lookup_table_op_with_prefetch(self, program,
672672

673673
# insert concat_op
674674
program.global_block().insert_op(
675-
index=op_index + 2,
676-
type="concat",
677-
inputs={'X': prefetch_output_vars},
675+
index=lookup_table_op_index + 2,
676+
type="merge_ids",
677+
inputs={
678+
'Ids': [
679+
program.global_block().vars[varname]
680+
for varname in ids_name
681+
],
682+
'X': prefetch_output_vars
683+
},
678684
outputs={
679685
"Out": [
680686
program.global_block().vars[varname]
681687
for varname in out_name
682688
]
683-
},
684-
attrs={"axis": 0})
689+
})
685690

686691
# delete lookup_table_op
687692
delete_ops(program.global_block(), [op])

0 commit comments

Comments
 (0)