Skip to content
Merged
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,7 @@ void GetInputIds(pir::Operation* op,
"input should in name map, [%d] 'th input of [%s] op",
i,
"if op"));
std::vector<int> inputs_id = GetValueIds(value, value_exec_info);
input_ids->emplace(value, inputs_id);
input_ids->emplace(value, GetValueIds(value, value_exec_info));
}
}
}
Expand All @@ -92,9 +91,7 @@ void GetOutsideOpInputs(
"input should in name map, [%d] 'th input of [%s] op",
i,
op->name()));
std::vector<int> inputs_id = GetValueIds(value, value_exec_info);

input_ids->emplace(value, inputs_id);
input_ids->emplace(value, GetValueIds(value, value_exec_info));
}
}
}
Expand Down Expand Up @@ -181,14 +178,22 @@ CondInstruction::CondInstruction(size_t id,
"input should in name map, [%d] 'th input of [%s] op",
i,
"if op"));
std::vector<int> outputs_id = GetValueIds(value, *value_exec_info);
outputs.emplace(value, outputs_id);
outputs.emplace(value, GetValueIds(value, *value_exec_info));
}
}
SetOutputs(outputs);
VLOG(6) << "finish process inputs outputs index";
}

CondInstruction::~CondInstruction() {
if (true_branch_inter_ != nullptr) {
delete true_branch_inter_;
}
if (false_branch_inter_ != nullptr) {
delete false_branch_inter_;
}
}

void CondInstruction::CopyBranchOutput(
const std::vector<std::string>& var_names, const NewIRInterpreter* inter) {
for (size_t i = 0; i < var_names.size(); ++i) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ class CondInstruction : public InstructionBase {
::pir::Operation* op,
ValueExecutionInfo* value_exe_info);

~CondInstruction();

void Run() override;

const std::string& Name() const override { return cond_name_; }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
#include <vector>

#include "paddle/fluid/framework/new_executor/new_executor_defs.h"
#include "paddle/fluid/pir/dialect/kernel/ir/kernel_dialect.h"
#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/event.h"
#include "paddle/pir/core/builtin_attribute.h"
Expand Down Expand Up @@ -148,7 +150,8 @@ OpFuncType AnalyseOpFuncType(pir::Operation* op, const platform::Place& place) {

auto& op_attributes = op->attributes();

if ((op->dialect()->name() == "pd_kernel") &&
if ((op->dialect()->name().compare(paddle::dialect::KernelDialect::name()) ==
0) &&
(op_attributes.count("kernel_key") > 0)) {
auto kernel_key = op_attributes.at("kernel_key")
.dyn_cast<dialect::KernelAttribute>()
Expand Down Expand Up @@ -179,7 +182,7 @@ OpFuncType AnalyseOpFuncType(pir::Operation* op, const platform::Place& place) {
return OpFuncType::kGpuSync;
}

if (op_name == "pd_op.shape") {
if (op_name.compare(paddle::dialect::ShapeOp::name()) == 0) {
return OpFuncType::kGpuSync;
}
}
Expand Down
30 changes: 23 additions & 7 deletions paddle/fluid/pir/transforms/inplace_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "paddle/fluid/pir/dialect/kernel/ir/kernel_dialect.h"
#include "paddle/fluid/pir/dialect/kernel/ir/kernel_type.h"
#include "paddle/fluid/pir/dialect/operator/interface/op_yaml_info.h"
#include "paddle/fluid/pir/dialect/operator/ir/control_flow_op.h"
#include "paddle/fluid/pir/dialect/operator/ir/op_type.h"
#include "paddle/fluid/pir/dialect/operator/trait/inplace.h"
#include "paddle/fluid/pir/dialect/operator/utils/op_yaml_info_parser.h"
Expand Down Expand Up @@ -121,11 +122,10 @@ static std::unordered_set<pir::Value> GetSkipDeletionValues(pir::Block* block) {
// NOTE(zhangbo): For inplace Pass, currently only the kernel_dialect operator
// is supported. Therefore, this function only returns the values in the
// kernel_dialect operator that can be eager deleted.
static std::unordered_map<pir::Operation*, std::unordered_set<pir::Value>>
GetEagerDeletionValues(pir::Block* block) {
std::unordered_set<pir::Value> skip_dels = GetSkipDeletionValues(block);

std::unordered_map<pir::Value, pir::Operation*> del_value_2_op;
static void GetEagerDelValueOfOp(
pir::Block* block,
const std::unordered_set<pir::Value>& skip_dels,
std::unordered_map<pir::Value, pir::Operation*>* del_value_2_op) {
for (auto& op : *block) {
std::string upper_op_name = op->name();
if (op->dialect()->name().compare(paddle::dialect::KernelDialect::name()) ==
Expand All @@ -150,16 +150,32 @@ GetEagerDeletionValues(pir::Block* block) {
VLOG(8) << " -- is no_need_buffer: " << IsNoNeedBuffer(op, input);
continue;
}
del_value_2_op[input] = op;
(*del_value_2_op)[input] = op;
}

for (size_t i = 0; i < op->num_results(); ++i) {
pir::Value output = op->result(i);
if (output && CanBeDeleted(output)) {
del_value_2_op[output] = op;
(*del_value_2_op)[output] = op;
}
}

if (op->isa<paddle::dialect::IfOp>()) {
auto if_op = op->dyn_cast<paddle::dialect::IfOp>();
GetEagerDelValueOfOp(if_op.true_block(), skip_dels, del_value_2_op);
VLOG(8) << "GetEagerDelValueOfOp for IfOp true block";
GetEagerDelValueOfOp(if_op.false_block(), skip_dels, del_value_2_op);
VLOG(8) << "GetEagerDelValueOfOp for IfOp false block";
}
}
}

static std::unordered_map<pir::Operation*, std::unordered_set<pir::Value>>
GetEagerDeletionValues(pir::Block* block) {
std::unordered_set<pir::Value> skip_dels = GetSkipDeletionValues(block);

std::unordered_map<pir::Value, pir::Operation*> del_value_2_op;
GetEagerDelValueOfOp(block, skip_dels, &del_value_2_op);

std::unordered_map<pir::Operation*, std::unordered_set<pir::Value>>
eager_dels;
Expand Down
1 change: 1 addition & 0 deletions test/legacy_test/test_cond.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,7 @@ def false_func():
np.asarray(ret[1]), np.full((2, 3), True, bool), rtol=1e-05
)

@test_and_compare_with_new_ir()
def test_pass_and_modify_var(self):
"""
pseudocode:
Expand Down