1515#include " paddle/fluid/inference/analysis/passes/memory_optimize_pass.h"
1616#include < algorithm>
1717#include < fstream>
18+ #include < functional>
1819#include < limits>
1920#include < map>
2021#include < string>
@@ -38,6 +39,14 @@ using framework::ir::Node;
3839using framework::ir::TopologyVarientSort;
3940using space_table_t = MemoryOptimizePass::space_table_t ;
4041
42+ typedef struct {
43+ std::string name;
44+ size_t size;
45+ int cluster;
46+ std::pair<int , int > lifetime;
47+ std::unordered_set<std::string> adj;
48+ } MemNode;
49+
4150// Collect the lifecycles of the tensors.
4251// Traverse the graph in topological order.
4352// The traversal order also affect the lifecycles, so different sort_kind is
@@ -96,6 +105,89 @@ int DataTypeToSpace(framework::proto::VarType_Type type) {
96105 }
97106}
98107
108+ void MemoryOptimizePass::CollectVarMemorySize (
109+ space_table_t * space_table) const {
110+ const int fake_batch_size = 1 ;
111+ // Collect tensors from graph.
112+ for (auto * node : graph_->Nodes ()) {
113+ if (node->IsVar () &&
114+ node->Var ()->GetType () ==
115+ framework::proto::VarType::Type::VarType_Type_LOD_TENSOR) {
116+ // Parameters will not be reused.
117+ if (node->Var ()->Persistable ()) continue ;
118+ auto shape = node->Var ()->GetShape ();
119+ for (auto & v : shape) {
120+ if (v < 0 ) v = fake_batch_size;
121+ }
122+
123+ int size = std::accumulate (shape.begin (), shape.end (), 1 ,
124+ std::multiplies<int >());
125+ (*space_table)[node->Var ()->Name ()] =
126+ size * DataTypeToSpace (node->Var ()->GetDataType ());
127+ }
128+ }
129+ }
130+
131+ void MakeSimpleReusePlan (
132+ const std::unordered_map<std::string, std::pair<int , int >>& lifecycles,
133+ const std::unordered_map<std::string, size_t >& space_table,
134+ std::unordered_map<std::string, std::string>* node2cluster,
135+ std::unordered_map<std::string, int >* cluster_size) {
136+ std::vector<MemNode> mem_nodes;
137+ for (auto & data : lifecycles) {
138+ MemNode temp_node;
139+ temp_node.name = data.first ;
140+ PADDLE_ENFORCE (
141+ space_table.count (data.first ),
142+ " %s variable should be in the spacetable during memory optimize" ,
143+ data.first );
144+ temp_node.size = space_table.at (data.first );
145+ temp_node.cluster = -1 ;
146+ temp_node.lifetime = data.second ;
147+ mem_nodes.push_back (temp_node);
148+ }
149+ auto overlap = [](std::pair<int , int > a, std::pair<int , int > b) -> bool {
150+ return b.second >= a.first && a.second >= b.first ;
151+ };
152+ // If the lifetime of two nodes is overwritten, we set them as adjacent nodes.
153+ for (size_t i = 0 ; i < mem_nodes.size (); i++) {
154+ for (size_t j = i + 1 ; j < mem_nodes.size (); j++) {
155+ if (overlap (mem_nodes[i].lifetime , mem_nodes[j].lifetime )) {
156+ mem_nodes[i].adj .insert (mem_nodes[j].name );
157+ mem_nodes[j].adj .insert (mem_nodes[i].name );
158+ }
159+ }
160+ }
161+
162+ // Sort the nodes according to the node memory size.
163+ auto sort_func = [](MemNode a, MemNode b) { return a.size > b.size ; };
164+ std::sort (mem_nodes.begin (), mem_nodes.end (), sort_func);
165+
166+ // Generating Memory Reuse Strategy Based on Greedy Way
167+ for (size_t i = 0 ; i < mem_nodes.size (); i++) {
168+ if (mem_nodes[i].cluster >= 0 ) continue ;
169+ int cluster_index = cluster_size->size ();
170+ mem_nodes[i].cluster = cluster_index;
171+ (*cluster_size)[mem_nodes[i].name ] = mem_nodes[i].size ;
172+ (*node2cluster)[mem_nodes[i].name ] = mem_nodes[i].name ;
173+ std::unordered_set<std::string> cluster_adj = mem_nodes[i].adj ;
174+ for (size_t j = i + 1 ; j < mem_nodes.size (); j++) {
175+ if (mem_nodes[j].cluster < 0 &&
176+ (cluster_adj.find (mem_nodes[j].name ) == cluster_adj.end ())) {
177+ (*node2cluster)[mem_nodes[j].name ] = mem_nodes[i].name ;
178+ mem_nodes[j].cluster = cluster_index;
179+ for (auto & n : mem_nodes[j].adj ) {
180+ cluster_adj.insert (n);
181+ }
182+ }
183+ }
184+ }
185+ for (auto & cluster : *cluster_size) {
186+ LOG (INFO) << " Cluster name : " << cluster.first
187+ << " size: " << cluster.second ;
188+ }
189+ }
190+
99191// Collect the memory size of the tensors.
100192void MemoryOptimizePass::CollectVarMemorySize (
101193 const std::unordered_map<std::string, size_t >& batch_var_ave_dim,
@@ -377,6 +469,17 @@ void UpdateOpDescsByReuse(
377469 }
378470 }
379471
472+ // modify the graph
473+ for (auto input_node : node->inputs ) {
474+ PADDLE_ENFORCE (input_node->IsVar ());
475+ std::string input_node_name = input_node->Name ();
476+ if (reuse_table.count (input_node_name) &&
477+ reuse_table.at (input_node_name) != input_node_name) {
478+ auto name = reuse_table.at (input_node_name);
479+ input_node->RenameVar (name);
480+ }
481+ }
482+
380483 for (auto argument : node->Op ()->Outputs ()) {
381484 for (const auto & x : argument.second ) {
382485 auto name = x;
@@ -388,6 +491,17 @@ void UpdateOpDescsByReuse(
388491 }
389492 }
390493
494+ // modify the graph
495+ for (auto out_node : node->outputs ) {
496+ PADDLE_ENFORCE (out_node->IsVar ());
497+ std::string out_node_name = out_node->Name ();
498+ if (reuse_table.count (out_node_name) &&
499+ reuse_table.at (out_node_name) != out_node_name) {
500+ auto name = reuse_table.at (out_node_name);
501+ out_node->RenameVar (name);
502+ }
503+ }
504+
391505 // Update arguments.
392506 for (auto & arg : in_args) {
393507 node->Op ()->SetInput (arg.first , arg.second );
@@ -589,12 +703,24 @@ void MemoryOptimizePass::RunImpl(Argument* argument) {
589703 VLOG (3 ) << " Load memory cache from " << path;
590704 std::vector<std::map<std::string, std::vector<int >>> batches;
591705
592- if (argument->static_memory_optim () && inference::IsFileExists (path)) {
706+ if (!(argument->static_memory_optim () && inference::IsFileExists (path))) {
707+ string::PrettyLogInfo (" --- Performing dynamic memory optimize" );
708+ // batches = FakeBatchVarShapes(argument->main_program());
709+ int sort_kind = 0 ;
710+ std::unordered_map<std::string, lifecycle_t > lifecycles;
711+ space_table_t space_table;
712+ std::unordered_map<std::string, std::string> node2cluster;
713+ std::unordered_map<std::string, int > cluster_size;
714+
715+ CollectLifeCycle (&lifecycles, sort_kind);
716+ CollectVarMemorySize (&space_table);
717+ MakeSimpleReusePlan (lifecycles, space_table, &node2cluster, &cluster_size);
718+ UpdateOpDescsByReuse (graph_, node2cluster, sort_kind);
719+ return ;
720+
721+ } else {
593722 string::PrettyLogInfo (" --- Performing static memory optimize" );
594723 batches = DeseralizeBatchVarShapes (path);
595- } else {
596- string::PrettyLogInfo (" --- Performing dynamic memory optimize" );
597- batches = FakeBatchVarShapes (argument->main_program ());
598724 }
599725 auto var_batch_ave_size = GetBatchAverageSize (batches);
600726
0 commit comments