Skip to content

Commit 4e8d5a0

Browse files
authored
Light mem reuse strategy for inference. (#17925)
* fix: when use the load model from memory mode, the RAM occupy is high test=develop * ligth mem reuse test=develop * fix cpplint test=develop
1 parent 83942c3 commit 4e8d5a0

File tree

2 files changed

+134
-4
lines changed

2 files changed

+134
-4
lines changed

paddle/fluid/inference/analysis/passes/memory_optimize_pass.cc

Lines changed: 130 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include "paddle/fluid/inference/analysis/passes/memory_optimize_pass.h"
1616
#include <algorithm>
1717
#include <fstream>
18+
#include <functional>
1819
#include <limits>
1920
#include <map>
2021
#include <string>
@@ -38,6 +39,14 @@ using framework::ir::Node;
3839
using framework::ir::TopologyVarientSort;
3940
using space_table_t = MemoryOptimizePass::space_table_t;
4041

42+
typedef struct {
43+
std::string name;
44+
size_t size;
45+
int cluster;
46+
std::pair<int, int> lifetime;
47+
std::unordered_set<std::string> adj;
48+
} MemNode;
49+
4150
// Collect the lifecycles of the tensors.
4251
// Traverse the graph in topological order.
4352
// The traversal order also affect the lifecycles, so different sort_kind is
@@ -96,6 +105,89 @@ int DataTypeToSpace(framework::proto::VarType_Type type) {
96105
}
97106
}
98107

108+
void MemoryOptimizePass::CollectVarMemorySize(
109+
space_table_t* space_table) const {
110+
const int fake_batch_size = 1;
111+
// Collect tensors from graph.
112+
for (auto* node : graph_->Nodes()) {
113+
if (node->IsVar() &&
114+
node->Var()->GetType() ==
115+
framework::proto::VarType::Type::VarType_Type_LOD_TENSOR) {
116+
// Parameters will not be reused.
117+
if (node->Var()->Persistable()) continue;
118+
auto shape = node->Var()->GetShape();
119+
for (auto& v : shape) {
120+
if (v < 0) v = fake_batch_size;
121+
}
122+
123+
int size = std::accumulate(shape.begin(), shape.end(), 1,
124+
std::multiplies<int>());
125+
(*space_table)[node->Var()->Name()] =
126+
size * DataTypeToSpace(node->Var()->GetDataType());
127+
}
128+
}
129+
}
130+
131+
void MakeSimpleReusePlan(
132+
const std::unordered_map<std::string, std::pair<int, int>>& lifecycles,
133+
const std::unordered_map<std::string, size_t>& space_table,
134+
std::unordered_map<std::string, std::string>* node2cluster,
135+
std::unordered_map<std::string, int>* cluster_size) {
136+
std::vector<MemNode> mem_nodes;
137+
for (auto& data : lifecycles) {
138+
MemNode temp_node;
139+
temp_node.name = data.first;
140+
PADDLE_ENFORCE(
141+
space_table.count(data.first),
142+
"%s variable should be in the spacetable during memory optimize",
143+
data.first);
144+
temp_node.size = space_table.at(data.first);
145+
temp_node.cluster = -1;
146+
temp_node.lifetime = data.second;
147+
mem_nodes.push_back(temp_node);
148+
}
149+
auto overlap = [](std::pair<int, int> a, std::pair<int, int> b) -> bool {
150+
return b.second >= a.first && a.second >= b.first;
151+
};
152+
// If the lifetime of two nodes is overwritten, we set them as adjacent nodes.
153+
for (size_t i = 0; i < mem_nodes.size(); i++) {
154+
for (size_t j = i + 1; j < mem_nodes.size(); j++) {
155+
if (overlap(mem_nodes[i].lifetime, mem_nodes[j].lifetime)) {
156+
mem_nodes[i].adj.insert(mem_nodes[j].name);
157+
mem_nodes[j].adj.insert(mem_nodes[i].name);
158+
}
159+
}
160+
}
161+
162+
// Sort the nodes according to the node memory size.
163+
auto sort_func = [](MemNode a, MemNode b) { return a.size > b.size; };
164+
std::sort(mem_nodes.begin(), mem_nodes.end(), sort_func);
165+
166+
// Generating Memory Reuse Strategy Based on Greedy Way
167+
for (size_t i = 0; i < mem_nodes.size(); i++) {
168+
if (mem_nodes[i].cluster >= 0) continue;
169+
int cluster_index = cluster_size->size();
170+
mem_nodes[i].cluster = cluster_index;
171+
(*cluster_size)[mem_nodes[i].name] = mem_nodes[i].size;
172+
(*node2cluster)[mem_nodes[i].name] = mem_nodes[i].name;
173+
std::unordered_set<std::string> cluster_adj = mem_nodes[i].adj;
174+
for (size_t j = i + 1; j < mem_nodes.size(); j++) {
175+
if (mem_nodes[j].cluster < 0 &&
176+
(cluster_adj.find(mem_nodes[j].name) == cluster_adj.end())) {
177+
(*node2cluster)[mem_nodes[j].name] = mem_nodes[i].name;
178+
mem_nodes[j].cluster = cluster_index;
179+
for (auto& n : mem_nodes[j].adj) {
180+
cluster_adj.insert(n);
181+
}
182+
}
183+
}
184+
}
185+
for (auto& cluster : *cluster_size) {
186+
LOG(INFO) << "Cluster name : " << cluster.first
187+
<< " size: " << cluster.second;
188+
}
189+
}
190+
99191
// Collect the memory size of the tensors.
100192
void MemoryOptimizePass::CollectVarMemorySize(
101193
const std::unordered_map<std::string, size_t>& batch_var_ave_dim,
@@ -377,6 +469,17 @@ void UpdateOpDescsByReuse(
377469
}
378470
}
379471

