Skip to content

Commit bd3934d

Browse files
authored
[Inference] optimize device memory occupation (#64493)
1 parent e106052 commit bd3934d

File tree

2 files changed

+11
-0
lines changed

2 files changed

+11
-0
lines changed

paddle/fluid/pir/transforms/general/constant_folding_pass.cc

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,13 @@ class ConstantFoldingPattern : public pir::RewritePattern {
227227
}
228228
}
229229
rewriter.EraseOp(op);
230+
231+
// NOTE(liuyuanle): Here, we release one useless variable after another to
232+
// effectively reduce peak memory usage.
233+
if (deleted_vars_.size() > 0) {
234+
scope_->EraseVars(deleted_vars_);
235+
deleted_vars_.clear();
236+
}
230237
VLOG(4) << "constant_folding_pass applied rewrite on [" << op->name()
231238
<< "] op";
232239
}
@@ -306,6 +313,8 @@ class ConstantFoldingPattern : public pir::RewritePattern {
306313
if (op->operand_source(index).use_count() > 1) {
307314
from_op->set_attribute(kAttrIsPersistable,
308315
rewriter.array_attr({rewriter.bool_attr(true)}));
316+
} else {
317+
deleted_vars_.push_back(var_name);
309318
}
310319
return from_op;
311320
}
@@ -397,6 +406,7 @@ class ConstantFoldingPattern : public pir::RewritePattern {
397406
phi::Place place_;
398407
paddle::framework::Scope* scope_;
399408
paddle::framework::interpreter::ExecutionConfig* exe_config_;
409+
mutable std::vector<std::string> deleted_vars_;
400410
};
401411

402412
class ConstantFoldingPatternForTrain : public ConstantFoldingPattern {

paddle/fluid/pir/transforms/general/dead_code_elimination_pass.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ class DeadCodeEliminationPass : public pir::Pass {
7575
deleted_vars->push_back(constant_tensor_op.tensor_name());
7676
}
7777
op->Erase();
78+
VLOG(4) << "erase op: " << op->name();
7879
(*num_erasers)++;
7980
}
8081

0 commit comments

Comments
 (0)