@@ -105,7 +105,7 @@ bool HasDynamicShape(const ::pir::Value& tensor) {
105105 PADDLE_ENFORCE_EQ (
106106 dim,
107107 -1UL ,
108- phi ::errors::InvalidArgument (
108+ ::common ::errors::InvalidArgument (
109109 " The dynamic shape dim should be -1, but got %d." , dim));
110110 return true ;
111111 }
@@ -170,7 +170,11 @@ hlir::framework::OpPatternKind GetOpPatternKind(const ::pir::Operation* node) {
170170bool CollectRewrittenReductionOpStmts (const OpStmt& op_stmt,
171171 List<OpStmt>* ret) {
172172 const auto & [op, inputs, outputs] = op_stmt.tuple ();
173- CHECK (op.Has <const ::pir::Operation*>());
173+ PADDLE_ENFORCE_EQ (
174+ op.Has <const ::pir::Operation*>(),
175+ true ,
176+ phi::errors::InvalidArgument (
177+ " The op should have a value of type ::pir::Operation*" ));
174178 if (GetOpPatternKind (op.Get <const ::pir::Operation*>()) ==
175179 hlir::framework::OpPatternKind::kReduction ) {
176180 tReduceInit<const ::pir::Operation*> init_op{
@@ -234,7 +238,10 @@ std::vector<std::shared_ptr<IGroup>> GenerateIGroups(
234238 std::vector<std::shared_ptr<IGroup>> ret{};
235239
236240 List<OpStmt> op_stmts = MakeOpStmts (group);
237- CHECK (!op_stmts->empty ());
241+ PADDLE_ENFORCE_EQ (
242+ !op_stmts->empty (),
243+ true ,
244+ phi::errors::InvalidArgument (" The op_stmts should not be empty" ));
238245
239246 PartitionIGroupOpStmts (op_stmts, [&](const auto & igroup_spec) {
240247 ret.push_back (MakeIGroup (igroup_spec));
@@ -249,7 +256,7 @@ std::shared_ptr<KGroup> GenerateKGroups(
249256 PADDLE_ENFORCE_EQ (
250257 igroups.size (),
251258 1UL ,
252- phi ::errors::InvalidArgument (
259+ ::common ::errors::InvalidArgument (
253260 " The size of igroups should be 1, but got %d." , igroups.size()));
254261 return std::make_shared<KGroup>(group, igroups);
255262}
@@ -271,9 +278,12 @@ std::unordered_map<Variable, const Value> MakeSdIterator2Iterator(
271278 std::unordered_map<Variable, const Value> ret{};
272279
273280 for (std::size_t i = 0 ; i < igroup.loop_iterators ()->size (); ++i) {
274- CHECK (ret.emplace (igroup.loop_iterators ()->at (i),
275- igroup.loop_iterators ()->at (i))
276- .second );
281+ PADDLE_ENFORCE_EQ (
282+ ret.emplace (igroup.loop_iterators ()->at (i),
283+ igroup.loop_iterators ()->at (i))
284+ .second ,
285+ true ,
286+ phi::errors::InvalidArgument (" The loop iterator should be unique" ));
277287 }
278288
279289 return ret;
@@ -326,15 +336,18 @@ LoopDescriptor4IterVarT MakeGetterLoopDescriptor4IterVar(
326336 PADDLE_ENFORCE_EQ (
327337 loop_iters->size (),
328338 sd->size (),
329- phi ::errors::InvalidArgument (
339+ ::common ::errors::InvalidArgument (
330340 " The size of loop iterators and loop descriptors should be equal, "
331341 " but got loop iterators size = %d, loop descriptors size = %d." ,
332342 loop_iters->size (),
333343 sd->size()));
334344 using Cache = std::unordered_map<Iterator, LoopDescriptor>;
335345 const auto & sd_iter2sd = std::make_shared<Cache>();
336346 for (std::size_t i = 0 ; i < loop_iters->size (); ++i) {
337- CHECK (sd_iter2sd->emplace (loop_iters->at (i), sd->at (i)).second );
347+ PADDLE_ENFORCE_EQ (
348+ sd_iter2sd->emplace (loop_iters->at (i), sd->at (i)).second ,
349+ true ,
350+ phi::errors::InvalidArgument (" The loop iterator should be unique" ));
338351 }
339352 return [sd_iter2sd](const auto & sd_iter) { return sd_iter2sd->at (sd_iter); };
340353}
@@ -343,7 +356,10 @@ TreeMerger<Stmt> MakeTreeMerger(const MapIr& map_ir) {
343356 using Cache = std::unordered_map<OpStmt, LoopIterators>;
344357 auto cache = std::make_shared<Cache>();
345358 for (const auto & op_stmt : *(map_ir.op_stmts ())) {
346- CHECK (cache->emplace (op_stmt, map_ir.loop_iterators ()).second );
359+ PADDLE_ENFORCE_EQ (
360+ cache->emplace (op_stmt, map_ir.loop_iterators ()).second ,
361+ true ,
362+ phi::errors::InvalidArgument (" The op_stmt should be unique" ));
347363 }
348364
349365 TreeMerger<Stmt> tree_merger{};
@@ -363,9 +379,12 @@ MapStmt<Stmt> MakeMapStmt(const MapIrList& map_irs) {
363379 PADDLE_ENFORCE_EQ (
364380 stmts->size (),
365381 1UL ,
366- phi::errors::InvalidArgument (" The size of stmts should be 1, but got %d." ,
367- stmts->size ()));
368- CHECK (stmts->at (0 ).Has <MapStmt<Stmt>>());
382+ ::common::errors::InvalidArgument (
383+ " The size of stmts should be 1, but got %d." , stmts->size ()));
384+ PADDLE_ENFORCE_EQ (stmts->at (0 ).Has<MapStmt<Stmt>>(),
385+ true,
386+ phi::errors::InvalidArgument(
387+ " The stmts should have a value of type MapStmt<Stmt>" ));
369388 return stmts->at (0 ).Get<MapStmt<Stmt>>();
370389}
371390
0 commit comments