Skip to content

Commit 9fbaf87

Browse files
committed
Fix
2 parents 4282664 + 7f29bea commit 9fbaf87

File tree

893 files changed

+8246
-5788
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

893 files changed

+8246
-5788
lines changed

.pre-commit-config.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,8 @@ repos:
5050
paddle/cinn/utils/registry.h
5151
)$
5252
# For Python files
53-
- repo: https://github.com/psf/black.git
54-
rev: 23.3.0
53+
- repo: https://github.com/psf/black-pre-commit-mirror
54+
rev: 24.8.0
5555
hooks:
5656
- id: black
5757
- repo: https://github.com/astral-sh/ruff-pre-commit

paddle/cinn/adt/equation_util.h

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,12 +75,18 @@ EquationGraphTopoWalker<VT, FT> GetSubgraph(
7575
};
7676
const auto& VisitInputVariables =
7777
[graph, IsSelected](FT function, const std::function<void(VT)>& Visit) {
78-
CHECK(IsSelected(function));
78+
PADDLE_ENFORCE_EQ(
79+
IsSelected(function),
80+
true,
81+
phi::errors::PreconditionNotMet("The function must be selected."));
7982
graph.VisitInputVariables(function, Visit);
8083
};
8184
const auto& VisitOutputVariables =
8285
[graph, IsSelected](FT function, const std::function<void(VT)>& Visit) {
83-
CHECK(IsSelected(function));
86+
PADDLE_ENFORCE_EQ(
87+
IsSelected(function),
88+
true,
89+
phi::errors::PreconditionNotMet("The function must be selected."));
8490
graph.VisitOutputVariables(function, Visit);
8591
};
8692
return EquationGraphTopoWalker<VT, FT>(

paddle/cinn/adt/generate_map_expr.cc

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,7 @@ bool CollectRewrittenReductionOpStmts(const OpStmt& op_stmt,
173173
PADDLE_ENFORCE_EQ(
174174
op.Has<const ::pir::Operation*>(),
175175
true,
176-
phi::errors::InvalidArgument(
176+
::common::errors::InvalidArgument(
177177
"The op should have a value of type ::pir::Operation*"));
178178
if (GetOpPatternKind(op.Get<const ::pir::Operation*>()) ==
179179
hlir::framework::OpPatternKind::kReduction) {
@@ -241,7 +241,7 @@ std::vector<std::shared_ptr<IGroup>> GenerateIGroups(
241241
PADDLE_ENFORCE_EQ(
242242
!op_stmts->empty(),
243243
true,
244-
phi::errors::InvalidArgument("The op_stmts should not be empty"));
244+
::common::errors::InvalidArgument("The op_stmts should not be empty"));
245245

246246
PartitionIGroupOpStmts(op_stmts, [&](const auto& igroup_spec) {
247247
ret.push_back(MakeIGroup(igroup_spec));
@@ -278,12 +278,12 @@ std::unordered_map<Variable, const Value> MakeSdIterator2Iterator(
278278
std::unordered_map<Variable, const Value> ret{};
279279

280280
for (std::size_t i = 0; i < igroup.loop_iterators()->size(); ++i) {
281-
PADDLE_ENFORCE_EQ(
282-
ret.emplace(igroup.loop_iterators()->at(i),
283-
igroup.loop_iterators()->at(i))
284-
.second,
285-
true,
286-
phi::errors::InvalidArgument("The loop iterator should be unique"));
281+
PADDLE_ENFORCE_EQ(ret.emplace(igroup.loop_iterators()->at(i),
282+
igroup.loop_iterators()->at(i))
283+
.second,
284+
true,
285+
::common::errors::InvalidArgument(
286+
"The loop iterator should be unique"));
287287
}
288288

289289
return ret;
@@ -344,10 +344,10 @@ LoopDescriptor4IterVarT MakeGetterLoopDescriptor4IterVar(
344344
using Cache = std::unordered_map<Iterator, LoopDescriptor>;
345345
const auto& sd_iter2sd = std::make_shared<Cache>();
346346
for (std::size_t i = 0; i < loop_iters->size(); ++i) {
347-
PADDLE_ENFORCE_EQ(
348-
sd_iter2sd->emplace(loop_iters->at(i), sd->at(i)).second,
349-
true,
350-
phi::errors::InvalidArgument("The loop iterator should be unique"));
347+
PADDLE_ENFORCE_EQ(sd_iter2sd->emplace(loop_iters->at(i), sd->at(i)).second,
348+
true,
349+
::common::errors::InvalidArgument(
350+
"The loop iterator should be unique"));
351351
}
352352
return [sd_iter2sd](const auto& sd_iter) { return sd_iter2sd->at(sd_iter); };
353353
}
@@ -359,7 +359,7 @@ TreeMerger<Stmt> MakeTreeMerger(const MapIr& map_ir) {
359359
PADDLE_ENFORCE_EQ(
360360
cache->emplace(op_stmt, map_ir.loop_iterators()).second,
361361
true,
362-
phi::errors::InvalidArgument("The op_stmt should be unique"));
362+
::common::errors::InvalidArgument("The op_stmt should be unique"));
363363
}
364364

365365
TreeMerger<Stmt> tree_merger{};
@@ -383,7 +383,7 @@ MapStmt<Stmt> MakeMapStmt(const MapIrList& map_irs) {
383383
"The size of stmts should be 1, but got %d.", stmts->size()));
384384
PADDLE_ENFORCE_EQ(stmts->at(0).Has<MapStmt<Stmt>>(),
385385
true,
386-
phi::errors::InvalidArgument(
386+
::common::errors::InvalidArgument(
387387
"The stmts should have a value of type MapStmt<Stmt>"));
388388
return stmts->at(0).Get<MapStmt<Stmt>>();
389389
}

paddle/cinn/adt/get_sub_reshape_dim_ranges.cc

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -40,10 +40,11 @@ GetSubReshapeDimRanges(const List<DimExpr>& lhs_dims,
4040
PADDLE_ENFORCE_EQ(
4141
!lhs_dims->empty(),
4242
true,
43-
phi::errors::InvalidArgument("Sorry,but lhs_dims is empty"));
44-
PADDLE_ENFORCE_EQ(!rhs_dims->empty(),
45-
true,
46-
phi::errors::InvalidArgument("Sory,but rhs_dims is empty"));
43+
::common::errors::InvalidArgument("Sorry,but lhs_dims is empty"));
44+
PADDLE_ENFORCE_EQ(
45+
!rhs_dims->empty(),
46+
true,
47+
::common::errors::InvalidArgument("Sory,but rhs_dims is empty"));
4748
std::vector<std::pair<int, int>> lhs_ranges{};
4849
std::vector<std::pair<int, int>> rhs_ranges{};
4950
int lhs_start = 0;
@@ -59,7 +60,7 @@ GetSubReshapeDimRanges(const List<DimExpr>& lhs_dims,
5960
PADDLE_ENFORCE_EQ(
6061
dims->at(i).Has<std::int64_t>(),
6162
true,
62-
phi::errors::InvalidArgument("dims->at(i) is not int64_t"));
63+
::common::errors::InvalidArgument("dims->at(i) is not int64_t"));
6364
ret *= dims->at(i).Get<std::int64_t>();
6465
}
6566
return ret;
@@ -95,7 +96,7 @@ GetSubReshapeDimRanges(const List<DimExpr>& lhs_dims,
9596
}
9697
PADDLE_ENFORCE_EQ(lhs_end == lhs_dims->size() && rhs_end == rhs_dims->size(),
9798
true,
98-
phi::errors::InvalidArgument(
99+
::common::errors::InvalidArgument(
99100
"lhs_end is not equal to lhs_dims->size() and rhs_end "
100101
"is not equal to rhs_dims->size()"));
101102
if (lhs_start < lhs_end && rhs_start < rhs_end) {

paddle/cinn/adt/igroup.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ std::shared_ptr<IndexExprInferContext> MakeIndexExprInferContext(
3232
.emplace(anchor_iterators->at(i), anchor_iterators->at(i))
3333
.second,
3434
true,
35-
phi::errors::InvalidArgument(
35+
::common::errors::InvalidArgument(
3636
"The element in anchor iterators failed to insert in anchor "
3737
"iterator2value! Please check."));
3838
}

paddle/cinn/adt/igroup.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ class IGroup final {
9494
const List<Iterator>& loop_iterators() const {
9595
PADDLE_ENFORCE_EQ(anchor_sd_equation_ctx_.has_value(),
9696
true,
97-
phi::errors::InvalidArgument(
97+
::common::errors::InvalidArgument(
9898
"The anchor_sd_equation_ctx_ has no value."));
9999
return anchor_sd_equation_ctx_.value().sd_iterators();
100100
}
@@ -128,7 +128,7 @@ class IGroup final {
128128
PADDLE_ENFORCE_EQ(
129129
index2tensor->emplace(index, tensor).second,
130130
true,
131-
phi::errors::InvalidArgument(
131+
::common::errors::InvalidArgument(
132132
"The index2tensor map has already contained the index."));
133133
(*tensor2indexes)[tensor].emplace_back(index);
134134
}
@@ -138,7 +138,7 @@ class IGroup final {
138138
PADDLE_ENFORCE_EQ(
139139
index2tensor->emplace(index, tensor).second,
140140
true,
141-
phi::errors::InvalidArgument(
141+
::common::errors::InvalidArgument(
142142
"The index2tensor map has already contained the index."));
143143
(*tensor2indexes)[tensor].emplace_back(index);
144144
}

paddle/cinn/adt/inline_translator.h

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ struct InlineTranslator final {
3535
static DstTree Call(const SrcTree& src_tree) {
3636
PADDLE_ENFORCE_EQ((src_tree.template Has<MapT<SrcTree>>()),
3737
true,
38-
phi::errors::InvalidArgument(
38+
::common::errors::InvalidArgument(
3939
"src_tree.template should have <MapT<SrcTree>>()"));
4040
const MapT<DstTree> dst_tree =
4141
CallMap(src_tree.template Get<MapT<SrcTree>>());
@@ -102,7 +102,7 @@ struct InlineTranslator final {
102102
const auto& [arg_tensor] = arg_leaf.tuple();
103103
PADDLE_ENFORCE_EQ(producer_tensor == arg_tensor,
104104
true,
105-
phi::errors::InvalidArgument(
105+
::common::errors::InvalidArgument(
106106
"producer_tensor should be equal to arg_tensor"));
107107
List<OpExpr> ret{};
108108
ret->assign(op_call_children->begin(), op_call_children->end());
@@ -117,7 +117,7 @@ struct InlineTranslator final {
117117
PADDLE_ENFORCE_EQ(
118118
(consumer_tree.template Has<OpCallT<OpExpr>>()),
119119
true,
120-
phi::errors::InvalidArgument(
120+
::common::errors::InvalidArgument(
121121
"consumer_tree.template should have <OpCallT<OpExpr>>()"));
122122
const auto& op_call = consumer_tree.template Get<OpCallT<OpExpr>>();
123123
const auto& op_call_children =
@@ -126,7 +126,7 @@ struct InlineTranslator final {
126126
PADDLE_ENFORCE_EQ(
127127
(op_call_child.template Has<Load<TensorT>>()),
128128
true,
129-
phi::errors::InvalidArgument(
129+
::common::errors::InvalidArgument(
130130
"op_call_child.template should have <Load<TensorT>>()"));
131131
}
132132

@@ -181,7 +181,7 @@ struct InlineTranslator final {
181181
index2dst_leaf.emplace(i, NaiveTranslateLeaf(*std::next(begin, i)))
182182
.second,
183183
true,
184-
phi::errors::InvalidArgument(
184+
::common::errors::InvalidArgument(
185185
"index2dst_leaf.emplace should return true"));
186186
}
187187
// Inline dst leaves
@@ -215,7 +215,7 @@ struct InlineTranslator final {
215215
static DstLeaf NaiveTranslateLeaf(const SrcTree& src_tree) {
216216
PADDLE_ENFORCE_EQ(src_tree.template Has<SrcLeaf>(),
217217
true,
218-
phi::errors::InvalidArgument(
218+
::common::errors::InvalidArgument(
219219
"src_tree.template should have <SrcLeaf>()"));
220220
const auto& [tensor, op_call] = src_tree.template Get<SrcLeaf>().tuple();
221221
const List<Load<TensorT>>& src_loads =

paddle/cinn/adt/m_ir.cc

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,12 +38,16 @@ void CollectTensorIndexIterators(const TensorIndexExpr& tensor_index_expr,
3838

3939
void CollectTensorIndexIteratorsImpl(const Undefined& tensor_index_expr,
4040
std::unordered_set<Iterator>* ret) {
41-
PADDLE_THROW(::common::errors::Unimplemented("Not Implemented"));
41+
PADDLE_THROW(::common::errors::Unimplemented(
42+
"CollectTensorIndexIteratorsImpl is not implemented for Undefined tensor "
43+
"index expression. Please check your input."));
4244
}
4345

4446
void CollectTensorIndexIteratorsImpl(const Ok& ok,
4547
std::unordered_set<Iterator>* ret) {
46-
PADDLE_THROW(::common::errors::Unimplemented("Not Implemented"));
48+
PADDLE_THROW(::common::errors::Unimplemented(
49+
"CollectTensorIndexIteratorsImpl is not implemented for Ok state. Please "
50+
"ensure the function is correctly called."));
4751
}
4852

4953
void CollectTensorIndexIteratorsImpl(const Iterator& iterator,

paddle/cinn/adt/map_expr_ctx.h

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,12 @@ class MapExprCtx final {
4141
::pir::Operation* node,
4242
const std::vector<ir::LoweredFunc>& lowered_funcs) {
4343
Node2LoweredFuncs* map = &node2lowered_funcs_;
44-
CHECK(map->emplace(node, ir::ir_utils::IRCopy(lowered_funcs)).second);
44+
PADDLE_ENFORCE_EQ(
45+
map->emplace(node, ir::ir_utils::IRCopy(lowered_funcs)).second,
46+
true,
47+
::common::errors::InvalidArgument(
48+
"Failed to emplace the node in the map. Ensure that the node is "
49+
"valid and the operation is correct."));
4550
}
4651

4752
const Node2LoweredFuncs& node2lowered_funcs() const {

paddle/cinn/adt/naive_bidirection_equation_generator.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ OpArgIndexes<std::optional<Index>> MakeOutMsgOpArgIndexes(
5353
for (const auto& out_msg_in_index : *opt_out_msg_in_indexes) {
5454
PADDLE_ENFORCE_EQ(out_msg_in_index.has_value(),
5555
true,
56-
phi::errors::InvalidArgument(
56+
::common::errors::InvalidArgument(
5757
"The out_msg_in_index should have value."));
5858
out_msg_in_indexes->emplace_back(out_msg_in_index.value());
5959
}
@@ -118,7 +118,7 @@ void NaiveBidirectionEquationGenerator::InitInMsgIndex2OutMsgIndex() {
118118
this->in_msg_index2out_msg_index_.emplace(in_index, out_index)
119119
.second,
120120
true,
121-
phi::errors::InvalidArgument(
121+
::common::errors::InvalidArgument(
122122
"The out_msg_index2in_msg_index_ map has already "
123123
"contained the out_index."));
124124
});
@@ -172,7 +172,7 @@ NaiveBidirectionEquationGenerator::MakeGetterOpStmt4OpPlaceHolder() const {
172172
->emplace(fake_op_placeholders_->at(i), op_stmts_->at(i))
173173
.second,
174174
true,
175-
phi::errors::InvalidArgument(
175+
::common::errors::InvalidArgument(
176176
"The fake_op_placeholder2op_stmt map has already contained the "
177177
"fake_op_placeholder."));
178178
}

0 commit comments

Comments
 (0)