Skip to content

Commit afb2d36

Browse files
authored
[Inference] Fix pir load bug (PaddlePaddle#65180)
* fix pir load bug * delete
1 parent bac4bd5 commit afb2d36

File tree

1 file changed

+18
-6
lines changed

1 file changed

+18
-6
lines changed

paddle/fluid/inference/api/analysis_predictor.cc

+18-6
Original file line numberDiff line numberDiff line change
@@ -802,8 +802,8 @@ void AnalysisPredictor::OptimizeInferencePirProgram() {
802802
pir::IrContext *ctx = pir::IrContext::Instance();
803803
ctx->GetOrRegisterDialect<cinn::dialect::OperatorDialect>();
804804
ctx->GetOrRegisterDialect<pir::shape::ShapeDialect>();
805-
auto pass_manager =
806-
std::make_shared<::pir::PassManager>(::pir::IrContext::Instance(), 2);
805+
auto pass_manager = std::make_shared<::pir::PassManager>(
806+
::pir::IrContext::Instance(), config_.pm_opt_level_);
807807
if (!config_.glog_info_disabled()) {
808808
pass_manager->EnablePrintStatistics();
809809
}
@@ -882,12 +882,20 @@ void AnalysisPredictor::OptimizeInferencePirProgram() {
882882
std::make_unique<pir::PassManager::IRPrinterOption>(
883883
ir_printing_conditions, ir_printing_conditions));
884884
}
885+
// set attr
886+
for (const auto &pass : pass_pm.passes()) {
887+
if (pass->name() == "matmul_add_act_fuse_pass" ||
888+
pass->name() == "conv2d_add_act_fuse_pass" ||
889+
pass->name() == "conv2d_add_fuse_pass") {
890+
pass->Set("use_cutlass", new bool(config_.use_cutlass_));
891+
}
892+
}
885893
pass_pm.Run(pir_program_.get());
886894

887895
// Apply some basic passes required by the framework
888896
::pir::PassManager basic_pass_pm(::pir::IrContext::Instance(),
889897
config_.pm_opt_level_);
890-
898+
basic_pass_pm.AddPass(::pir::CreateCommonSubexpressionEliminationPass());
891899
auto params_sync_among_devices_pass =
892900
::pir::CreateParamsSyncAmongDevicesPass();
893901
params_sync_among_devices_pass->SetNotOwned(pir::Pass::kPlaceAttr, &place_);
@@ -918,6 +926,9 @@ void AnalysisPredictor::OptimizeInferencePirProgram() {
918926
paddle::dialect::PdOpLowerToKernelPass(pir_program_.get(), place_);
919927

920928
::pir::PassManager lowered_pm(::pir::IrContext::Instance(), 3);
929+
auto remove_shadow_feed_pass = ::pir::CreateRemoveShadowFeedPass();
930+
remove_shadow_feed_pass->Set("used_for_inference", new bool(true));
931+
lowered_pm.AddPass(std::move(remove_shadow_feed_pass));
921932
if (FLAGS_pir_apply_inplace_pass) {
922933
lowered_pm.AddPass(::pir::CreateInplacePass());
923934
}
@@ -1081,9 +1092,10 @@ bool AnalysisPredictor::PrepareProgram(
10811092
executor_->CreateVariables(*inference_program_, 0, false, sub_scope_);
10821093

10831094
if (config_.new_ir_enabled()) {
1084-
if (pir_program_ != nullptr) {
1085-
PADDLE_FATAL("pir_program_ must be nullptr");
1086-
}
1095+
PADDLE_ENFORCE_EQ(
1096+
pir_program_,
1097+
nullptr,
1098+
platform::errors::Fatal("Here, pir_program must be a nullptr!"));
10871099
pir_program_ = paddle::TranslateLegacyProgramToProgram(*inference_program_);
10881100
OptimizeInferencePirProgram();
10891101
}

0 commit comments

Comments
 (0)