File tree Expand file tree Collapse file tree 2 files changed +11
-0
lines changed
paddle/fluid/pir/transforms/general Expand file tree Collapse file tree 2 files changed +11
-0
lines changed Original file line number Diff line number Diff line change @@ -227,6 +227,13 @@ class ConstantFoldingPattern : public pir::RewritePattern {
227227 }
228228 }
229229 rewriter.EraseOp (op);
230+
231+ // NOTE(liuyuanle): Here, we release one useless variable after another to
232+ // effectively reduce peak memory usage.
233+ if (deleted_vars_.size () > 0 ) {
234+ scope_->EraseVars (deleted_vars_);
235+ deleted_vars_.clear ();
236+ }
230237 VLOG (4 ) << " constant_folding_pass applied rewrite on [" << op->name ()
231238 << " ] op" ;
232239 }
@@ -306,6 +313,8 @@ class ConstantFoldingPattern : public pir::RewritePattern {
306313 if (op->operand_source (index).use_count () > 1 ) {
307314 from_op->set_attribute (kAttrIsPersistable ,
308315 rewriter.array_attr ({rewriter.bool_attr (true )}));
316+ } else {
317+ deleted_vars_.push_back (var_name);
309318 }
310319 return from_op;
311320 }
@@ -397,6 +406,7 @@ class ConstantFoldingPattern : public pir::RewritePattern {
397406 phi::Place place_;
398407 paddle::framework::Scope* scope_;
399408 paddle::framework::interpreter::ExecutionConfig* exe_config_;
409+ mutable std::vector<std::string> deleted_vars_;
400410};
401411
402412class ConstantFoldingPatternForTrain : public ConstantFoldingPattern {
Original file line number Diff line number Diff line change @@ -75,6 +75,7 @@ class DeadCodeEliminationPass : public pir::Pass {
7575 deleted_vars->push_back (constant_tensor_op.tensor_name ());
7676 }
7777 op->Erase ();
78+ VLOG (4 ) << " erase op: " << op->name ();
7879 (*num_erasers)++;
7980 }
8081
You can’t perform that action at this time.
0 commit comments