Skip to content

Commit f99b300

Browse files
Jeff114514lixcli
authored andcommitted
【Error Message No. 23 BUAA】rewrite error message (PaddlePaddle#66455)
* rewrite err msg * change to phi namespace * retest * fix bug * fix
1 parent 6e45ac5 commit f99b300

File tree

1 file changed

+24
-6
lines changed

1 file changed

+24
-6
lines changed

paddle/cinn/hlir/op/contrib/argmax.cc

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -48,14 +48,25 @@ std::vector<ir::Tensor> Argmax(const Tensor &in_tensor,
4848
const std::string &name) {
4949
auto shape = in_tensor->shape;
5050
auto ndim = shape.size();
51-
CHECK_GT(ndim, 0) << "tensor's dim must be more than 0";
51+
PADDLE_ENFORCE_GT(
52+
ndim,
53+
0,
54+
phi::errors::InvalidArgument(
55+
"The dimension of input tensor must be greater than 0."));
5256

5357
int pos_axis = axis;
5458
if (axis < 0) {
5559
pos_axis = static_cast<int>(ndim) + axis;
5660
}
57-
CHECK_LT(pos_axis, ndim) << "Axis must be less than tensor's dim";
58-
CHECK_GE(pos_axis, 0) << "Axis must be more than 0";
61+
PADDLE_ENFORCE_LT(
62+
pos_axis,
63+
ndim,
64+
phi::errors::InvalidArgument(
65+
"The axis must be less than the dimension of input tensor."));
66+
PADDLE_ENFORCE_GE(pos_axis,
67+
0,
68+
phi::errors::InvalidArgument(
69+
"The axis must be greater than or equal to 0."));
5970

6071
std::vector<Expr> output_shape;
6172
for (int i = 0; i < shape.size(); ++i) {
@@ -114,12 +125,19 @@ std::shared_ptr<framework::OpStrategy> StrategyForArgmax(
114125
<< "The input argument of argmax compute is empty! Please check.";
115126
cinn::common::CINNValuePack pack_args = args[0];
116127
std::string tensor_name = UniqName("Argmax_out");
117-
CHECK_GE(pack_args.size(), 1U)
118-
<< "There should be 1 input args for argmax compute";
128+
PADDLE_ENFORCE_GE(
129+
pack_args.size(),
130+
1U,
131+
phi::errors::InvalidArgument(
132+
"There should be 1 input args for argmax compute"));
119133
Expr in_expr = pack_args[0];
120134
CHECK(in_expr.as_tensor());
121135
Tensor in_tensor = in_expr.as_tensor_ref();
122-
CHECK_EQ(pack_args.size(), 2U);
136+
PADDLE_ENFORCE_EQ(
137+
pack_args.size(),
138+
2U,
139+
phi::errors::InvalidArgument(
140+
"The input argument of argmax compute must be 2."));
123141
CHECK(pack_args[1].is_string());
124142
tensor_name = pack_args[1].operator std::string();
125143
std::vector<ir::Tensor> out_tensor =

0 commit comments

Comments
 (0)