Skip to content

Commit 2f2f987

Browse files
authored
[Cherry-Pick]Move pass optimizations into CINN. (#42047) (#42070)
* Move pass optimizations into CINN.
1 parent dbdb56d commit 2f2f987

File tree

2 files changed

+11
-24
lines changed

2 files changed

+11
-24
lines changed

cmake/external/cinn.cmake

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ add_definitions(-w)
2626
######################################
2727
include(ExternalProject)
2828
set(CINN_PREFIX_DIR ${THIRD_PARTY_PATH}/CINN)
29-
set(CINN_GIT_TAG 08d7680dd91dfaa65787969050eb8f1143654f10)
29+
set(CINN_GIT_TAG release/v0.2)
3030
set(CINN_OPTIONAL_ARGS -DPY_VERSION=${PY_VERSION}
3131
-DWITH_CUDA=${WITH_GPU}
3232
-DWITH_CUDNN=${WITH_GPU}

paddle/fluid/framework/paddle2cinn/cinn_compiler.cc

Lines changed: 10 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -25,14 +25,10 @@
2525
#include "cinn/auto_schedule/tuning.h"
2626
#include "cinn/common/target.h"
2727
#include "cinn/common/type.h"
28-
#include "cinn/frontend/decomposer/use_decomposer.h"
29-
#include "cinn/frontend/pass/use_program_pass.h"
30-
#include "cinn/frontend/program_pass.h"
28+
#include "cinn/frontend/optimize.h"
3129
#include "cinn/frontend/syntax.h"
3230
#include "cinn/hlir/framework/graph.h"
3331
#include "cinn/hlir/framework/graph_compiler.h"
34-
#include "cinn/hlir/framework/pass.h"
35-
#include "cinn/hlir/pass/use_pass.h"
3632
#include "gflags/gflags.h"
3733
#include "paddle/fluid/framework/framework.pb.h"
3834
#include "paddle/fluid/framework/ir/graph.h"
@@ -58,13 +54,11 @@ namespace paddle2cinn {
5854
using ir::Graph;
5955
using ir::Node;
6056
using inference::analysis::Dot;
61-
using ::cinn::common::Target;
62-
using ::cinn::common::Float;
63-
using ::cinn::hlir::framework::GraphCompiler;
6457
using ::cinn::auto_schedule::AutoTuner;
58+
using ::cinn::common::Target;
59+
using ::cinn::frontend::Optimize;
6560
using ::cinn::hlir::framework::BuildScope;
66-
using ::cinn::frontend::ProgramPass;
67-
using ::cinn::hlir::framework::ApplyPass;
61+
using ::cinn::hlir::framework::GraphCompiler;
6862

6963
CinnCompiler* CinnCompiler::GetInstance() {
7064
static CinnCompiler instance;
@@ -75,7 +69,7 @@ const CinnCompiledObject& CinnCompiler::Compile(
7569
const Graph& graph,
7670
const std::map<std::string, const LoDTensor*>& input_tensors,
7771
const Target& target, void* stream) {
78-
VLOG(1) << "-- The graph to be compiled is:\n" << VizGraph(graph);
72+
VLOG(4) << "-- The graph to be compiled is:\n" << VizGraph(graph);
7973
CinnCacheKeyByAddress cur_key_by_address(graph, input_tensors,
8074
target.arch_str());
8175
CinnCacheKeyByStructure cur_key_by_struct;
@@ -258,22 +252,15 @@ std::unique_ptr<CinnCompiledObject> CinnCompiler::CompileGraph(
258252
CinnGraphSymbolization symbol{compiled_num, graph, target, input_tensors};
259253
auto frontend_program = symbol();
260254
auto fetch_ids = symbol.GetFetchIds();
261-
ProgramPass::Apply(&frontend_program, fetch_ids, target, {"Decomposer"});
262-
::cinn::frontend::ApplyPass(&frontend_program, fetch_ids, "RemoveIdentity");
263-
::cinn::frontend::ApplyPass(&frontend_program, fetch_ids, "TransposeFolding");
264-
ProgramPass::Apply(&frontend_program, fetch_ids, target, {"GemmRewriter"});
255+
VLOG(4) << "All fetch var ids in CINN: "
256+
<< string::join_strings(fetch_ids, ',');
265257

266-
auto cinn_graph = std::make_shared<::cinn::hlir::framework::Graph>(
267-
frontend_program, target);
268-
VLOG(1) << "-- The " << compiled_num << "-th compilation ("
258+
auto cinn_graph = Optimize(&frontend_program, fetch_ids, target);
259+
VLOG(4) << "-- The " << compiled_num << "-th compilation ("
269260
<< target.arch_str() << "), and its related graph:\n"
270261
<< cinn_graph->Visualize();
271-
ApplyPass(cinn_graph.get(), "OpFusion");
272-
auto scope = BuildScope(target, cinn_graph);
273-
274-
VLOG(4) << "All fetch var ids in CINN: "
275-
<< string::join_strings(fetch_ids, ',');
276262

263+
auto scope = BuildScope(target, cinn_graph);
277264
auto graph_compiler =
278265
std::make_unique<GraphCompiler>(target, scope, cinn_graph);
279266
GraphCompiler::CompileOptions options;

0 commit comments

Comments
 (0)