51
51
#include " paddle/fluid/pir/transforms/general/common_subexpression_elimination_pass.h"
52
52
#include " paddle/fluid/pir/transforms/gpu/fused_bn_add_act_pass.h"
53
53
#include " paddle/fluid/pir/transforms/passes.h"
54
+ #include " paddle/fluid/pir/utils/general_functions.h"
54
55
#include " paddle/fluid/pybind/control_flow_api.h"
55
56
#include " paddle/fluid/pybind/eager_utils.h"
56
57
#include " paddle/fluid/pybind/pybind_variant_caster.h"
@@ -326,6 +327,34 @@ std::string GetValueName(Value value) {
326
327
" DataOp/ParameterOp/BlockArgument and ShadowOutputOp." ));
327
328
}
328
329
330
+ phi::DataType GetTensorDtype (Type type) {
331
+ if (!type) {
332
+ PADDLE_THROW (phi::errors::InvalidArgument (" The type of value is nullptr." ));
333
+ }
334
+ if (auto dense_tensor_type = type.dyn_cast <DenseTensorType>()) {
335
+ return dialect::TransToPhiDataType (dense_tensor_type.dtype ());
336
+ } else if (auto sparse_coo_tensor_type =
337
+ type.dyn_cast <SparseCooTensorType>()) {
338
+ return dialect::TransToPhiDataType (sparse_coo_tensor_type.dtype ());
339
+ } else if (auto sparse_csr_tensor_type =
340
+ type.dyn_cast <SparseCsrTensorType>()) {
341
+ return dialect::TransToPhiDataType (sparse_csr_tensor_type.dtype ());
342
+ } else if (auto select_rows = type.dyn_cast <SelectedRowsType>()) {
343
+ return dialect::TransToPhiDataType (select_rows.dtype ());
344
+ } else if (auto dense_array = type.dyn_cast <DenseTensorArrayType>()) {
345
+ return dialect::TransToPhiDataType (dense_array.dtype ());
346
+ } else {
347
+ PADDLE_THROW (phi::errors::InvalidArgument (
348
+ " Currently, we can only get phi::DataType from DenseTensorType and "
349
+ " SelectedRowsType, DenseTensorArrayType,SparseCooTensorType or "
350
+ " SparseCsrTensorType." ));
351
+ }
352
+ }
353
+
354
+ phi::DataType GetValueDtype (Value value) {
355
+ return GetTensorDtype (value.type ());
356
+ }
357
+
329
358
py::object Clone (const Program &self, IrMapping *p_mapper = nullptr ) {
330
359
IrMapping mapper;
331
360
if (p_mapper == nullptr ) {
@@ -339,6 +368,122 @@ py::object Clone(const Program &self, IrMapping *p_mapper = nullptr) {
339
368
return new_obj;
340
369
}
341
370
371
+ bool SomeInSet (const std::vector<pir::Value> &vec,
372
+ const std::set<pir::Value> &set) {
373
+ for (auto &v : vec) {
374
+ if (set.find (v) != set.end ()) {
375
+ return true ;
376
+ }
377
+ }
378
+ return false ;
379
+ }
380
+
381
+ pir::Value AppendDataOp (pir::Block *block,
382
+ const pir::Value &value,
383
+ const std::string &name,
384
+ const pir::Operation &origin_op) {
385
+ pir::IrContext *ctx = pir::IrContext::Instance ();
386
+ auto op_info = ctx->GetRegisteredOpInfo (paddle::dialect::DataOp::name ());
387
+ pir::AttributeMap attribute_map = {
388
+ {" name" , StrAttribute::get (ctx, name)},
389
+ {" shape" ,
390
+ paddle::dialect::IntArrayAttribute::get (
391
+ ctx, phi::IntArray (phi::vectorize (GetValueDims (value))))},
392
+ {" dtype" ,
393
+ paddle::dialect::DataTypeAttribute::get (ctx, GetValueDtype (value))},
394
+ {" place" , paddle::dialect::PlaceAttribute::get (ctx, phi::Place ())}};
395
+ std::vector<pir::Type> output_types{value.type ()};
396
+ pir::Operation *operation =
397
+ pir::Operation::Create ({}, attribute_map, output_types, op_info);
398
+
399
+ block->insert (origin_op, operation);
400
+ return operation->result (0 );
401
+ }
402
+ std::vector<pir::Value> GetRealOpInputs (pir::Operation *op) {
403
+ if (op->isa <paddle::dialect::IfOp>() ||
404
+ op->isa <paddle::dialect::PyLayerOp>()) {
405
+ return pir::GetUsedExternalValue (*op);
406
+ } else if (op->isa <paddle::dialect::WhileOp>()) {
407
+ paddle::dialect::WhileOp whileop = op->dyn_cast <paddle::dialect::WhileOp>();
408
+ auto value_vector = op->operands_source ();
409
+ auto value_vector2 = pir::GetUsedExternalValue (whileop.body ());
410
+ value_vector.insert (
411
+ value_vector.end (), value_vector2.begin (), value_vector2.end ());
412
+ return value_vector;
413
+ } else {
414
+ return op->operands_source ();
415
+ }
416
+ }
417
+ /*
418
+ Variables in input_vars will be the pruned program's inputs,
419
+ and variables in output_vars will be the pruned program's outputs.
420
+ Therefore, the pruning logic includes replacing the input of
421
+ input_vars with the data op, and then preserving all connected
422
+ ops starting from output_vars.
423
+
424
+ Note: The returned program is the original program.
425
+ If you do not want the original program to be modified,
426
+ please pass in a cloned result.
427
+ */
428
+ void PruneWithInput (const std::vector<pir::Value> &input_vars,
429
+ const std::vector<pir::Value> &output_vars,
430
+ Program *prog) {
431
+ auto global_block = prog->block ();
432
+ std::vector<pir::Value> new_input_vars;
433
+ if (!input_vars.empty ()) {
434
+ std::vector<pir::Value> new_input_vars;
435
+ for (uint64_t idx = 0 ; idx < input_vars.size (); idx++) {
436
+ auto input = input_vars[idx];
437
+ auto orgin_op = input.defining_op ();
438
+ std::string name = " input_" + idx;
439
+ if (HasValueName (input)) {
440
+ name = GetValueName (input);
441
+ }
442
+ auto new_input = AppendDataOp (global_block, input, name, *orgin_op);
443
+ input.ReplaceAllUsesWith (new_input);
444
+ new_input_vars.push_back (new_input);
445
+ }
446
+ }
447
+ VLOG (6 ) << " program after add new feed op = " << *prog;
448
+ auto total_ops_list = global_block->ops ();
449
+ std::vector<pir::Operation *> total_ops (total_ops_list.begin (),
450
+ total_ops_list.end ());
451
+ std::vector<bool > intersection_op_flags (total_ops.size (), true );
452
+ std::set<pir::Value> output_vars_set (output_vars.begin (), output_vars.end ());
453
+ for (uint32_t index = total_ops.size () - 1 ; index != (uint32_t )(-1 );
454
+ --index) {
455
+ auto op = total_ops[index];
456
+ auto op_results = op->results ();
457
+ if (SomeInSet (op_results, output_vars_set)) {
458
+ for (auto &operand : GetRealOpInputs (op)) {
459
+ output_vars_set.insert (operand);
460
+ }
461
+ } else {
462
+ VLOG (6 ) << " delete op " << index << " , name is " << op->name ();
463
+ intersection_op_flags[index] = false ;
464
+ }
465
+ }
466
+
467
+ std::set<pir::Value> input_vars_set (new_input_vars.begin (),
468
+ new_input_vars.end ());
469
+ std::vector<pir::Operation *> remove_ops;
470
+ for (uint32_t index = total_ops.size () - 1 ; index != (uint32_t )(-1 );
471
+ --index) {
472
+ auto op = total_ops[index];
473
+ if (!intersection_op_flags[index]) {
474
+ auto op_results = op->results ();
475
+ if (!input_vars_set.empty () && SomeInSet (op_results, input_vars_set)) {
476
+ PADDLE_THROW (phi::errors::InvalidArgument (
477
+ " The input_var create by: '{%s}' is not involved in the "
478
+ " output_vars calculation"
479
+ " Please remove it from input_vars." ,
480
+ op->name ()));
481
+ }
482
+ global_block->erase (*op);
483
+ }
484
+ }
485
+ }
486
+
342
487
void BindProgram (py::module *m) {
343
488
static int64_t global_prog_seed = 0 ;
344
489
py::class_<Program, std::shared_ptr<Program>> program (
@@ -567,6 +712,25 @@ void BindProgram(py::module *m) {
567
712
}
568
713
}
569
714
})
715
+ .def (
716
+ " _prune" ,
717
+ [](Program &self, std::vector<pir::Value> output_vars) {
718
+ std::vector<pir::Value> input_vars;
719
+ PruneWithInput (input_vars, output_vars, &self);
720
+ return &self;
721
+ },
722
+ py::arg (" targets" ),
723
+ " A description for the _prune method" )
724
+ .def (
725
+ " _prune_with_input" ,
726
+ [](Program &self,
727
+ std::vector<pir::Value> input_vars,
728
+ std::vector<pir::Value> output_vars) {
729
+ PruneWithInput (input_vars, output_vars, &self);
730
+ return &self;
731
+ },
732
+ py::arg (" feeded_vars" ),
733
+ py::arg (" targets" ))
570
734
.def (" _sync_with_cpp" , [](const std::shared_ptr<Program> &self) {
571
735
// It's not need _sync_with_cpp in pir, but it's necessary in old static
572
736
// graph. Add empyt function to avoid python call error.
@@ -1031,33 +1195,6 @@ py::str Value2String(Value self) {
1031
1195
return print_stream.str ();
1032
1196
}
1033
1197
1034
- phi::DataType GetTensorDtype (Type type) {
1035
- if (!type) {
1036
- PADDLE_THROW (phi::errors::InvalidArgument (" The type of value is nullptr." ));
1037
- }
1038
- if (auto dense_tensor_type = type.dyn_cast <DenseTensorType>()) {
1039
- return dialect::TransToPhiDataType (dense_tensor_type.dtype ());
1040
- } else if (auto sparse_coo_tensor_type =
1041
- type.dyn_cast <SparseCooTensorType>()) {
1042
- return dialect::TransToPhiDataType (sparse_coo_tensor_type.dtype ());
1043
- } else if (auto sparse_csr_tensor_type =
1044
- type.dyn_cast <SparseCsrTensorType>()) {
1045
- return dialect::TransToPhiDataType (sparse_csr_tensor_type.dtype ());
1046
- } else if (auto select_rows = type.dyn_cast <SelectedRowsType>()) {
1047
- return dialect::TransToPhiDataType (select_rows.dtype ());
1048
- } else if (auto dense_array = type.dyn_cast <DenseTensorArrayType>()) {
1049
- return dialect::TransToPhiDataType (dense_array.dtype ());
1050
- } else {
1051
- PADDLE_THROW (phi::errors::InvalidArgument (
1052
- " Currently, we can only get phi::DataType from DenseTensorType and "
1053
- " SelectedRowsType, DenseTensorArrayType,SparseCooTensorType or "
1054
- " SparseCsrTensorType." ));
1055
- }
1056
- }
1057
- phi::DataType GetValueDtype (Value value) {
1058
- return GetTensorDtype (value.type ());
1059
- }
1060
-
1061
1198
const phi::DDim &GetTensorDims (Type type) {
1062
1199
if (!type) {
1063
1200
PADDLE_THROW (common::errors::InvalidArgument (
0 commit comments