25
25
#include " cinn/auto_schedule/tuning.h"
26
26
#include " cinn/common/target.h"
27
27
#include " cinn/common/type.h"
28
- #include " cinn/frontend/decomposer/use_decomposer.h"
29
- #include " cinn/frontend/pass/use_program_pass.h"
30
- #include " cinn/frontend/program_pass.h"
28
+ #include " cinn/frontend/optimize.h"
31
29
#include " cinn/frontend/syntax.h"
32
30
#include " cinn/hlir/framework/graph.h"
33
31
#include " cinn/hlir/framework/graph_compiler.h"
34
- #include " cinn/hlir/framework/pass.h"
35
- #include " cinn/hlir/pass/use_pass.h"
36
32
#include " gflags/gflags.h"
37
33
#include " paddle/fluid/framework/framework.pb.h"
38
34
#include " paddle/fluid/framework/ir/graph.h"
@@ -58,13 +54,11 @@ namespace paddle2cinn {
58
54
using ir::Graph;
59
55
using ir::Node;
60
56
using inference::analysis::Dot;
61
- using ::cinn::common::Target;
62
- using ::cinn::common::Float;
63
- using ::cinn::hlir::framework::GraphCompiler;
64
57
using ::cinn::auto_schedule::AutoTuner;
58
+ using ::cinn::common::Target;
59
+ using ::cinn::frontend::Optimize;
65
60
using ::cinn::hlir::framework::BuildScope;
66
- using ::cinn::frontend::ProgramPass;
67
- using ::cinn::hlir::framework::ApplyPass;
61
+ using ::cinn::hlir::framework::GraphCompiler;
68
62
69
63
CinnCompiler* CinnCompiler::GetInstance () {
70
64
static CinnCompiler instance;
@@ -75,7 +69,7 @@ const CinnCompiledObject& CinnCompiler::Compile(
75
69
const Graph& graph,
76
70
const std::map<std::string, const LoDTensor*>& input_tensors,
77
71
const Target& target, void * stream) {
78
- VLOG (1 ) << " -- The graph to be compiled is:\n " << VizGraph (graph);
72
+ VLOG (4 ) << " -- The graph to be compiled is:\n " << VizGraph (graph);
79
73
CinnCacheKeyByAddress cur_key_by_address (graph, input_tensors,
80
74
target.arch_str ());
81
75
CinnCacheKeyByStructure cur_key_by_struct;
@@ -258,22 +252,15 @@ std::unique_ptr<CinnCompiledObject> CinnCompiler::CompileGraph(
258
252
CinnGraphSymbolization symbol{compiled_num, graph, target, input_tensors};
259
253
auto frontend_program = symbol ();
260
254
auto fetch_ids = symbol.GetFetchIds ();
261
- ProgramPass::Apply (&frontend_program, fetch_ids, target, {" Decomposer" });
262
- ::cinn::frontend::ApplyPass (&frontend_program, fetch_ids, " RemoveIdentity" );
263
- ::cinn::frontend::ApplyPass (&frontend_program, fetch_ids, " TransposeFolding" );
264
- ProgramPass::Apply (&frontend_program, fetch_ids, target, {" GemmRewriter" });
255
+ VLOG (4 ) << " All fetch var ids in CINN: "
256
+ << string::join_strings (fetch_ids, ' ,' );
265
257
266
- auto cinn_graph = std::make_shared<::cinn::hlir::framework::Graph>(
267
- frontend_program, target);
268
- VLOG (1 ) << " -- The " << compiled_num << " -th compilation ("
258
+ auto cinn_graph = Optimize (&frontend_program, fetch_ids, target);
259
+ VLOG (4 ) << " -- The " << compiled_num << " -th compilation ("
269
260
<< target.arch_str () << " ), and its related graph:\n "
270
261
<< cinn_graph->Visualize ();
271
- ApplyPass (cinn_graph.get (), " OpFusion" );
272
- auto scope = BuildScope (target, cinn_graph);
273
-
274
- VLOG (4 ) << " All fetch var ids in CINN: "
275
- << string::join_strings (fetch_ids, ' ,' );
276
262
263
+ auto scope = BuildScope (target, cinn_graph);
277
264
auto graph_compiler =
278
265
std::make_unique<GraphCompiler>(target, scope, cinn_graph);
279
266
GraphCompiler::CompileOptions options;
0 commit comments