Skip to content

Commit a977ebe

Browse files
xiaoguoguo626807yuanlehome
authored andcommitted
【pir_save_load】add program._prune to support paddledetection (PaddlePaddle#66231)
* add new test * add new test * nouse * add predictor(pir) delete pass api * add delete pass * recover reset test * modify delete pass * modify static_save_load * add test_build_strategy * move static save load * move static save load * add program._prune --------- Co-authored-by: yuanlehome <yuanlehome@163.com>
1 parent d96cdb0 commit a977ebe

File tree

3 files changed

+389
-28
lines changed

3 files changed

+389
-28
lines changed

paddle/fluid/pybind/pir.cc

Lines changed: 164 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151
#include "paddle/fluid/pir/transforms/general/common_subexpression_elimination_pass.h"
5252
#include "paddle/fluid/pir/transforms/gpu/fused_bn_add_act_pass.h"
5353
#include "paddle/fluid/pir/transforms/passes.h"
54+
#include "paddle/fluid/pir/utils/general_functions.h"
5455
#include "paddle/fluid/pybind/control_flow_api.h"
5556
#include "paddle/fluid/pybind/eager_utils.h"
5657
#include "paddle/fluid/pybind/pybind_variant_caster.h"
@@ -326,6 +327,34 @@ std::string GetValueName(Value value) {
326327
"DataOp/ParameterOp/BlockArgument and ShadowOutputOp."));
327328
}
328329

