Skip to content

Commit 17daf99

Browse files
committed
fix conflict
2 parents 9e10005 + 0663608 commit 17daf99

File tree

130 files changed

+4477
-1800
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

130 files changed

+4477
-1800
lines changed

paddle/cinn/ast_gen_ius/ast_gen.cc

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,10 @@ PD_DECLARE_bool(cinn_bucket_compile);
2828
namespace cinn {
2929
namespace ast_gen_ius {
3030

31+
bool IsReduceBool(const ir::Expr& lhs, const ir::Expr& rhs) {
32+
return lhs.type().is_bool() || rhs.type().is_bool();
33+
}
34+
3135
ir::Expr ConvertReduceBody(ir::Expr body,
3236
ir::Tensor tensor,
3337
const std::vector<Expr>& axis_exprs) {
@@ -38,9 +42,17 @@ ir::Expr ConvertReduceBody(ir::Expr body,
3842

3943
switch (reduce_node->reduce_type) {
4044
case ir::Reduce::kSum:
45+
if (IsReduceBool(tensor(axis_exprs), reduce_node->body)) {
46+
return ir::Store::Make(
47+
tensor, tensor(axis_exprs) || reduce_node->body, axis_exprs);
48+
}
4149
return ir::Store::Make(
4250
tensor, tensor(axis_exprs) + reduce_node->body, axis_exprs);
4351
case ir::Reduce::kMul:
52+
if (IsReduceBool(tensor(axis_exprs), reduce_node->body)) {
53+
return ir::Store::Make(
54+
tensor, tensor(axis_exprs) && reduce_node->body, axis_exprs);
55+
}
4456
return ir::Store::Make(
4557
tensor, tensor(axis_exprs) * reduce_node->body, axis_exprs);
4658
case ir::Reduce::kMax:

paddle/cinn/frontend/paddle/cpp/block_desc.cc

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,18 +13,29 @@
1313
// limitations under the License.
1414

1515
#include "paddle/cinn/frontend/paddle/cpp/block_desc.h"
16+
#include "paddle/common/enforce.h"
1617

1718
namespace cinn::frontend::paddle::cpp {
1819

1920
template <>
2021
VarDesc* BlockDesc::GetVar<VarDesc>(int32_t idx) {
21-
CHECK_LT(idx, VarsSize()) << "idx >= vars.size()";
22+
PADDLE_ENFORCE_LT(
23+
idx,
24+
VarsSize(),
25+
phi::errors::InvalidArgument(
26+
"The value of idx and vars.size() is incorrect."
27+
"Expected idx < vars.size(), but receive idx >= vars.size()."));
2228
return &vars_[idx];
2329
}
2430

2531
template <>
2632
const VarDesc& BlockDesc::GetConstVar<VarDesc>(int32_t idx) const {
27-
CHECK_LT(idx, static_cast<int32_t>(VarsSize())) << "idx >= vars.size()";
33+
PADDLE_ENFORCE_LT(
34+
idx,
35+
static_cast<int32_t>(VarsSize()),
36+
phi::errors::InvalidArgument(
37+
"The value of idx and vars.size() is incorrect."
38+
"Expected idx < vars.size(), but receive idx >= vars.size()."));
2839
return vars_[idx];
2940
}
3041

@@ -36,13 +47,23 @@ VarDesc* BlockDesc::AddVar<VarDesc>() {
3647

3748
template <>
3849
OpDesc* BlockDesc::GetOp<OpDesc>(int32_t idx) {
39-
CHECK_LT(idx, OpsSize()) << "idx >= ops.size()";
50+
PADDLE_ENFORCE_LT(
51+
idx,
52+
OpsSize(),
53+
phi::errors::InvalidArgument(
54+
"The value of idx and ops.size() is incorrect."
55+
"Expected idx < ops.size(), but receive idx >= ops.size()."));
4056
return &ops_[idx];
4157
}
4258

4359
template <>
4460
const OpDesc& BlockDesc::GetConstOp<OpDesc>(int32_t idx) const {
45-
CHECK_LT(idx, static_cast<int32_t>(OpsSize())) << "idx >= ops.size()";
61+
PADDLE_ENFORCE_LT(
62+
idx,
63+
static_cast<int32_t>(OpsSize()),
64+
phi::errors::InvalidArgument(
65+
"The value of idx and ops.size() is incorrect."
66+
"Expected idx < ops.size(), but receive idx >= ops.size()."));
4667
return ops_[idx];
4768
}
4869

paddle/cinn/frontend/paddle/cpp/program_desc.cc

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,18 +13,29 @@
1313
// limitations under the License.
1414

1515
#include "paddle/cinn/frontend/paddle/cpp/program_desc.h"
16+
#include "paddle/common/enforce.h"
1617

1718
namespace cinn::frontend::paddle::cpp {
1819

1920
template <>
2021
BlockDesc* ProgramDesc::GetBlock<BlockDesc>(int32_t idx) {
21-
CHECK_LT(idx, BlocksSize()) << "idx >= blocks.size()";
22+
PADDLE_ENFORCE_LT(
23+
idx,
24+
BlocksSize(),
25+
phi::errors::InvalidArgument(
26+
"The value of idx and blocks.size() is incorrect."
27+
"Expected idx < blocks.size(), but receive idx >= blocks.size()."));
2228
return &blocks_[idx];
2329
}
2430

2531
template <>
2632
const BlockDesc& ProgramDesc::GetConstBlock<BlockDesc>(int32_t idx) const {
27-
CHECK_LT(idx, static_cast<int32_t>(BlocksSize())) << "idx >= blocks.size()";
33+
PADDLE_ENFORCE_LT(
34+
idx,
35+
static_cast<int32_t>(BlocksSize()),
36+
phi::errors::InvalidArgument(
37+
"The value of idx and blocks.size() is incorrect."
38+
"Expected idx < blocks.size(), but receive idx >= blocks.size()."));
2839
return blocks_[idx];
2940
}
3041

paddle/cinn/frontend/paddle/model_parser.cc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
#include "paddle/cinn/backends/cuda_util.h"
2424
#include "paddle/cinn/common/common.h"
2525
#include "paddle/cinn/frontend/paddle/compatible_pb.h"
26+
#include "paddle/common/enforce.h"
2627

2728
namespace cinn::frontend::paddle {
2829

@@ -55,7 +56,8 @@ void TensorFromStream(std::istream &is,
5556
using Type = framework_proto::VarType::Type;
5657
uint32_t version;
5758
is.read(reinterpret_cast<char *>(&version), sizeof(version));
58-
CHECK_EQ(version, 0U) << "Only version 0 is supported";
59+
PADDLE_ENFORCE_EQ(
60+
version, 0U, phi::errors::InvalidArgument("Only version 0 is supported"));
5961
// read tensor desc
6062
framework_proto::VarType::TensorDesc desc;
6163
{

paddle/cinn/frontend/paddle/pb/block_desc.cc

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,19 @@
1313
// limitations under the License.
1414

1515
#include "paddle/cinn/frontend/paddle/pb/block_desc.h"
16+
#include "paddle/common/enforce.h"
1617

1718
namespace cinn::frontend::paddle::pb {
1819

1920
template <>
2021
framework_proto::VarDesc* BlockDesc::GetVar<framework_proto::VarDesc>(
2122
int32_t idx) {
22-
CHECK_LT(idx, VarsSize()) << "idx >= vars.size()";
23+
PADDLE_ENFORCE_LT(
24+
idx,
25+
VarsSize(),
26+
phi::errors::InvalidArgument(
27+
"The value of idx and vars.size() is incorrect."
28+
"Expected idx < vars.size(), but receive idx >= vars.size()."));
2329
return desc_->mutable_vars(idx);
2430
}
2531

@@ -31,7 +37,12 @@ framework_proto::VarDesc* BlockDesc::AddVar<framework_proto::VarDesc>() {
3137
template <>
3238
framework_proto::OpDesc* BlockDesc::GetOp<framework_proto::OpDesc>(
3339
int32_t idx) {
34-
CHECK_LT(idx, OpsSize()) << "idx >= ops.size()";
40+
PADDLE_ENFORCE_LT(
41+
idx,
42+
OpsSize(),
43+
phi::errors::InvalidArgument(
44+
"The value of idx and ops.size() is incorrect."
45+
"Expected idx < ops.size(), but receive idx >= ops.size()."));
3546
return desc_->mutable_ops(idx);
3647
}
3748

paddle/cinn/frontend/paddle/pb/program_desc.cc

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,19 @@
1717
#include <algorithm>
1818
#include <limits>
1919

20+
#include "paddle/common/enforce.h"
21+
2022
namespace cinn::frontend::paddle::pb {
2123

2224
template <>
2325
framework_proto::BlockDesc* ProgramDesc::GetBlock<framework_proto::BlockDesc>(
2426
int32_t idx) {
25-
CHECK_LT(idx, BlocksSize()) << "idx >= blocks.size()";
27+
PADDLE_ENFORCE_LT(
28+
idx,
29+
BlocksSize(),
30+
phi::errors::InvalidArgument(
31+
"The value of idx and blocks.size() is incorrect."
32+
"Expected idx < blocks.size(), but receive idx >= blocks.size()."));
2633
return desc_->mutable_blocks(idx);
2734
}
2835

0 commit comments

Comments
 (0)