Skip to content

Commit

Permalink
add graph_key to specific graph's varmap (#60567)
Browse files Browse the repository at this point in the history
* add graph_key to specific graph's varmap

* fix inpalce case

* fix inpalce case
  • Loading branch information
GGBond8488 authored Jan 12, 2024
1 parent 823b94e commit 600fc2f
Show file tree
Hide file tree
Showing 6 changed files with 39 additions and 12 deletions.
3 changes: 3 additions & 0 deletions paddle/fluid/framework/io/save_paddle2cinn_varmap.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,11 @@ namespace framework {

void save_paddle2cinn_varmap(
std::unordered_map<std::string, std::string> paddle2cinn_var_map,
int64_t graph_compilation_key,
std::string save_path) {
std::stringstream ss;
ss << "graph_compilation_key:" << std::to_string(graph_compilation_key)
<< "\n";
for (const auto& kv : paddle2cinn_var_map) {
ss << kv.first << ":" << kv.second << "\n";
}
Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/framework/io/save_paddle2cinn_varmap.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ namespace framework {

void save_paddle2cinn_varmap(
std::unordered_map<std::string, std::string> paddle2cinn_var_map,
int64_t graph_compilation_key,
std::string save_path);

}
Expand Down
20 changes: 19 additions & 1 deletion paddle/fluid/framework/io/save_runtime_graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,21 @@ void save_string(std::string content,
fout.close();
}

void save_graph_compilation_key(int64_t graph_compilation_key,
std::string type,
std::string saved_path) {
VLOG(6) << type << " will be saved to " << saved_path;
MkDirRecursively(DirName(saved_path).c_str());

std::ofstream fout(saved_path);
PADDLE_ENFORCE_EQ(
static_cast<bool>(fout),
true,
phi::errors::Unavailable("Cannot open %s to save ", saved_path));
fout << std::to_string(graph_compilation_key);
fout.close();
}

std::string node_format(const ir::Node& node, int number) {
return "node_" + std::to_string(number) + " : " + "[" + node.Name() + ", " +
(node.IsOp() ? "op" : "var") + "]";
Expand Down Expand Up @@ -78,6 +93,7 @@ void save_graph(const ir::Graph& graph,
}

void save_runtime_cinn_graph(const ir::Graph& graph,
int64_t graph_compilation_key,
std::string clusters_ops,
std::string clusters_inputs,
std::string cluster_outputs,
Expand All @@ -91,7 +107,9 @@ void save_runtime_cinn_graph(const ir::Graph& graph,
save_string(cluster_intervals,
"cluster_intervals",
saved_path + "/cluster_intervals.txt");

save_graph_compilation_key(graph_compilation_key,
"graph_compilation_key",
saved_path + "/graph_compilation_key.txt");
save_graph(graph, "graph", saved_path + "/subgraph.txt");
}

Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/framework/io/save_runtime_graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ limitations under the License. */
namespace paddle {
namespace framework {
void save_runtime_cinn_graph(const ir::Graph& graph,
int64_t graph_compilation_key,
std::string clusters_ops,
std::string clusters_inputs,
std::string cluster_outputs,
Expand Down
11 changes: 6 additions & 5 deletions paddle/fluid/framework/paddle2cinn/build_cinn_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -753,20 +753,21 @@ void SearchAllSubgraphs(Graph* graph, bool is_inference_stage) {
subgraph->GetOrInit<std::unordered_set<std::string>>(kSkipGcVarNames);
sub_skip_gc_vars = all_skip_gc_vars;
}
auto compilation_key = cinn_compiler->AddGraph(std::move(subgraph));
VLOG(4) << "Compilation Key:\n"
<< cinn_compiler->ReadableKey(compilation_key);

if (FLAGS_save_static_runtime_data) {
paddle::framework::save_runtime_cinn_graph(
*subgraph,
cinn_compiler->FindGraph(compilation_key),
compilation_key,
cluster_debug_info(cluster_set),
cluster_debug_info(cluster_inputs),
cluster_debug_info(cluster_outputs),
cluster_debug_info(cluster_internals),
FLAGS_static_runtime_data_save_path + "/cluster_" +
std::to_string(++i));
}
auto compilation_key = cinn_compiler->AddGraph(std::move(subgraph));
VLOG(4) << "Compilation Key:\n"
<< cinn_compiler->ReadableKey(compilation_key);

// Replace the found cluster to a new cinn op node
ReplaceSubGraphWithCinnOpNode(cluster_set,
cluster_inputs,
Expand Down
15 changes: 9 additions & 6 deletions paddle/fluid/operators/cinn/cinn_launch_context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,15 @@ CinnLaunchContext::CinnLaunchContext(const framework::ir::Graph& graph,
[](const auto& name_view) { return std::string(name_view.data()); });
// build name map between the original variables and compiled ones
BuildVarNameMap(compiled_obj.paddle2cinn_varmap, cinn_argument_names_);
if (FLAGS_save_static_runtime_data) {
auto graph_compilation_key =
std::hash<const framework::ir::Graph*>()((&graph));
paddle::framework::save_paddle2cinn_varmap(
paddle2cinn_varmap_,
graph_compilation_key,
FLAGS_static_runtime_data_save_path +
"/paddle2cinn_varmap/paddle2cinn_varmap.txt");
}

const auto& input_var_names =
graph.Get<std::vector<std::string>>(framework::paddle2cinn::kInputVars);
Expand Down Expand Up @@ -193,12 +202,6 @@ void CinnLaunchContext::BuildVarNameMap(
"Size of variables is not euqal, paddle[%ld] vs cinn[%ld]",
paddle2cinn_varmap_.size(),
cinn2paddle_varmap_.size()));
if (FLAGS_save_static_runtime_data) {
paddle::framework::save_paddle2cinn_varmap(
paddle2cinn_varmap_,
FLAGS_static_runtime_data_save_path +
"/paddle2cinn_varmap/paddle2cinn_varmap.txt");
}
}

std::unordered_set<std::string> CinnLaunchContext::GetVisibleVarNames() const {
Expand Down

0 comments on commit 600fc2f

Please sign in to comment.