Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions paddle/fluid/inference/api/analysis_predictor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1595,7 +1595,7 @@ void AnalysisPredictor::PrepareArgument() {
}

#ifdef PADDLE_WITH_IPU
argument_->SetUseIpu(config_.use_ipu_);
argument_->SetUseIpu(config_.use_ipu());
argument_->SetIpuDeviceNum(config_.ipu_device_num());
argument_->SetIpuMicroBatchSize(config_.ipu_micro_batch_size_);
argument_->SetIpuEnablePipelining(config_.ipu_enable_pipelining_);
Expand All @@ -1611,7 +1611,7 @@ void AnalysisPredictor::PrepareArgument() {
argument_->SetIpuCustomPatterns(config_.ipu_custom_patterns_);
#endif

if (config_.use_mkldnn_) {
if (config_.mkldnn_enabled() && !config_.use_gpu()) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

打开gpu和打开mkldnn互斥吗?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

跑GPU就不会跑cpu啊

LOG(INFO) << "MKLDNN is enabled";
argument_->SetMKLDNNEnabledOpTypes(config_.mkldnn_enabled_op_types_);
}
Expand All @@ -1628,12 +1628,12 @@ void AnalysisPredictor::PrepareArgument() {
argument_->SetQuantizeExcludedOpIds(
config_.mkldnn_quantizer_config()->excluded_op_ids());
}
if (config_.use_mkldnn_bfloat16_) {
if (config_.mkldnn_bfloat16_enabled()) {
LOG(INFO) << "Bfloat16 is enabled";
argument_->SetBfloat16EnabledOpTypes(config_.bfloat16_enabled_op_types_);
}

if (config_.use_mkldnn_int8_) {
if (config_.mkldnn_int8_enabled()) {
LOG(INFO) << "Int8 is enabled";
argument_->SetQuantizeEnabledOpTypes(config_.quantize_enabled_op_types_);
argument_->SetQuantizeExcludedOpIds(config_.quantize_excluded_op_ids_);
Expand Down
23 changes: 15 additions & 8 deletions paddle/fluid/pir/transforms/constant_folding_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,9 @@ class ConstantFoldingPattern : public pir::RewritePattern {
void Rewrite(pir::Operation* op,
pir::PatternRewriter& rewriter) const override { // NOLINT
VLOG(4) << "constant_folding_pass applys on [" << op->name() << "] op";
pir::Program new_program(ir_context());
auto output_var_name = BuildProgramFromOperation(op, &new_program);
pir::Program new_program(rewriter.ir_context());
auto output_var_name =
BuildProgramFromOperation(op, &new_program, rewriter);

// execute program
exe_config_->skip_gc_vars.insert(output_var_name);
Expand Down Expand Up @@ -163,9 +164,12 @@ class ConstantFoldingPattern : public pir::RewritePattern {
return true;
}

std::string BuildProgramFromOperation(pir::Operation* op,
pir::Program* new_program) const {
pir::Builder builder = pir::Builder(ir_context(), new_program->block());
std::string BuildProgramFromOperation(
pir::Operation* op,
pir::Program* new_program,
pir::PatternRewriter& rewriter) const { // NOLINT
pir::Builder builder =
pir::Builder(rewriter.ir_context(), new_program->block());

// prepare op inputs
std::vector<pir::Value> op_inputs;
Expand All @@ -176,12 +180,15 @@ class ConstantFoldingPattern : public pir::RewritePattern {
PADDLE_ENFORCE_NOT_NULL(
param_var,
phi::errors::InvalidArgument("Parameter var not in scope."));
if (op->operand_source(i).use_count() == 1) {
deleted_vars_->push_back(param_name);
}

auto parameter_op = builder.Build<pir::ParameterOp>(
param_name, op->operand_source(i).type());
if (op->operand_source(i).use_count() <= 1) {
deleted_vars_->push_back(param_name);
} else {
parameter_op->set_attribute(
kAttrIsPersisable, rewriter.array_attr({rewriter.bool_attr(true)}));
}
op_inputs.push_back(parameter_op->result(0));
}

Expand Down