Skip to content

Commit 60f8ebe

Browse files
committed
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into hack6_13
2 parents d32eb2d + a14320d commit 60f8ebe

File tree

246 files changed

+1756
-4992
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

246 files changed

+1756
-4992
lines changed

paddle/cinn/frontend/paddle/cpp/block_desc.cc

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,18 +13,29 @@
1313
// limitations under the License.
1414

1515
#include "paddle/cinn/frontend/paddle/cpp/block_desc.h"
16+
#include "paddle/common/enforce.h"
1617

1718
namespace cinn::frontend::paddle::cpp {
1819

1920
template <>
2021
VarDesc* BlockDesc::GetVar<VarDesc>(int32_t idx) {
21-
CHECK_LT(idx, VarsSize()) << "idx >= vars.size()";
22+
PADDLE_ENFORCE_LT(
23+
idx,
24+
VarsSize(),
25+
phi::errors::InvalidArgument(
26+
"The value of idx and vars.size() is incorrect."
27+
"Expected idx < vars.size(), but receive idx >= vars.size()."));
2228
return &vars_[idx];
2329
}
2430

2531
template <>
2632
const VarDesc& BlockDesc::GetConstVar<VarDesc>(int32_t idx) const {
27-
CHECK_LT(idx, static_cast<int32_t>(VarsSize())) << "idx >= vars.size()";
33+
PADDLE_ENFORCE_LT(
34+
idx,
35+
static_cast<int32_t>(VarsSize()),
36+
phi::errors::InvalidArgument(
37+
"The value of idx and vars.size() is incorrect."
38+
"Expected idx < vars.size(), but receive idx >= vars.size()."));
2839
return vars_[idx];
2940
}
3041

@@ -36,13 +47,23 @@ VarDesc* BlockDesc::AddVar<VarDesc>() {
3647

3748
template <>
3849
OpDesc* BlockDesc::GetOp<OpDesc>(int32_t idx) {
39-
CHECK_LT(idx, OpsSize()) << "idx >= ops.size()";
50+
PADDLE_ENFORCE_LT(
51+
idx,
52+
OpsSize(),
53+
phi::errors::InvalidArgument(
54+
"The value of idx and ops.size() is incorrect."
55+
"Expected idx < ops.size(), but receive idx >= ops.size()."));
4056
return &ops_[idx];
4157
}
4258

4359
template <>
4460
const OpDesc& BlockDesc::GetConstOp<OpDesc>(int32_t idx) const {
45-
CHECK_LT(idx, static_cast<int32_t>(OpsSize())) << "idx >= ops.size()";
61+
PADDLE_ENFORCE_LT(
62+
idx,
63+
static_cast<int32_t>(OpsSize()),
64+
phi::errors::InvalidArgument(
65+
"The value of idx and ops.size() is incorrect."
66+
"Expected idx < ops.size(), but receive idx >= ops.size()."));
4667
return ops_[idx];
4768
}
4869

paddle/cinn/frontend/paddle/cpp/program_desc.cc

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,18 +13,29 @@
1313
// limitations under the License.
1414

1515
#include "paddle/cinn/frontend/paddle/cpp/program_desc.h"
16+
#include "paddle/common/enforce.h"
1617

1718
namespace cinn::frontend::paddle::cpp {
1819

1920
template <>
2021
BlockDesc* ProgramDesc::GetBlock<BlockDesc>(int32_t idx) {
21-
CHECK_LT(idx, BlocksSize()) << "idx >= blocks.size()";
22+
PADDLE_ENFORCE_LT(
23+
idx,
24+
BlocksSize(),
25+
phi::errors::InvalidArgument(
26+
"The value of idx and blocks.size() is incorrect."
27+
"Expected idx < blocks.size(), but receive idx >= blocks.size()."));
2228
return &blocks_[idx];
2329
}
2430

2531
template <>
2632
const BlockDesc& ProgramDesc::GetConstBlock<BlockDesc>(int32_t idx) const {
27-
CHECK_LT(idx, static_cast<int32_t>(BlocksSize())) << "idx >= blocks.size()";
33+
PADDLE_ENFORCE_LT(
34+
idx,
35+
static_cast<int32_t>(BlocksSize()),
36+
phi::errors::InvalidArgument(
37+
"The value of idx and blocks.size() is incorrect."
38+
"Expected idx < blocks.size(), but receive idx >= blocks.size()."));
2839
return blocks_[idx];
2940
}
3041

paddle/cinn/frontend/paddle/model_parser.cc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
#include "paddle/cinn/backends/cuda_util.h"
2424
#include "paddle/cinn/common/common.h"
2525
#include "paddle/cinn/frontend/paddle/compatible_pb.h"
26+
#include "paddle/common/enforce.h"
2627

2728
namespace cinn::frontend::paddle {
2829

@@ -55,7 +56,8 @@ void TensorFromStream(std::istream &is,
5556
using Type = framework_proto::VarType::Type;
5657
uint32_t version;
5758
is.read(reinterpret_cast<char *>(&version), sizeof(version));
58-
CHECK_EQ(version, 0U) << "Only version 0 is supported";
59+
PADDLE_ENFORCE_EQ(
60+
version, 0U, phi::errors::InvalidArgument("Only version 0 is supported"));
5961
// read tensor desc
6062
framework_proto::VarType::TensorDesc desc;
6163
{

paddle/cinn/frontend/paddle/pb/block_desc.cc

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,19 @@
1313
// limitations under the License.
1414

1515
#include "paddle/cinn/frontend/paddle/pb/block_desc.h"
16+
#include "paddle/common/enforce.h"
1617

1718
namespace cinn::frontend::paddle::pb {
1819

1920
template <>
2021
framework_proto::VarDesc* BlockDesc::GetVar<framework_proto::VarDesc>(
2122
int32_t idx) {
22-
CHECK_LT(idx, VarsSize()) << "idx >= vars.size()";
23+
PADDLE_ENFORCE_LT(
24+
idx,
25+
VarsSize(),
26+
phi::errors::InvalidArgument(
27+
"The value of idx and vars.size() is incorrect."
28+
"Expected idx < vars.size(), but receive idx >= vars.size()."));
2329
return desc_->mutable_vars(idx);
2430
}
2531

@@ -31,7 +37,12 @@ framework_proto::VarDesc* BlockDesc::AddVar<framework_proto::VarDesc>() {
3137
template <>
3238
framework_proto::OpDesc* BlockDesc::GetOp<framework_proto::OpDesc>(
3339
int32_t idx) {
34-
CHECK_LT(idx, OpsSize()) << "idx >= ops.size()";
40+
PADDLE_ENFORCE_LT(
41+
idx,
42+
OpsSize(),
43+
phi::errors::InvalidArgument(
44+
"The value of idx and ops.size() is incorrect."
45+
"Expected idx < ops.size(), but receive idx >= ops.size()."));
3546
return desc_->mutable_ops(idx);
3647
}
3748

paddle/cinn/frontend/paddle/pb/program_desc.cc

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,19 @@
1717
#include <algorithm>
1818
#include <limits>
1919

20+
#include "paddle/common/enforce.h"
21+
2022
namespace cinn::frontend::paddle::pb {
2123

2224
template <>
2325
framework_proto::BlockDesc* ProgramDesc::GetBlock<framework_proto::BlockDesc>(
2426
int32_t idx) {
25-
CHECK_LT(idx, BlocksSize()) << "idx >= blocks.size()";
27+
PADDLE_ENFORCE_LT(
28+
idx,
29+
BlocksSize(),
30+
phi::errors::InvalidArgument(
31+
"The value of idx and blocks.size() is incorrect."
32+
"Expected idx < blocks.size(), but receive idx >= blocks.size()."));
2633
return desc_->mutable_blocks(idx);
2734
}
2835

paddle/cinn/hlir/dialect/operator/transforms/check_infer_symbolic_util.cc

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,10 +37,15 @@ DimExprs4ValueT MakeDimExprs4Value(
3737
std::shared_ptr<pir::PassManager> pass_manager = CreatePassManager();
3838
pass_manager->AddPass(pir::CreateShapeOptimizationPass());
3939
pass_manager->Run(program);
40-
const auto* shape_analysis =
41-
&pir::ShapeAnalysisManager::Instance().Get(program);
40+
auto* shape_analysis = &pir::ShapeAnalysisManager::Instance().Get(program);
4241
return
4342
[shape_analysis](pir::Value value) -> const symbol::ShapeOrDataDimExprs& {
43+
// TODO(Hongqing-work): define a default empty ShapeOrDataDimExprss
44+
if (!value) {
45+
static symbol::ShapeOrDataDimExprs empty{
46+
symbol::TensorShapeOrDataDimExprs{}};
47+
return empty;
48+
}
4449
return shape_analysis->GetShapeOrDataForValue(value);
4550
};
4651
}

paddle/cinn/hlir/dialect/operator/transforms/group_merge/op_with_group_merge_pass.cc

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -236,10 +236,8 @@ int GetSharedSize(::pir::Operation* op) {
236236
return 0;
237237
}
238238

239-
using ConditionFunction =
240-
std::function<bool(::pir::Operation*,
241-
const GroupPtr&,
242-
const ::pir::ShapeConstraintIRAnalysis&)>;
239+
using ConditionFunction = std::function<bool(
240+
::pir::Operation*, const GroupPtr&, ::pir::ShapeConstraintIRAnalysis*)>;
243241

244242
// Op Fusion Pass which performs Ops fusion, Ops are fused
245243
// "vertically", meaning producing Ops are fused into their consumers
@@ -354,7 +352,7 @@ class OpFusionPassHelper {
354352

355353
private:
356354
void DoOpFusion() {
357-
const auto& shape_analysis = pir::ShapeAnalysisManager::Instance().Get(
355+
auto& shape_analysis = pir::ShapeAnalysisManager::Instance().Get(
358356
ops_.front()->GetParentProgram());
359357
for (auto consumer : ops_) {
360358
auto consumer_kind =
@@ -413,7 +411,7 @@ class OpFusionPassHelper {
413411
}
414412
}
415413

416-
if (!can_fuse || !CanFuse(producer, consumer, shape_analysis)) {
414+
if (!can_fuse || !CanFuse(producer, consumer, &shape_analysis)) {
417415
continue;
418416
}
419417

@@ -443,7 +441,7 @@ class OpFusionPassHelper {
443441
// VLOG(3) << "Insert Global Output Node : " << producer->id();
444442
consumer_fusion->output_ops.insert(producer);
445443
} else if (producer_data_used_num > 1 && producer->num_operands() > 0 &&
446-
is_same_size(producer, consumer_fusion, shape_analysis)) {
444+
is_same_size(producer, consumer_fusion, &shape_analysis)) {
447445
// producer is not a const value op.
448446
consumer_fusion->internal_ops.insert(producer);
449447
}
@@ -484,7 +482,7 @@ class OpFusionPassHelper {
484482
{OpPatternKind::kBroadcast,
485483
[](::pir::Operation* producer,
486484
const GroupPtr& consumer,
487-
const ::pir::ShapeConstraintIRAnalysis& shape_analysis) -> bool {
485+
::pir::ShapeConstraintIRAnalysis* shape_analysis) -> bool {
488486
// NOTE, producer and consumer NEVER be same size
489487
if (is_same_size(producer, consumer, shape_analysis)) {
490488
return true;
@@ -598,7 +596,7 @@ class OpFusionPassHelper {
598596

599597
bool CanFuse(::pir::Operation* producer,
600598
const ::pir::Operation* consumer,
601-
const ::pir::ShapeConstraintIRAnalysis& shape_analysis) {
599+
::pir::ShapeConstraintIRAnalysis* shape_analysis) {
602600
auto& relation =
603601
fusion_relation_map_[hlir::framework::pir::CompatibleInfo::OpKind(
604602
*producer)];

paddle/cinn/hlir/dialect/operator/transforms/group_merge/op_with_group_merge_util.h

Lines changed: 22 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -77,32 +77,30 @@ int GetSharedSize(::pir::Operation* op);
7777
inline bool always_fuse(
7878
::pir::Operation* producer,
7979
const std::shared_ptr<Group>& consumer,
80-
const ::pir::ShapeConstraintIRAnalysis& shape_analysis) { // NOLINT
80+
::pir::ShapeConstraintIRAnalysis* shape_analysis) { // NOLINT
8181
return true;
8282
}
8383

8484
inline bool no_fuse(::pir::Operation* producer,
8585
const std::shared_ptr<Group>& consumer,
86-
const ::pir::ShapeConstraintIRAnalysis& shape_analysis) {
86+
::pir::ShapeConstraintIRAnalysis* shape_analysis) {
8787
return false;
8888
}
8989

90-
inline bool is_same_shape(
91-
::pir::Operation* producer,
92-
const std::shared_ptr<Group>& consumer,
93-
const ::pir::ShapeConstraintIRAnalysis& shape_analysis) {
90+
inline bool is_same_shape(::pir::Operation* producer,
91+
const std::shared_ptr<Group>& consumer,
92+
::pir::ShapeConstraintIRAnalysis* shape_analysis) {
9493
auto master_op = consumer->master_ops.begin();
95-
return shape_analysis.IsShapeEqual(producer->result(0),
96-
(*master_op)->result(0));
94+
return shape_analysis->IsShapeEqual(producer->result(0),
95+
(*master_op)->result(0));
9796
}
9897

99-
inline bool is_same_size(
100-
::pir::Operation* producer,
101-
const std::shared_ptr<Group>& consumer,
102-
const ::pir::ShapeConstraintIRAnalysis& shape_analysis) {
98+
inline bool is_same_size(::pir::Operation* producer,
99+
const std::shared_ptr<Group>& consumer,
100+
::pir::ShapeConstraintIRAnalysis* shape_analysis) {
103101
auto master_op = consumer->master_ops.begin();
104-
return shape_analysis.IsSameNumel(producer->result(0),
105-
(*master_op)->result(0));
102+
return shape_analysis->IsSameNumel(producer->result(0),
103+
(*master_op)->result(0));
106104
}
107105

108106
inline bool without_last_dimension_in_reduce(
@@ -115,7 +113,7 @@ inline bool without_last_dimension_in_reduce(
115113
inline bool reduce_fuse_reduce(
116114
::pir::Operation* producer,
117115
const std::shared_ptr<Group>& consumer,
118-
const ::pir::ShapeConstraintIRAnalysis& shape_analysis) {
116+
::pir::ShapeConstraintIRAnalysis* shape_analysis) {
119117
::pir::Operation* reducer = NULL;
120118
for (auto* master : consumer->master_ops) {
121119
if (hlir::framework::pir::CompatibleInfo::OpKind(*master) ==
@@ -227,7 +225,7 @@ inline bool is_horizontal_relation(::pir::Operation* producer,
227225
inline bool horizontal_or_vertical_reduce_relation(
228226
::pir::Operation* producer,
229227
const std::shared_ptr<Group>& consumer,
230-
const ::pir::ShapeConstraintIRAnalysis& shape_analysis) {
228+
::pir::ShapeConstraintIRAnalysis* shape_analysis) {
231229
// check is same shape with horizontal relation.
232230
if (is_same_size(producer, consumer, shape_analysis)) {
233231
return true;
@@ -298,7 +296,7 @@ inline bool horizontal_or_vertical_reduce_relation(
298296
inline bool horizontal_or_can_inline(
299297
::pir::Operation* producer,
300298
const std::shared_ptr<Group>& consumer,
301-
const ::pir::ShapeConstraintIRAnalysis& shape_analysis) {
299+
::pir::ShapeConstraintIRAnalysis* shape_analysis) {
302300
// horizontal relation.
303301
if (is_horizontal_relation(producer, consumer)) {
304302
if (is_same_size(producer, consumer, shape_analysis)) {
@@ -336,22 +334,22 @@ inline bool horizontal_or_can_inline(
336334
inline bool horizontal_with_same_size(
337335
::pir::Operation* producer,
338336
const std::shared_ptr<Group>& consumer,
339-
const ::pir::ShapeConstraintIRAnalysis& shape_analysis) {
337+
::pir::ShapeConstraintIRAnalysis* shape_analysis) {
340338
return is_horizontal_relation(producer, consumer) &&
341339
is_same_size(producer, consumer, shape_analysis);
342340
}
343341

344342
inline std::vector<int64_t> GetBroadcastAxes(
345343
::pir::Operation* bcast_op,
346-
const ::pir::ShapeConstraintIRAnalysis& shape_analysis) { // NOLINT
344+
::pir::ShapeConstraintIRAnalysis* shape_analysis) { // NOLINT
347345
if (bcast_op->isa<cinn::dialect::BroadcastOp>()) {
348346
return GetVectorAttr(bcast_op, "broadcast_axes");
349347
} else if (bcast_op->isa<paddle::dialect::ExpandOp>()) {
350348
const auto& input_shape =
351-
shape_analysis.GetShapeOrDataForValue(bcast_op->operand_source(0))
349+
shape_analysis->GetShapeOrDataForValue(bcast_op->operand_source(0))
352350
.shape();
353351
const auto& output_shape =
354-
shape_analysis.GetShapeOrDataForValue(bcast_op->result(0)).shape();
352+
shape_analysis->GetShapeOrDataForValue(bcast_op->result(0)).shape();
355353
std::vector<int64_t> broadcast_axes(input_shape.size(), 0);
356354
size_t index_gap = output_shape.size() - input_shape.size();
357355
for (size_t i = 0; i < input_shape.size(); ++i) {
@@ -366,7 +364,7 @@ inline std::vector<int64_t> GetBroadcastAxes(
366364
inline bool reduce_fuse_broadcast(
367365
::pir::Operation* producer,
368366
const std::shared_ptr<Group>& consumer,
369-
const ::pir::ShapeConstraintIRAnalysis& shape_analysis) {
367+
::pir::ShapeConstraintIRAnalysis* shape_analysis) {
370368
if (is_horizontal_relation(producer, consumer)) {
371369
if (is_same_size(producer, consumer, shape_analysis)) {
372370
return true;
@@ -379,7 +377,7 @@ inline bool reduce_fuse_broadcast(
379377
// }
380378

381379
const auto& rinput_shape =
382-
shape_analysis.GetShapeOrDataForValue(producer->operand_source(0))
380+
shape_analysis->GetShapeOrDataForValue(producer->operand_source(0))
383381
.shape();
384382
auto reduce_axes = GetVectorAttr(producer, "dim");
385383
auto keep_dim = producer->attributes()
@@ -429,7 +427,7 @@ inline bool reduce_fuse_broadcast(
429427
continue;
430428
}
431429
const auto& broadcast_shape =
432-
shape_analysis.GetShapeOrDataForValue(op->result(0)).shape();
430+
shape_analysis->GetShapeOrDataForValue(op->result(0)).shape();
433431
auto broadcast_axes = GetBroadcastAxes(op, shape_analysis);
434432

435433
for (auto& axis : broadcast_axes) {

paddle/cinn/hlir/dialect/operator/transforms/lowering_pass/pre_analysis.cc

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@
1616
#include "paddle/cinn/hlir/dialect/operator/ir/manual_op.h"
1717
#include "paddle/cinn/hlir/dialect/operator/transforms/lowering_pass/utils.h"
1818
#include "paddle/cinn/hlir/framework/pir_compiler.h"
19+
#include "paddle/common/flags.h"
20+
21+
PD_DECLARE_bool(enable_cinn_compile_cache);
1922

2023
namespace cinn::dialect::ir::details {
2124
using cinn::hlir::framework::PirCompiler;
@@ -42,6 +45,10 @@ void FusionOpAnalysis::RunImpl(pir::Operation* op) {
4245
}
4346

4447
void FusionOpAnalysis::PreCompileGroup() {
48+
// Make compilation into lazy mode while
49+
// FLAGS_enable_cinn_compile_cache=false.
50+
if (!FLAGS_enable_cinn_compile_cache) return;
51+
4552
std::vector<OpLoweringGroupPtr> groups;
4653
for (auto& group_info : *group_infos_) {
4754
if (is_dy_shape_ && NeedBroadcastWithCF(group_info.second)) continue;

0 commit comments

Comments
 (0)