Skip to content

Commit b3cf28f

Browse files
authored
Remove boost::variant (#43100)
* boost::variant -> paddle::variant * boost::variant.apply_visit -> paddle::visit * Update pybind_boost_hraders.h * Fix CINN compilation errors * Revert FetchResultType
1 parent 369b2b1 commit b3cf28f

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

46 files changed

+225
-218
lines changed

paddle/fluid/distributed/collective/HCCLTools.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818

1919
#include <string>
2020

21-
#include "boost/variant.hpp"
2221
#include "paddle/fluid/distributed/collective/Types.h"
2322
#include "paddle/fluid/framework/data_type.h"
2423
#include "paddle/fluid/framework/variable.h"
@@ -27,6 +26,7 @@
2726
#include "paddle/fluid/platform/device/npu/npu_info.h"
2827
#include "paddle/fluid/platform/device_context.h"
2928
#include "paddle/fluid/platform/enforce.h"
29+
#include "paddle/utils/variant.h"
3030

3131
namespace paddle {
3232
namespace distributed {

paddle/fluid/distributed/collective/NCCLTools.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525

2626
#include <string>
2727

28-
#include "boost/variant.hpp"
2928
#include "paddle/fluid/distributed/collective/Types.h"
3029
#include "paddle/fluid/framework/data_type.h"
3130
#include "paddle/fluid/framework/variable.h"
@@ -43,6 +42,7 @@
4342
#endif
4443

4544
#include "paddle/fluid/platform/enforce.h"
45+
#include "paddle/utils/variant.h"
4646

4747
namespace paddle {
4848
namespace distributed {

paddle/fluid/eager/auto_code_generator/eager_generator.cc

Lines changed: 26 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -336,45 +336,47 @@ static std::string AttrTypeToString(const proto::AttrType& type) {
336336
}
337337
default: {
338338
PADDLE_THROW(platform::errors::Fatal(
339-
"AttrType of type boost::variant only supports specific data types."
339+
"AttrType of type paddle::variant only supports specific data types."
340340
"However, detected unrecognized AttrType: %d",
341341
type));
342342
}
343343
}
344344
return ret;
345345
}
346346

347-
template <typename T>
348-
static std::string GetAttrValue(const framework::Attribute& attr,
349-
bool is_vector) {
347+
template <typename T, bool IsVector>
348+
static typename std::enable_if<IsVector, std::string>::type GetAttrValue(
349+
const framework::Attribute& attr) {
350350
std::string val = "";
351-
if (is_vector) {
352-
val += "{";
353-
for (auto x : BOOST_GET_CONST(std::vector<T>, attr)) {
354-
val += std::to_string(x) + ",";
355-
}
356-
if (val.size() > 1) val.pop_back();
357-
val += "}";
358-
} else {
359-
val = std::to_string(BOOST_GET_CONST(T, attr));
351+
val += "{";
352+
for (auto x : BOOST_GET_CONST(std::vector<T>, attr)) {
353+
val += std::to_string(x) + ",";
360354
}
355+
if (val.size() > 1) val.pop_back();
356+
val += "}";
361357
return val;
362358
}
363359

360+
template <typename T, bool IsVector>
361+
static typename std::enable_if<!IsVector, std::string>::type GetAttrValue(
362+
const framework::Attribute& attr) {
363+
return std::to_string(BOOST_GET_CONST(T, attr));
364+
}
365+
364366
static std::pair<std::string, std::string> GetAttrType(
365367
const framework::Attribute& attr, bool is_arg) {
366368
std::string ret = "";
367369
std::string val = "";
368-
size_t variant_pos = attr.which();
370+
size_t variant_pos = attr.index();
369371
switch (variant_pos) {
370372
case (1): {
371373
ret = "int";
372-
val = GetAttrValue<int>(attr, false);
374+
val = GetAttrValue<int, false>(attr);
373375
break;
374376
}
375377
case (2): {
376378
ret = "float";
377-
val = GetAttrValue<float>(attr, false);
379+
val = GetAttrValue<float, false>(attr);
378380
break;
379381
}
380382
case (3): {
@@ -386,13 +388,13 @@ static std::pair<std::string, std::string> GetAttrType(
386388
case (4): {
387389
ret = "std::vector<int>";
388390
if (is_arg) ret += "&";
389-
val = GetAttrValue<int>(attr, true);
391+
val = GetAttrValue<int, true>(attr);
390392
break;
391393
}
392394
case (5): {
393395
ret = "std::vector<float>";
394396
if (is_arg) ret += "&";
395-
val = GetAttrValue<float>(attr, true);
397+
val = GetAttrValue<float, true>(attr);
396398
break;
397399
}
398400
case (6): {
@@ -408,13 +410,13 @@ static std::pair<std::string, std::string> GetAttrType(
408410
}
409411
case (7): {
410412
ret = "bool";
411-
val = GetAttrValue<bool>(attr, false);
413+
val = GetAttrValue<bool, false>(attr);
412414
break;
413415
}
414416
case (8): {
415417
ret = "std::vector<bool>";
416418
if (is_arg) ret += "&";
417-
val = GetAttrValue<bool>(attr, true);
419+
val = GetAttrValue<bool, true>(attr);
418420
break;
419421
}
420422
case (9): {
@@ -423,7 +425,7 @@ static std::pair<std::string, std::string> GetAttrType(
423425
}
424426
case (10): {
425427
ret = "int64_t";
426-
val = GetAttrValue<int64_t>(attr, false);
428+
val = GetAttrValue<int64_t, false>(attr);
427429
break;
428430
}
429431
case (11): {
@@ -434,18 +436,18 @@ static std::pair<std::string, std::string> GetAttrType(
434436
case (12): {
435437
ret = "std::vector<int64_t>";
436438
if (is_arg) ret += "&";
437-
val = GetAttrValue<int64_t>(attr, true);
439+
val = GetAttrValue<int64_t, true>(attr);
438440
break;
439441
}
440442
case (13): {
441443
ret = "std::vector<double>";
442444
if (is_arg) ret += "&";
443-
val = GetAttrValue<double>(attr, true);
445+
val = GetAttrValue<double, true>(attr);
444446
break;
445447
}
446448
default: {
447449
PADDLE_THROW(platform::errors::Fatal(
448-
"AttrType of type boost::variant only supports specific data types."
450+
"AttrType of type paddle::variant only supports specific data types."
449451
"However, detected unrecognized AttrType: %d",
450452
variant_pos));
451453
}

paddle/fluid/framework/attribute.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
1313
limitations under the License. */
1414

1515
#include "paddle/fluid/framework/attribute.h"
16+
#include "boost/blank.hpp"
1617

1718
namespace paddle {
1819
namespace framework {

paddle/fluid/framework/attribute.h

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,12 @@ limitations under the License. */
2323
#include <unordered_set>
2424
#include <vector>
2525

26-
#include "boost/variant/get.hpp"
2726
#include "paddle/fluid/framework/framework.pb.h"
2827
#include "paddle/fluid/framework/type_defs.h"
2928
#include "paddle/fluid/platform/enforce.h"
3029
#include "paddle/fluid/platform/errors.h"
3130
#include "paddle/utils/any.h"
31+
#include "paddle/utils/variant.h"
3232

3333
namespace paddle {
3434
namespace framework {
@@ -45,8 +45,8 @@ struct ExtractAttribute {
4545
T* operator()(Attribute& attr) const {
4646
T* attr_value = nullptr;
4747
try {
48-
attr_value = &boost::get<T>(attr);
49-
} catch (boost::bad_get& bad_get) {
48+
attr_value = &paddle::get<T>(attr);
49+
} catch (paddle::bad_variant_access const& bad_get) {
5050
PADDLE_THROW(platform::errors::InvalidArgument(
5151
"Cannot get attribute (%s) by type %s, its type is %s.",
5252
attr_name_,
@@ -80,8 +80,8 @@ struct ExtractAttribute<bool> {
8080
}
8181
bool* attr_value = nullptr;
8282
try {
83-
attr_value = &boost::get<bool>(attr);
84-
} catch (boost::bad_get& bad_get) {
83+
attr_value = &paddle::get<bool>(attr);
84+
} catch (paddle::bad_variant_access const& bad_get) {
8585
PADDLE_THROW(platform::errors::InvalidArgument(
8686
"Cannot get attribute (%s) by type bool, its type is %s.",
8787
attr_name_,
@@ -108,8 +108,8 @@ struct ExtractAttribute<int64_t> {
108108
}
109109
int64_t* attr_value = nullptr;
110110
try {
111-
attr_value = &boost::get<int64_t>(attr);
112-
} catch (boost::bad_get& bad_get) {
111+
attr_value = &paddle::get<int64_t>(attr);
112+
} catch (paddle::bad_variant_access const& bad_get) {
113113
PADDLE_THROW(platform::errors::InvalidArgument(
114114
"Cannot get attribute (%s) by type int64_t, its type is %s.",
115115
attr_name_,
@@ -138,8 +138,8 @@ struct ExtractAttribute<std::vector<int64_t>> {
138138
}
139139
std::vector<int64_t>* attr_value = nullptr;
140140
try {
141-
attr_value = &boost::get<std::vector<int64_t>>(attr);
142-
} catch (boost::bad_get& bad_get) {
141+
attr_value = &paddle::get<std::vector<int64_t>>(attr);
142+
} catch (paddle::bad_variant_access const& bad_get) {
143143
PADDLE_THROW(platform::errors::InvalidArgument(
144144
"Cannot get attribute (%s) by type std::vector<int64_t>, its type is "
145145
"%s.",
@@ -167,8 +167,8 @@ struct ExtractAttribute<float> {
167167
}
168168
float* attr_value = nullptr;
169169
try {
170-
attr_value = &boost::get<float>(attr);
171-
} catch (boost::bad_get& bad_get) {
170+
attr_value = &paddle::get<float>(attr);
171+
} catch (paddle::bad_variant_access const& bad_get) {
172172
PADDLE_THROW(platform::errors::InvalidArgument(
173173
"Cannot get attribute (%s) by type float, its type is %s.",
174174
attr_name_,
@@ -197,8 +197,8 @@ struct ExtractAttribute<std::vector<double>> {
197197
}
198198
std::vector<double>* attr_value = nullptr;
199199
try {
200-
attr_value = &boost::get<std::vector<double>>(attr);
201-
} catch (boost::bad_get& bad_get) {
200+
attr_value = &paddle::get<std::vector<double>>(attr);
201+
} catch (paddle::bad_variant_access const& bad_get) {
202202
PADDLE_THROW(platform::errors::InvalidArgument(
203203
"Cannot get attribute (%s) by type std::vector<double>, its type is "
204204
"%s.",
@@ -214,11 +214,11 @@ struct ExtractAttribute<std::vector<double>> {
214214
template <typename T>
215215
inline proto::AttrType AttrTypeID() {
216216
Attribute tmp = T();
217-
return static_cast<proto::AttrType>(tmp.which() - 1);
217+
return static_cast<proto::AttrType>(tmp.index() - 1);
218218
}
219219

220220
inline proto::AttrType AttrTypeID(const Attribute& attr) {
221-
return static_cast<proto::AttrType>(attr.which() - 1);
221+
return static_cast<proto::AttrType>(attr.index() - 1);
222222
}
223223

224224
class AttrReader {

paddle/fluid/framework/block_desc.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -272,7 +272,7 @@ void BlockDesc::MoveFrom(BlockDesc *block) {
272272
for (const auto &pair : src_op->GetAttrMap()) {
273273
const auto &attr_name = pair.first;
274274
const auto &attr_value = pair.second;
275-
auto attr_type = static_cast<proto::AttrType>(attr_value.which() - 1);
275+
auto attr_type = static_cast<proto::AttrType>(attr_value.index() - 1);
276276
if (attr_type == proto::AttrType::BLOCK) {
277277
auto block_id = BOOST_GET_CONST(BlockDesc *, attr_value)->ID();
278278
dst_op->SetBlockAttr(attr_name, prog_->MutableBlock(block_id));

paddle/fluid/framework/details/async_ssa_graph_executor.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,7 @@ FetchResultType AsyncSSAGraphExecutor::Run(
174174
HandleException();
175175

176176
FetchList ret;
177-
auto &val = BOOST_GET(FetchList, fetch_data);
177+
auto &val = boost::get<FetchList>(fetch_data);
178178
for (size_t fetch_idx = 0; fetch_idx < fetch_tensors.size(); ++fetch_idx) {
179179
if (data_is_lod_tensor(val.at(fetch_idx))) {
180180
std::vector<const LoDTensor *> lodtensor_ptrs;

paddle/fluid/framework/details/fetch_async_op_handle.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,7 @@ void FetchAsyncOpHandle::RunImpl() {
228228
}
229229

230230
if (return_merged_) {
231-
auto &val = BOOST_GET(FetchList, *data_);
231+
auto &val = boost::get<FetchList>(*data_);
232232
if (src_vars[0]->IsType<LoDTensor>()) {
233233
// to lodtensor type
234234
std::vector<const LoDTensor *> src_lodtensors;
@@ -263,7 +263,7 @@ void FetchAsyncOpHandle::RunImpl() {
263263
val.at(offset_) = std::move(dst_lodtensor_array);
264264
}
265265
} else {
266-
auto &val = BOOST_GET(FetchUnmergedList, *data_);
266+
auto &val = boost::get<FetchUnmergedList>(*data_);
267267
auto &dst_tensors = val.at(offset_);
268268
dst_tensors.reserve(src_vars.size());
269269

paddle/fluid/framework/details/fetch_op_handle.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ void FetchOpHandle::WaitAndMergeCPUFetchVars() const {
8484
for (auto &t : tensors_) {
8585
tensors_ptr.emplace_back(&BOOST_GET_CONST(LoDTensor, t));
8686
}
87-
auto &val = BOOST_GET(FetchList, *data_);
87+
auto &val = boost::get<FetchList>(*data_);
8888
LoDTensor var;
8989
MergeLoDTensor(&var, tensors_ptr, platform::CPUPlace());
9090
val.at(offset_) = std::move(var);
@@ -106,11 +106,11 @@ void FetchOpHandle::WaitAndMergeCPUFetchVars() const {
106106
tmp_array.emplace_back();
107107
MergeLoDTensor(&(tmp_array.back()), tensors_ptr, platform::CPUPlace());
108108
}
109-
auto &val = BOOST_GET(FetchList, *data_);
109+
auto &val = boost::get<FetchList>(*data_);
110110
val.at(offset_) = std::move(tmp_array);
111111
}
112112
} else {
113-
auto &val = BOOST_GET(FetchUnmergedList, *data_);
113+
auto &val = boost::get<FetchUnmergedList>(*data_);
114114
val.at(offset_) = std::move(tensors_);
115115
}
116116
}

paddle/fluid/framework/details/parallel_ssa_graph_executor.cc

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -278,8 +278,7 @@ FetchResultType ParallelSSAGraphExecutor::Run(
278278
if (!is_valid[scope_idx]) {
279279
continue;
280280
}
281-
const auto &fetch_list =
282-
BOOST_GET_CONST(FetchList, fetch_data[scope_idx]);
281+
const auto &fetch_list = boost::get<FetchList>(fetch_data[scope_idx]);
283282
if (data_is_lod_tensor(fetch_list[fetch_idx])) {
284283
lodtensor_ptrs.push_back(
285284
&(BOOST_GET_CONST(LoDTensor, fetch_list[fetch_idx])));
@@ -318,7 +317,7 @@ FetchResultType ParallelSSAGraphExecutor::Run(
318317
continue;
319318
}
320319
const auto &fetch_list =
321-
BOOST_GET_CONST(FetchUnmergedList, fetch_data[scope_idx]);
320+
boost::get<FetchUnmergedList>(fetch_data[scope_idx]);
322321
PADDLE_ENFORCE_EQ(
323322
fetch_list[fetch_idx].size(),
324323
1,

0 commit comments

Comments
 (0)