472+
// modify the graph
473+
for (auto input_node : node->inputs) {
474+
PADDLE_ENFORCE(input_node->IsVar());
475+
std::string input_node_name = input_node->Name();
476+
if (reuse_table.count(input_node_name) &&
477+
reuse_table.at(input_node_name) != input_node_name) {
478+
auto name = reuse_table.at(input_node_name);
479+
input_node->RenameVar(name);
480+
}
481+
}
482+
380483
for (auto argument : node->Op()->Outputs()) {
381484
for (const auto& x : argument.second) {
382485
auto name = x;
@@ -388,6 +491,17 @@ void UpdateOpDescsByReuse(
388491
}
389492
}
390493

494+
// modify the graph
495+
for (auto out_node : node->outputs) {
496+
PADDLE_ENFORCE(out_node->IsVar());
497+
std::string out_node_name = out_node->Name();
498+
if (reuse_table.count(out_node_name) &&
499+
reuse_table.at(out_node_name) != out_node_name) {
500+
auto name = reuse_table.at(out_node_name);
501+
out_node->RenameVar(name);
502+
}
503+
}
504+
391505
// Update arguments.
392506
for (auto& arg : in_args) {
393507
node->Op()->SetInput(arg.first, arg.second);
@@ -589,12 +703,24 @@ void MemoryOptimizePass::RunImpl(Argument* argument) {
589703
VLOG(3) << "Load memory cache from " << path;
590704
std::vector<std::map<std::string, std::vector<int>>> batches;
591705

592-
if (argument->static_memory_optim() && inference::IsFileExists(path)) {
706+
if (!(argument->static_memory_optim() && inference::IsFileExists(path))) {
707+
string::PrettyLogInfo("--- Performing dynamic memory optimize");
708+
// batches = FakeBatchVarShapes(argument->main_program());
709+
int sort_kind = 0;
710+
std::unordered_map<std::string, lifecycle_t> lifecycles;
711+
space_table_t space_table;
712+
std::unordered_map<std::string, std::string> node2cluster;
713+
std::unordered_map<std::string, int> cluster_size;
714+
715+
CollectLifeCycle(&lifecycles, sort_kind);
716+
CollectVarMemorySize(&space_table);
717+
MakeSimpleReusePlan(lifecycles, space_table, &node2cluster, &cluster_size);
718+
UpdateOpDescsByReuse(graph_, node2cluster, sort_kind);
719+
return;
720+
721+
} else {
593722
string::PrettyLogInfo("--- Performing static memory optimize");
594723
batches = DeseralizeBatchVarShapes(path);
595-
} else {
596-
string::PrettyLogInfo("--- Performing dynamic memory optimize");
597-
batches = FakeBatchVarShapes(argument->main_program());
598724
}
599725
auto var_batch_ave_size = GetBatchAverageSize(batches);
600726

paddle/fluid/inference/analysis/passes/memory_optimize_pass.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414

1515
#pragma once
1616
#include <string>
17+
#include <unordered_map>
18+
#include <unordered_set>
1719
#include <utility>
1820
#include <vector>
1921
#include "paddle/fluid/inference/analysis/analysis_pass.h"
@@ -72,6 +74,8 @@ class MemoryOptimizePass : public AnalysisPass {
7274
std::unordered_map<std::string, lifecycle_t> *lifecycles,
7375
int sort_kind) const;
7476

77+
void CollectVarMemorySize(space_table_t *space_table) const;
78+
7579
void CollectVarMemorySize(
7680
const std::unordered_map<std::string, size_t> &batch_var_ave_dim,
7781
std::unordered_map<std::string, framework::ir::Node *> *tensor_nodes,

0 commit comments

Comments
 (0)