@@ -802,8 +802,8 @@ void AnalysisPredictor::OptimizeInferencePirProgram() {
802
802
pir::IrContext *ctx = pir::IrContext::Instance ();
803
803
ctx->GetOrRegisterDialect <cinn::dialect::OperatorDialect>();
804
804
ctx->GetOrRegisterDialect <pir::shape::ShapeDialect>();
805
- auto pass_manager =
806
- std::make_shared<:: pir::PassManager>(:: pir:: IrContext::Instance (), 2 );
805
+ auto pass_manager = std::make_shared<::pir::PassManager>(
806
+ :: pir::IrContext::Instance (), config_.pm_opt_level_ );
807
807
if (!config_.glog_info_disabled ()) {
808
808
pass_manager->EnablePrintStatistics ();
809
809
}
@@ -882,12 +882,20 @@ void AnalysisPredictor::OptimizeInferencePirProgram() {
882
882
std::make_unique<pir::PassManager::IRPrinterOption>(
883
883
ir_printing_conditions, ir_printing_conditions));
884
884
}
885
+ // set attr
886
+ for (const auto &pass : pass_pm.passes ()) {
887
+ if (pass->name () == " matmul_add_act_fuse_pass" ||
888
+ pass->name () == " conv2d_add_act_fuse_pass" ||
889
+ pass->name () == " conv2d_add_fuse_pass" ) {
890
+ pass->Set (" use_cutlass" , new bool (config_.use_cutlass_ ));
891
+ }
892
+ }
885
893
pass_pm.Run (pir_program_.get ());
886
894
887
895
// Apply some basic passes required by the framework
888
896
::pir::PassManager basic_pass_pm (::pir::IrContext::Instance (),
889
897
config_.pm_opt_level_ );
890
-
898
+ basic_pass_pm. AddPass (:: pir::CreateCommonSubexpressionEliminationPass ());
891
899
auto params_sync_among_devices_pass =
892
900
::pir::CreateParamsSyncAmongDevicesPass ();
893
901
params_sync_among_devices_pass->SetNotOwned (pir::Pass::kPlaceAttr , &place_);
@@ -918,6 +926,9 @@ void AnalysisPredictor::OptimizeInferencePirProgram() {
918
926
paddle::dialect::PdOpLowerToKernelPass (pir_program_.get (), place_);
919
927
920
928
::pir::PassManager lowered_pm (::pir::IrContext::Instance (), 3 );
929
+ auto remove_shadow_feed_pass = ::pir::CreateRemoveShadowFeedPass ();
930
+ remove_shadow_feed_pass->Set (" used_for_inference" , new bool (true ));
931
+ lowered_pm.AddPass (std::move (remove_shadow_feed_pass));
921
932
if (FLAGS_pir_apply_inplace_pass) {
922
933
lowered_pm.AddPass (::pir::CreateInplacePass ());
923
934
}
@@ -1081,9 +1092,10 @@ bool AnalysisPredictor::PrepareProgram(
1081
1092
executor_->CreateVariables (*inference_program_, 0 , false , sub_scope_);
1082
1093
1083
1094
if (config_.new_ir_enabled ()) {
1084
- if (pir_program_ != nullptr ) {
1085
- PADDLE_FATAL (" pir_program_ must be nullptr" );
1086
- }
1095
+ PADDLE_ENFORCE_EQ (
1096
+ pir_program_,
1097
+ nullptr ,
1098
+ platform::errors::Fatal (" Here, pir_program must be a nullptr!" ));
1087
1099
pir_program_ = paddle::TranslateLegacyProgramToProgram (*inference_program_);
1088
1100
OptimizeInferencePirProgram ();
1089
1101
}
0 commit comments