Skip to content

Commit c8ba6d5

Browse files
authored
Merge pull request #8135 from JiayiFeng/dev_make_VarDesc_supporting_multiple_tensor
Add type `Reader` for `VarDesc`
2 parents 445c74c + e5227c2 commit c8ba6d5

File tree

9 files changed

+244
-24
lines changed

9 files changed

+244
-24
lines changed

paddle/framework/backward.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -534,7 +534,7 @@ ParamGradInfoMap AppendBackward(
534534
auto root_block = program_desc.MutableBlock(root_block_idx);
535535

536536
std::string fill_one_op_out = GradVarName(target.Name());
537-
bool is_scalar = target.Shape() == std::vector<int64_t>{1};
537+
bool is_scalar = target.GetShape() == std::vector<int64_t>{1};
538538
PADDLE_ENFORCE(is_scalar, "target should be scalar");
539539
VLOG(3) << "backward from loss=" << target.Name()
540540
<< " data_type=" << target.GetDataType();
@@ -565,7 +565,7 @@ ParamGradInfoMap AppendBackward(
565565

566566
auto var = root_block->Var(fill_one_op_out);
567567
var->SetDataType(target.GetDataType());
568-
var->SetShape(target.Shape());
568+
var->SetShape(target.GetShape());
569569
auto& target_grad = retv[target.Name()];
570570
target_grad.name_ = fill_one_op_out;
571571
target_grad.block_idx_ = root_block_idx;

paddle/framework/framework.proto

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,8 @@ message LoDTensorArrayDesc {
116116
optional int32 lod_level = 2 [ default = 0 ];
117117
}
118118

119+
message Reader { repeated LoDTensorDesc lod_tensor = 1; }
120+
119121
message VarDesc {
120122
enum VarType {
121123
LOD_TENSOR = 1;
@@ -126,13 +128,15 @@ message VarDesc {
126128
LOD_RANK_TABLE = 6;
127129
LOD_TENSOR_ARRAY = 7;
128130
PLACE_LIST = 8;
131+
READER = 9;
129132
}
130133
required string name = 1;
131134
required VarType type = 2;
132-
optional LoDTensorDesc lod_tensor = 3;
133-
optional TensorDesc selected_rows = 4;
135+
optional bool persistable = 3 [ default = false ];
136+
optional LoDTensorDesc lod_tensor = 4;
137+
optional TensorDesc selected_rows = 5;
134138
optional LoDTensorArrayDesc tensor_array = 6;
135-
optional bool persistable = 5 [ default = false ];
139+
optional Reader reader = 7;
136140
}
137141

138142
message BlockDesc {

paddle/framework/op_desc.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -458,11 +458,11 @@ DDim CompileTimeInferShapeContext::GetDim(const std::string &name) const {
458458
auto var = block_.FindVarRecursive(name);
459459
PADDLE_ENFORCE(var != nullptr, "Cannot find variable %s", name);
460460
try {
461-
auto shape = var->Shape();
461+
auto shape = var->GetShape();
462462
if (shape.empty()) {
463463
return framework::make_ddim({0UL});
464464
} else {
465-
return framework::make_ddim(var->Shape());
465+
return framework::make_ddim(var->GetShape());
466466
}
467467
} catch (...) {
468468
VLOG(5) << "GetDim of variable " << name << " error";

paddle/framework/program_desc_test.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ TEST(ProgramDesc, copy_ctor) {
5353
ASSERT_NE(copy, var_before);
5454
ASSERT_EQ(copy->Name(), var_before->Name());
5555
ASSERT_EQ(copy->GetType(), var_before->GetType());
56-
ASSERT_EQ(copy->Shape(), var_before->Shape());
56+
ASSERT_EQ(copy->GetShape(), var_before->GetShape());
5757
ASSERT_EQ(copy->Proto()->SerializeAsString(),
5858
var_before->Proto()->SerializeAsString());
5959
};
@@ -117,7 +117,7 @@ TEST(ProgramDescBind, serialize_and_deserialize) {
117117
ASSERT_NE(restored, var_before);
118118
ASSERT_EQ(restored->Name(), var_before->Name());
119119
ASSERT_EQ(restored->GetType(), var_before->GetType());
120-
ASSERT_EQ(restored->Shape(), var_before->Shape());
120+
ASSERT_EQ(restored->GetShape(), var_before->GetShape());
121121
ASSERT_EQ(restored->Proto()->SerializeAsString(),
122122
var_before->Proto()->SerializeAsString());
123123
};

paddle/framework/var_desc.cc

Lines changed: 163 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -26,18 +26,91 @@ void VarDesc::SetShape(const std::vector<int64_t> &dims) {
2626
VectorToRepeated(dims, mutable_tensor_desc()->mutable_dims());
2727
}
2828

29+
void VarDesc::SetTensorDescNum(size_t num) {
30+
switch (desc_.type()) {
31+
case proto::VarDesc::READER: {
32+
auto *lod_tensors_ptr = desc_.mutable_reader()->mutable_lod_tensor();
33+
lod_tensors_ptr->Clear();
34+
for (size_t i = 0; i < num; ++i) {
35+
lod_tensors_ptr->Add();
36+
}
37+
return;
38+
} break;
39+
default:
40+
PADDLE_THROW(
41+
"Setting 'sub_tensor_number' is not supported by the type of var %s.",
42+
this->Name());
43+
}
44+
}
45+
46+
size_t VarDesc::GetTensorDescNum() const {
47+
switch (desc_.type()) {
48+
case proto::VarDesc::READER:
49+
return desc_.reader().lod_tensor_size();
50+
break;
51+
default:
52+
PADDLE_THROW(
53+
"Getting 'sub_tensor_number' is not supported by the type of var %s.",
54+
this->Name());
55+
}
56+
}
57+
58+
void VarDesc::SetShapes(
59+
const std::vector<std::vector<int64_t>> &multiple_dims) {
60+
PADDLE_ENFORCE_EQ(multiple_dims.size(), GetTensorDescNum(),
61+
"The number of given shapes(%d) doesn't equal to the "
62+
"number of sub tensor.",
63+
multiple_dims.size(), GetTensorDescNum());
64+
std::vector<proto::TensorDesc *> tensors = mutable_tensor_descs();
65+
for (size_t i = 0; i < multiple_dims.size(); ++i) {
66+
VectorToRepeated(multiple_dims[i], tensors[i]->mutable_dims());
67+
}
68+
}
69+
70+
std::vector<int64_t> VarDesc::GetShape() const {
71+
return RepeatedToVector(tensor_desc().dims());
72+
}
73+
74+
std::vector<std::vector<int64_t>> VarDesc::GetShapes() const {
75+
std::vector<proto::TensorDesc> descs = tensor_descs();
76+
std::vector<std::vector<int64_t>> res;
77+
res.reserve(descs.size());
78+
for (const auto &tensor_desc : descs) {
79+
res.push_back(RepeatedToVector(tensor_desc.dims()));
80+
}
81+
return res;
82+
}
83+
2984
void VarDesc::SetDataType(proto::DataType data_type) {
3085
mutable_tensor_desc()->set_data_type(data_type);
3186
}
3287

33-
std::vector<int64_t> VarDesc::Shape() const {
34-
return RepeatedToVector(tensor_desc().dims());
88+
void VarDesc::SetDataTypes(
89+
const std::vector<proto::DataType> &multiple_data_type) {
90+
PADDLE_ENFORCE_EQ(multiple_data_type.size(), GetTensorDescNum(),
91+
"The number of given data types(%d) doesn't equal to the "
92+
"number of sub tensor.",
93+
multiple_data_type.size(), GetTensorDescNum());
94+
std::vector<proto::TensorDesc *> tensor_descs = mutable_tensor_descs();
95+
for (size_t i = 0; i < multiple_data_type.size(); ++i) {
96+
tensor_descs[i]->set_data_type(multiple_data_type[i]);
97+
}
3598
}
3699

37100
proto::DataType VarDesc::GetDataType() const {
38101
return tensor_desc().data_type();
39102
}
40103

104+
std::vector<proto::DataType> VarDesc::GetDataTypes() const {
105+
std::vector<proto::TensorDesc> descs = tensor_descs();
106+
std::vector<proto::DataType> res;
107+
res.reserve(descs.size());
108+
for (const auto &tensor_desc : descs) {
109+
res.push_back(tensor_desc.data_type());
110+
}
111+
return res;
112+
}
113+
41114
void VarDesc::SetLoDLevel(int32_t lod_level) {
42115
switch (desc_.type()) {
43116
case proto::VarDesc::LOD_TENSOR:
@@ -47,8 +120,28 @@ void VarDesc::SetLoDLevel(int32_t lod_level) {
47120
desc_.mutable_tensor_array()->set_lod_level(lod_level);
48121
break;
49122
default:
50-
PADDLE_THROW("Tensor type=%d does not support LoDLevel",
51-
desc_.tensor_array().lod_level());
123+
PADDLE_THROW(
124+
"Setting 'lod_level' is not supported by the type of var %s.",
125+
this->Name());
126+
}
127+
}
128+
129+
void VarDesc::SetLoDLevels(const std::vector<int32_t> &multiple_lod_level) {
130+
PADDLE_ENFORCE_EQ(multiple_lod_level.size(), GetTensorDescNum(),
131+
"The number of given data types(%d) doesn't equal to the "
132+
"number of sub tensor.",
133+
multiple_lod_level.size(), GetTensorDescNum());
134+
switch (desc_.type()) {
135+
case proto::VarDesc::READER: {
136+
size_t i = 0;
137+
for (auto &lod_tensor : *desc_.mutable_reader()->mutable_lod_tensor()) {
138+
lod_tensor.set_lod_level(multiple_lod_level[i++]);
139+
}
140+
} break;
141+
default:
142+
PADDLE_THROW(
143+
"Setting 'lod_levels' is not supported by the type of var %s.",
144+
this->Name());
52145
}
53146
}
54147

@@ -59,13 +152,31 @@ int32_t VarDesc::GetLoDLevel() const {
59152
case proto::VarDesc::LOD_TENSOR_ARRAY:
60153
return desc_.tensor_array().lod_level();
61154
default:
62-
PADDLE_THROW("Tensor type=%d does not support LoDLevel",
63-
desc_.tensor_array().lod_level());
155+
PADDLE_THROW(
156+
"Getting 'lod_level' is not supported by the type of var %s.",
157+
this->Name());
158+
}
159+
}
160+
161+
std::vector<int32_t> VarDesc::GetLoDLevels() const {
162+
std::vector<int32_t> res;
163+
switch (desc_.type()) {
164+
case proto::VarDesc::READER:
165+
res.reserve(desc_.reader().lod_tensor_size());
166+
for (auto &lod_tensor : desc_.reader().lod_tensor()) {
167+
res.push_back(lod_tensor.lod_level());
168+
}
169+
return res;
170+
break;
171+
default:
172+
PADDLE_THROW(
173+
"Getting 'lod_levels' is not supported by the type of var %s.",
174+
this->Name());
64175
}
65176
}
66177

67178
const proto::TensorDesc &VarDesc::tensor_desc() const {
68-
PADDLE_ENFORCE(desc_.has_type(), "invoke TensorDesc must after set type");
179+
PADDLE_ENFORCE(desc_.has_type(), "The var's type hasn't been set.");
69180
switch (desc_.type()) {
70181
case proto::VarDesc::SELECTED_ROWS:
71182
return desc_.selected_rows();
@@ -74,13 +185,32 @@ const proto::TensorDesc &VarDesc::tensor_desc() const {
74185
case proto::VarDesc::LOD_TENSOR_ARRAY:
75186
return desc_.tensor_array().tensor();
76187
default:
77-
PADDLE_THROW("The type of var %s is unsupported.", this->Name());
188+
PADDLE_THROW(
189+
"Getting 'tensor_desc' is not supported by the type of var %s.",
190+
this->Name());
191+
}
192+
}
193+
194+
std::vector<proto::TensorDesc> VarDesc::tensor_descs() const {
195+
PADDLE_ENFORCE(desc_.has_type(), "The var type hasn't been set.");
196+
std::vector<proto::TensorDesc> res;
197+
res.reserve(GetTensorDescNum());
198+
switch (desc_.type()) {
199+
case proto::VarDesc::READER:
200+
for (const auto &lod_tensor : desc_.reader().lod_tensor()) {
201+
res.push_back(lod_tensor.tensor());
202+
}
203+
return res;
204+
default:
205+
PADDLE_THROW(
206+
"Getting 'tensor_descs' is not supported by the type of var "
207+
"%s.",
208+
this->Name());
78209
}
79210
}
80211

81212
proto::TensorDesc *VarDesc::mutable_tensor_desc() {
82-
PADDLE_ENFORCE(desc_.has_type(),
83-
"invoke MutableTensorDesc must after set type");
213+
PADDLE_ENFORCE(desc_.has_type(), "The var type hasn't been set.");
84214
switch (desc_.type()) {
85215
case proto::VarDesc::SELECTED_ROWS:
86216
return desc_.mutable_selected_rows();
@@ -89,8 +219,30 @@ proto::TensorDesc *VarDesc::mutable_tensor_desc() {
89219
case proto::VarDesc::LOD_TENSOR_ARRAY:
90220
return desc_.mutable_tensor_array()->mutable_tensor();
91221
default:
92-
PADDLE_THROW("Unexpected branch.");
222+
PADDLE_THROW(
223+
"Getting 'mutable_tensor_desc' is not supported by the type of var "
224+
"%s.",
225+
this->Name());
93226
}
94227
}
228+
229+
std::vector<proto::TensorDesc *> VarDesc::mutable_tensor_descs() {
230+
PADDLE_ENFORCE(desc_.has_type(), "The var type hasn't been set.");
231+
std::vector<proto::TensorDesc *> res;
232+
res.reserve(GetTensorDescNum());
233+
switch (desc_.type()) {
234+
case proto::VarDesc::READER:
235+
for (auto &lod_tensor : *desc_.mutable_reader()->mutable_lod_tensor()) {
236+
res.push_back(lod_tensor.mutable_tensor());
237+
}
238+
return res;
239+
default:
240+
PADDLE_THROW(
241+
"Getting 'tensor_descs' is not supported by the type of var "
242+
"%s.",
243+
this->Name());
244+
}
245+
}
246+
95247
} // namespace framework
96248
} // namespace paddle

paddle/framework/var_desc.h

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,18 +68,34 @@ class VarDesc {
6868

6969
void SetName(std::string name) { desc_.set_name(name); }
7070

71+
void SetTensorDescNum(size_t num);
72+
73+
size_t GetTensorDescNum() const;
74+
7175
void SetShape(const std::vector<int64_t> &dims);
7276

77+
void SetShapes(const std::vector<std::vector<int64_t>> &multiple_dims);
78+
79+
std::vector<int64_t> GetShape() const;
80+
81+
std::vector<std::vector<int64_t>> GetShapes() const;
82+
7383
void SetDataType(proto::DataType data_type);
7484

75-
std::vector<int64_t> Shape() const;
85+
void SetDataTypes(const std::vector<proto::DataType> &multiple_data_type);
7686

7787
proto::DataType GetDataType() const;
7888

89+
std::vector<proto::DataType> GetDataTypes() const;
90+
7991
void SetLoDLevel(int32_t lod_level);
8092

93+
void SetLoDLevels(const std::vector<int32_t> &multiple_lod_level);
94+
8195
int32_t GetLoDLevel() const;
8296

97+
std::vector<int32_t> GetLoDLevels() const;
98+
8399
proto::VarDesc::VarType GetType() const;
84100

85101
void SetType(proto::VarDesc::VarType type);
@@ -90,7 +106,9 @@ class VarDesc {
90106

91107
private:
92108
const proto::TensorDesc &tensor_desc() const;
109+
std::vector<proto::TensorDesc> tensor_descs() const;
93110
proto::TensorDesc *mutable_tensor_desc();
111+
std::vector<proto::TensorDesc *> mutable_tensor_descs();
94112

95113
proto::VarDesc desc_;
96114
};

paddle/inference/io.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ void LoadPersistables(framework::Executor& executor,
5555
VLOG(3) << "parameter's name: " << var->Name();
5656

5757
framework::VarDesc* new_var = load_block->Var(var->Name());
58-
new_var->SetShape(var->Shape());
58+
new_var->SetShape(var->GetShape());
5959
new_var->SetDataType(var->GetDataType());
6060
new_var->SetType(var->GetType());
6161
new_var->SetLoDLevel(var->GetLoDLevel());

paddle/pybind/protobuf.cc

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -214,11 +214,20 @@ void BindVarDsec(py::module &m) {
214214
py::return_value_policy::reference)
215215
.def("set_name", &VarDesc::SetName)
216216
.def("set_shape", &VarDesc::SetShape)
217+
.def("set_shapes", &VarDesc::SetShapes)
217218
.def("set_dtype", &VarDesc::SetDataType)
218-
.def("shape", &VarDesc::Shape, py::return_value_policy::reference)
219+
.def("set_dtypes", &VarDesc::SetDataTypes)
220+
.def("set_tensor_num", &VarDesc::SetTensorDescNum)
221+
.def("tensor_num", &VarDesc::GetTensorDescNum)
222+
.def("shape", &VarDesc::GetShape, py::return_value_policy::reference)
223+
.def("shapes", &VarDesc::GetShapes, py::return_value_policy::reference)
219224
.def("dtype", &VarDesc::GetDataType, py::return_value_policy::reference)
225+
.def("dtypes", &VarDesc::GetDataTypes, py::return_value_policy::reference)
220226
.def("lod_level", &VarDesc::GetLoDLevel)
227+
.def("lod_levels", &VarDesc::GetLoDLevels,
228+
py::return_value_policy::reference)
221229
.def("set_lod_level", &VarDesc::SetLoDLevel)
230+
.def("set_lod_levels", &VarDesc::SetLoDLevels)
222231
.def("type", &VarDesc::GetType)
223232
.def("set_type", &VarDesc::SetType)
224233
.def("serialize_to_string", SerializeMessage<VarDesc>)
@@ -233,7 +242,8 @@ void BindVarDsec(py::module &m) {
233242
.value("STEP_SCOPES", proto::VarDesc::STEP_SCOPES)
234243
.value("LOD_RANK_TABLE", proto::VarDesc::LOD_RANK_TABLE)
235244
.value("LOD_TENSOR_ARRAY", proto::VarDesc::LOD_TENSOR_ARRAY)
236-
.value("PLACE_LIST", proto::VarDesc::PLACE_LIST);
245+
.value("PLACE_LIST", proto::VarDesc::PLACE_LIST)
246+
.value("READER", proto::VarDesc::READER);
237247
}
238248

239249
void BindOpDesc(py::module &m) {

0 commit comments

Comments
 (0)