@@ -26,18 +26,91 @@ void VarDesc::SetShape(const std::vector<int64_t> &dims) {
26
26
VectorToRepeated (dims, mutable_tensor_desc ()->mutable_dims ());
27
27
}
28
28
29
+ void VarDesc::SetTensorDescNum (size_t num) {
30
+ switch (desc_.type ()) {
31
+ case proto::VarDesc::READER: {
32
+ auto *lod_tensors_ptr = desc_.mutable_reader ()->mutable_lod_tensor ();
33
+ lod_tensors_ptr->Clear ();
34
+ for (size_t i = 0 ; i < num; ++i) {
35
+ lod_tensors_ptr->Add ();
36
+ }
37
+ return ;
38
+ } break ;
39
+ default :
40
+ PADDLE_THROW (
41
+ " Setting 'sub_tensor_number' is not supported by the type of var %s." ,
42
+ this ->Name ());
43
+ }
44
+ }
45
+
46
+ size_t VarDesc::GetTensorDescNum () const {
47
+ switch (desc_.type ()) {
48
+ case proto::VarDesc::READER:
49
+ return desc_.reader ().lod_tensor_size ();
50
+ break ;
51
+ default :
52
+ PADDLE_THROW (
53
+ " Getting 'sub_tensor_number' is not supported by the type of var %s." ,
54
+ this ->Name ());
55
+ }
56
+ }
57
+
58
+ void VarDesc::SetShapes (
59
+ const std::vector<std::vector<int64_t >> &multiple_dims) {
60
+ PADDLE_ENFORCE_EQ (multiple_dims.size (), GetTensorDescNum (),
61
+ " The number of given shapes(%d) doesn't equal to the "
62
+ " number of sub tensor." ,
63
+ multiple_dims.size (), GetTensorDescNum ());
64
+ std::vector<proto::TensorDesc *> tensors = mutable_tensor_descs ();
65
+ for (size_t i = 0 ; i < multiple_dims.size (); ++i) {
66
+ VectorToRepeated (multiple_dims[i], tensors[i]->mutable_dims ());
67
+ }
68
+ }
69
+
70
+ std::vector<int64_t > VarDesc::GetShape () const {
71
+ return RepeatedToVector (tensor_desc ().dims ());
72
+ }
73
+
74
+ std::vector<std::vector<int64_t >> VarDesc::GetShapes () const {
75
+ std::vector<proto::TensorDesc> descs = tensor_descs ();
76
+ std::vector<std::vector<int64_t >> res;
77
+ res.reserve (descs.size ());
78
+ for (const auto &tensor_desc : descs) {
79
+ res.push_back (RepeatedToVector (tensor_desc.dims ()));
80
+ }
81
+ return res;
82
+ }
83
+
29
84
void VarDesc::SetDataType (proto::DataType data_type) {
30
85
mutable_tensor_desc ()->set_data_type (data_type);
31
86
}
32
87
33
- std::vector<int64_t > VarDesc::Shape () const {
34
- return RepeatedToVector (tensor_desc ().dims ());
88
+ void VarDesc::SetDataTypes (
89
+ const std::vector<proto::DataType> &multiple_data_type) {
90
+ PADDLE_ENFORCE_EQ (multiple_data_type.size (), GetTensorDescNum (),
91
+ " The number of given data types(%d) doesn't equal to the "
92
+ " number of sub tensor." ,
93
+ multiple_data_type.size (), GetTensorDescNum ());
94
+ std::vector<proto::TensorDesc *> tensor_descs = mutable_tensor_descs ();
95
+ for (size_t i = 0 ; i < multiple_data_type.size (); ++i) {
96
+ tensor_descs[i]->set_data_type (multiple_data_type[i]);
97
+ }
35
98
}
36
99
37
100
proto::DataType VarDesc::GetDataType () const {
38
101
return tensor_desc ().data_type ();
39
102
}
40
103
104
+ std::vector<proto::DataType> VarDesc::GetDataTypes () const {
105
+ std::vector<proto::TensorDesc> descs = tensor_descs ();
106
+ std::vector<proto::DataType> res;
107
+ res.reserve (descs.size ());
108
+ for (const auto &tensor_desc : descs) {
109
+ res.push_back (tensor_desc.data_type ());
110
+ }
111
+ return res;
112
+ }
113
+
41
114
void VarDesc::SetLoDLevel (int32_t lod_level) {
42
115
switch (desc_.type ()) {
43
116
case proto::VarDesc::LOD_TENSOR:
@@ -47,8 +120,28 @@ void VarDesc::SetLoDLevel(int32_t lod_level) {
47
120
desc_.mutable_tensor_array ()->set_lod_level (lod_level);
48
121
break ;
49
122
default :
50
- PADDLE_THROW (" Tensor type=%d does not support LoDLevel" ,
51
- desc_.tensor_array ().lod_level ());
123
+ PADDLE_THROW (
124
+ " Setting 'lod_level' is not supported by the type of var %s." ,
125
+ this ->Name ());
126
+ }
127
+ }
128
+
129
+ void VarDesc::SetLoDLevels (const std::vector<int32_t > &multiple_lod_level) {
130
+ PADDLE_ENFORCE_EQ (multiple_lod_level.size (), GetTensorDescNum (),
131
+ " The number of given data types(%d) doesn't equal to the "
132
+ " number of sub tensor." ,
133
+ multiple_lod_level.size (), GetTensorDescNum ());
134
+ switch (desc_.type ()) {
135
+ case proto::VarDesc::READER: {
136
+ size_t i = 0 ;
137
+ for (auto &lod_tensor : *desc_.mutable_reader ()->mutable_lod_tensor ()) {
138
+ lod_tensor.set_lod_level (multiple_lod_level[i++]);
139
+ }
140
+ } break ;
141
+ default :
142
+ PADDLE_THROW (
143
+ " Setting 'lod_levels' is not supported by the type of var %s." ,
144
+ this ->Name ());
52
145
}
53
146
}
54
147
@@ -59,13 +152,31 @@ int32_t VarDesc::GetLoDLevel() const {
59
152
case proto::VarDesc::LOD_TENSOR_ARRAY:
60
153
return desc_.tensor_array ().lod_level ();
61
154
default :
62
- PADDLE_THROW (" Tensor type=%d does not support LoDLevel" ,
63
- desc_.tensor_array ().lod_level ());
155
+ PADDLE_THROW (
156
+ " Getting 'lod_level' is not supported by the type of var %s." ,
157
+ this ->Name ());
158
+ }
159
+ }
160
+
161
+ std::vector<int32_t > VarDesc::GetLoDLevels () const {
162
+ std::vector<int32_t > res;
163
+ switch (desc_.type ()) {
164
+ case proto::VarDesc::READER:
165
+ res.reserve (desc_.reader ().lod_tensor_size ());
166
+ for (auto &lod_tensor : desc_.reader ().lod_tensor ()) {
167
+ res.push_back (lod_tensor.lod_level ());
168
+ }
169
+ return res;
170
+ break ;
171
+ default :
172
+ PADDLE_THROW (
173
+ " Getting 'lod_levels' is not supported by the type of var %s." ,
174
+ this ->Name ());
64
175
}
65
176
}
66
177
67
178
const proto::TensorDesc &VarDesc::tensor_desc () const {
68
- PADDLE_ENFORCE (desc_.has_type (), " invoke TensorDesc must after set type " );
179
+ PADDLE_ENFORCE (desc_.has_type (), " The var's type hasn't been set. " );
69
180
switch (desc_.type ()) {
70
181
case proto::VarDesc::SELECTED_ROWS:
71
182
return desc_.selected_rows ();
@@ -74,13 +185,32 @@ const proto::TensorDesc &VarDesc::tensor_desc() const {
74
185
case proto::VarDesc::LOD_TENSOR_ARRAY:
75
186
return desc_.tensor_array ().tensor ();
76
187
default :
77
- PADDLE_THROW (" The type of var %s is unsupported." , this ->Name ());
188
+ PADDLE_THROW (
189
+ " Getting 'tensor_desc' is not supported by the type of var %s." ,
190
+ this ->Name ());
191
+ }
192
+ }
193
+
194
+ std::vector<proto::TensorDesc> VarDesc::tensor_descs () const {
195
+ PADDLE_ENFORCE (desc_.has_type (), " The var type hasn't been set." );
196
+ std::vector<proto::TensorDesc> res;
197
+ res.reserve (GetTensorDescNum ());
198
+ switch (desc_.type ()) {
199
+ case proto::VarDesc::READER:
200
+ for (const auto &lod_tensor : desc_.reader ().lod_tensor ()) {
201
+ res.push_back (lod_tensor.tensor ());
202
+ }
203
+ return res;
204
+ default :
205
+ PADDLE_THROW (
206
+ " Getting 'tensor_descs' is not supported by the type of var "
207
+ " %s." ,
208
+ this ->Name ());
78
209
}
79
210
}
80
211
81
212
proto::TensorDesc *VarDesc::mutable_tensor_desc () {
82
- PADDLE_ENFORCE (desc_.has_type (),
83
- " invoke MutableTensorDesc must after set type" );
213
+ PADDLE_ENFORCE (desc_.has_type (), " The var type hasn't been set." );
84
214
switch (desc_.type ()) {
85
215
case proto::VarDesc::SELECTED_ROWS:
86
216
return desc_.mutable_selected_rows ();
@@ -89,8 +219,30 @@ proto::TensorDesc *VarDesc::mutable_tensor_desc() {
89
219
case proto::VarDesc::LOD_TENSOR_ARRAY:
90
220
return desc_.mutable_tensor_array ()->mutable_tensor ();
91
221
default :
92
- PADDLE_THROW (" Unexpected branch." );
222
+ PADDLE_THROW (
223
+ " Getting 'mutable_tensor_desc' is not supported by the type of var "
224
+ " %s." ,
225
+ this ->Name ());
93
226
}
94
227
}
228
+
229
+ std::vector<proto::TensorDesc *> VarDesc::mutable_tensor_descs () {
230
+ PADDLE_ENFORCE (desc_.has_type (), " The var type hasn't been set." );
231
+ std::vector<proto::TensorDesc *> res;
232
+ res.reserve (GetTensorDescNum ());
233
+ switch (desc_.type ()) {
234
+ case proto::VarDesc::READER:
235
+ for (auto &lod_tensor : *desc_.mutable_reader ()->mutable_lod_tensor ()) {
236
+ res.push_back (lod_tensor.mutable_tensor ());
237
+ }
238
+ return res;
239
+ default :
240
+ PADDLE_THROW (
241
+ " Getting 'tensor_descs' is not supported by the type of var "
242
+ " %s." ,
243
+ this ->Name ());
244
+ }
245
+ }
246
+
95
247
} // namespace framework
96
248
} // namespace paddle
0 commit comments