330+
phi::DataType GetTensorDtype(Type type) {
331+
if (!type) {
332+
PADDLE_THROW(phi::errors::InvalidArgument("The type of value is nullptr."));
333+
}
334+
if (auto dense_tensor_type = type.dyn_cast<DenseTensorType>()) {
335+
return dialect::TransToPhiDataType(dense_tensor_type.dtype());
336+
} else if (auto sparse_coo_tensor_type =
337+
type.dyn_cast<SparseCooTensorType>()) {
338+
return dialect::TransToPhiDataType(sparse_coo_tensor_type.dtype());
339+
} else if (auto sparse_csr_tensor_type =
340+
type.dyn_cast<SparseCsrTensorType>()) {
341+
return dialect::TransToPhiDataType(sparse_csr_tensor_type.dtype());
342+
} else if (auto select_rows = type.dyn_cast<SelectedRowsType>()) {
343+
return dialect::TransToPhiDataType(select_rows.dtype());
344+
} else if (auto dense_array = type.dyn_cast<DenseTensorArrayType>()) {
345+
return dialect::TransToPhiDataType(dense_array.dtype());
346+
} else {
347+
PADDLE_THROW(phi::errors::InvalidArgument(
348+
"Currently, we can only get phi::DataType from DenseTensorType and "
349+
"SelectedRowsType, DenseTensorArrayType,SparseCooTensorType or "
350+
"SparseCsrTensorType."));
351+
}
352+
}
353+
354+
phi::DataType GetValueDtype(Value value) {
355+
return GetTensorDtype(value.type());
356+
}
357+
329358
py::object Clone(const Program &self, IrMapping *p_mapper = nullptr) {
330359
IrMapping mapper;
331360
if (p_mapper == nullptr) {
@@ -339,6 +368,122 @@ py::object Clone(const Program &self, IrMapping *p_mapper = nullptr) {
339368
return new_obj;
340369
}
341370

371+
bool SomeInSet(const std::vector<pir::Value> &vec,
372+
const std::set<pir::Value> &set) {
373+
for (auto &v : vec) {
374+
if (set.find(v) != set.end()) {
375+
return true;
376+
}
377+
}
378+
return false;
379+
}
380+
381+
pir::Value AppendDataOp(pir::Block *block,
382+
const pir::Value &value,
383+
const std::string &name,
384+
const pir::Operation &origin_op) {
385+
pir::IrContext *ctx = pir::IrContext::Instance();
386+
auto op_info = ctx->GetRegisteredOpInfo(paddle::dialect::DataOp::name());
387+
pir::AttributeMap attribute_map = {
388+
{"name", StrAttribute::get(ctx, name)},
389+
{"shape",
390+
paddle::dialect::IntArrayAttribute::get(
391+
ctx, phi::IntArray(phi::vectorize(GetValueDims(value))))},
392+
{"dtype",
393+
paddle::dialect::DataTypeAttribute::get(ctx, GetValueDtype(value))},
394+
{"place", paddle::dialect::PlaceAttribute::get(ctx, phi::Place())}};
395+
std::vector<pir::Type> output_types{value.type()};
396+
pir::Operation *operation =
397+
pir::Operation::Create({}, attribute_map, output_types, op_info);
398+
399+
block->insert(origin_op, operation);
400+
return operation->result(0);
401+
}
402+
std::vector<pir::Value> GetRealOpInputs(pir::Operation *op) {
403+
if (op->isa<paddle::dialect::IfOp>() ||
404+
op->isa<paddle::dialect::PyLayerOp>()) {
405+
return pir::GetUsedExternalValue(*op);
406+
} else if (op->isa<paddle::dialect::WhileOp>()) {
407+
paddle::dialect::WhileOp whileop = op->dyn_cast<paddle::dialect::WhileOp>();
408+
auto value_vector = op->operands_source();
409+
auto value_vector2 = pir::GetUsedExternalValue(whileop.body());
410+
value_vector.insert(
411+
value_vector.end(), value_vector2.begin(), value_vector2.end());
412+
return value_vector;
413+
} else {
414+
return op->operands_source();
415+
}
416+
}
417+
/*
418+
Variables in input_vars will be the pruned program's inputs,
419+
and variables in output_vars will be the pruned program's outputs.
420+
Therefore, the pruning logic includes replacing the input of
421+
input_vars with the data op, and then preserving all connected
422+
ops starting from output_vars.
423+
424+
Note: The returned program is the original program.
425+
If you do not want the original program to be modified,
426+
please pass in a cloned result.
427+
*/
428+
void PruneWithInput(const std::vector<pir::Value> &input_vars,
429+
const std::vector<pir::Value> &output_vars,
430+
Program *prog) {
431+
auto global_block = prog->block();
432+
std::vector<pir::Value> new_input_vars;
433+
if (!input_vars.empty()) {
434+
std::vector<pir::Value> new_input_vars;
435+
for (uint64_t idx = 0; idx < input_vars.size(); idx++) {
436+
auto input = input_vars[idx];
437+
auto orgin_op = input.defining_op();
438+
std::string name = "input_" + idx;
439+
if (HasValueName(input)) {
440+
name = GetValueName(input);
441+
}
442+
auto new_input = AppendDataOp(global_block, input, name, *orgin_op);
443+
input.ReplaceAllUsesWith(new_input);
444+
new_input_vars.push_back(new_input);
445+
}
446+
}
447+
VLOG(6) << "program after add new feed op = " << *prog;
448+
auto total_ops_list = global_block->ops();
449+
std::vector<pir::Operation *> total_ops(total_ops_list.begin(),
450+
total_ops_list.end());
451+
std::vector<bool> intersection_op_flags(total_ops.size(), true);
452+
std::set<pir::Value> output_vars_set(output_vars.begin(), output_vars.end());
453+
for (uint32_t index = total_ops.size() - 1; index != (uint32_t)(-1);
454+
--index) {
455+
auto op = total_ops[index];
456+
auto op_results = op->results();
457+
if (SomeInSet(op_results, output_vars_set)) {
458+
for (auto &operand : GetRealOpInputs(op)) {
459+
output_vars_set.insert(operand);
460+
}
461+
} else {
462+
VLOG(6) << "delete op " << index << ", name is " << op->name();
463+
intersection_op_flags[index] = false;
464+
}
465+
}
466+
467+
std::set<pir::Value> input_vars_set(new_input_vars.begin(),
468+
new_input_vars.end());
469+
std::vector<pir::Operation *> remove_ops;
470+
for (uint32_t index = total_ops.size() - 1; index != (uint32_t)(-1);
471+
--index) {
472+
auto op = total_ops[index];
473+
if (!intersection_op_flags[index]) {
474+
auto op_results = op->results();
475+
if (!input_vars_set.empty() && SomeInSet(op_results, input_vars_set)) {
476+
PADDLE_THROW(phi::errors::InvalidArgument(
477+
"The input_var create by: '{%s}' is not involved in the "
478+
"output_vars calculation"
479+
"Please remove it from input_vars.",
480+
op->name()));
481+
}
482+
global_block->erase(*op);
483+
}
484+
}
485+
}
486+
342487
void BindProgram(py::module *m) {
343488
static int64_t global_prog_seed = 0;
344489
py::class_<Program, std::shared_ptr<Program>> program(
@@ -567,6 +712,25 @@ void BindProgram(py::module *m) {
567712
}
568713
}
569714
})
715+
.def(
716+
"_prune",
717+
[](Program &self, std::vector<pir::Value> output_vars) {
718+
std::vector<pir::Value> input_vars;
719+
PruneWithInput(input_vars, output_vars, &self);
720+
return &self;
721+
},
722+
py::arg("targets"),
723+
"A description for the _prune method")
724+
.def(
725+
"_prune_with_input",
726+
[](Program &self,
727+
std::vector<pir::Value> input_vars,
728+
std::vector<pir::Value> output_vars) {
729+
PruneWithInput(input_vars, output_vars, &self);
730+
return &self;
731+
},
732+
py::arg("feeded_vars"),
733+
py::arg("targets"))
570734
.def("_sync_with_cpp", [](const std::shared_ptr<Program> &self) {
571735
// It's not need _sync_with_cpp in pir, but it's necessary in old static
572736
// graph. Add empyt function to avoid python call error.
@@ -1031,33 +1195,6 @@ py::str Value2String(Value self) {
10311195
return print_stream.str();
10321196
}
10331197

1034-
phi::DataType GetTensorDtype(Type type) {
1035-
if (!type) {
1036-
PADDLE_THROW(phi::errors::InvalidArgument("The type of value is nullptr."));
1037-
}
1038-
if (auto dense_tensor_type = type.dyn_cast<DenseTensorType>()) {
1039-
return dialect::TransToPhiDataType(dense_tensor_type.dtype());
1040-
} else if (auto sparse_coo_tensor_type =
1041-
type.dyn_cast<SparseCooTensorType>()) {
1042-
return dialect::TransToPhiDataType(sparse_coo_tensor_type.dtype());
1043-
} else if (auto sparse_csr_tensor_type =
1044-
type.dyn_cast<SparseCsrTensorType>()) {
1045-
return dialect::TransToPhiDataType(sparse_csr_tensor_type.dtype());
1046-
} else if (auto select_rows = type.dyn_cast<SelectedRowsType>()) {
1047-
return dialect::TransToPhiDataType(select_rows.dtype());
1048-
} else if (auto dense_array = type.dyn_cast<DenseTensorArrayType>()) {
1049-
return dialect::TransToPhiDataType(dense_array.dtype());
1050-
} else {
1051-
PADDLE_THROW(phi::errors::InvalidArgument(
1052-
"Currently, we can only get phi::DataType from DenseTensorType and "
1053-
"SelectedRowsType, DenseTensorArrayType,SparseCooTensorType or "
1054-
"SparseCsrTensorType."));
1055-
}
1056-
}
1057-
phi::DataType GetValueDtype(Value value) {
1058-
return GetTensorDtype(value.type());
1059-
}
1060-
10611198
const phi::DDim &GetTensorDims(Type type) {
10621199
if (!type) {
10631200
PADDLE_THROW(common::errors::InvalidArgument(

test/legacy_test/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1139,7 +1139,7 @@ foreach(PIR_COVERAGE_TEST ${PIR_COVERAGE_TESTS})
11391139
message(STATUS "PIR Copied OpTest: ${PIR_COVERAGE_TEST}_pir in legacy_test")
11401140
endforeach()
11411141

1142-
set(PIR_ONLY_TEST_FILES test_pir_translated_layer)
1142+
set(PIR_ONLY_TEST_FILES test_pir_translated_layer test_prune)
11431143
foreach(ITEST ${PIR_ONLY_TEST_FILES})
11441144
if(TEST ${ITEST})
11451145
set_tests_properties(${ITEST} PROPERTIES ENVIRONMENT

0 commit comments

Comments
 (0)