Skip to content

Commit

Permalink
[Inference] Use cse pass (PaddlePaddle#64523)
Browse files Browse the repository at this point in the history
* use cse pass

* fix
  • Loading branch information
yuanlehome authored May 29, 2024
1 parent 442c7d9 commit fe30f9f
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 6 deletions.
5 changes: 3 additions & 2 deletions paddle/fluid/inference/api/analysis_predictor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@

#include "paddle/common/flags.h"
#include "paddle/fluid/ir_adaptor/translator/translate.h"
#include "paddle/fluid/pir/transforms/general/common_subexpression_elimination_pass.h"
#include "paddle/fluid/pir/transforms/general/constant_folding_pass.h"
#include "paddle/fluid/pir/transforms/general/dead_code_elimination_pass.h"
#include "paddle/fluid/pir/transforms/general/inplace_pass.h"
Expand Down Expand Up @@ -906,7 +907,7 @@ bool AnalysisPredictor::PrepareExecutor() {
ctx->GetOrRegisterDialect<cinn::dialect::OperatorDialect>();
ctx->GetOrRegisterDialect<pir::shape::ShapeDialect>();
auto pass_manager = std::make_shared<::pir::PassManager>(
::pir::IrContext::Instance(), 2);
::pir::IrContext::Instance(), config_.pm_opt_level_);
if (!config_.glog_info_disabled()) {
pass_manager->EnablePrintStatistics();
}
Expand Down Expand Up @@ -999,7 +1000,7 @@ bool AnalysisPredictor::PrepareExecutor() {
// Apply some basic passes required by the framework
::pir::PassManager basic_pass_pm(::pir::IrContext::Instance(),
config_.pm_opt_level_);

basic_pass_pm.AddPass(::pir::CreateCommonSubexpressionEliminationPass());
auto params_sync_among_devices_pass =
::pir::CreateParamsSyncAmongDevicesPass();
params_sync_among_devices_pass->SetNotOwned(pir::Pass::kPlaceAttr,
Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/inference/api/paddle_pass_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -610,6 +610,7 @@ const std::vector<std::string> kPirGpuPasses{
"embedding_eltwise_layernorm_fuse_pass",
"fused_flash_attn_pass",
"multihead_matmul_fuse_pass",
"fused_weight_only_linear_pass",
"matmul_add_act_fuse_pass",
"fc_elementwise_layernorm_fuse_pass",
"matmul_scale_fuse_pass",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,6 @@ int getSMVersion() {
#if defined(PADDLE_WITH_CUDA) && defined(PADDLE_WITH_CUTLASS)
sm_version = paddle::platform::GetGPUComputeCapability(
paddle::platform::GetCurrentDeviceId());
#else
PADDLE_THROW(common::errors::Unavailable(
"fused_weight_only_linear_pass needs paddle compiled with CUDA."));
#endif
return sm_version;
}
Expand Down Expand Up @@ -280,7 +277,7 @@ class FusedWeightOnlyLinearPass : public pir::PatternRewritePass {
sm_version_(getSMVersion()) {}

pir::RewritePatternSet InitializePatterns(pir::IrContext *context) override {
std::string algo = "weight_only_int4";
std::string algo = "weight_only_int8";
if (Has("weight_only_algo")) {
algo = Get<std::string>("weight_only_algo");
}
Expand Down

0 comments on commit fe30f9f

Please sign in to comment.