Skip to content

Commit cb24090

Browse files
author
Menooker
authored
Add all-in-one pass pipeline (#75)
1 parent 05fcc76 commit cb24090

File tree

6 files changed

+230
-0
lines changed

6 files changed

+230
-0
lines changed

include/gc/Transforms/Passes.h

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,34 @@
1212
#include "mlir/Pass/Pass.h"
1313

1414
namespace mlir {
15+
16+
namespace LLVM {
17+
class LLVMDialect;
18+
}
19+
20+
namespace scf {
21+
class SCFDialect;
22+
}
23+
24+
namespace openmp {
25+
class OpenMPDialect;
26+
}
27+
28+
namespace linalg {
29+
class LinalgDialect;
30+
}
31+
32+
namespace MemRef {
33+
class MemRefDialect;
34+
}
35+
36+
class PassManager;
37+
1538
namespace gc {
1639

40+
void populateFrontendPasses(mlir::PassManager &);
41+
void populateCPUPipeline(mlir::PassManager &);
42+
1743
#define GEN_PASS_DECL
1844
#include "gc/Transforms/Passes.h.inc"
1945

include/gc/Transforms/Passes.td

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,4 +31,17 @@ def ConvertOneDNNGraphToLinalg : Pass<"convert-onednn-graph-to-linalg"> {
3131
];
3232
}
3333

34+
def GCCPUPipeline: Pass<"gc-cpu-pipeline"> {
35+
let summary = "All-in-one pipeline for GC for CPU";
36+
let dependentDialects = ["onednn_graph::OneDNNGraphDialect",
37+
"tensor::TensorDialect",
38+
"memref::MemRefDialect",
39+
"linalg::LinalgDialect",
40+
"LLVM::LLVMDialect",
41+
"scf::SCFDialect",
42+
"bufferization::BufferizationDialect",
43+
"omp::OpenMPDialect",
44+
"vector::VectorDialect"];
45+
}
46+
3447
#endif // GC_DIALECT_GC_PASSES

lib/gc/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ gc_set_mlir_link_components(MLIR_LINK_COMPONENTS
66

77
add_mlir_library(GCPasses
88
OneDNNGraphToLinalg.cpp
9+
Pipeline.cpp
910
TileNamed.cpp
1011

1112
ADDITIONAL_HEADER_DIRS

lib/gc/Transforms/Pipeline.cpp

Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
//===- Pipeline.cpp - Graph Compiler all-in-one pipeline --------*- C++ -*-===//
2+
//
3+
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "mlir/Conversion/Passes.h"
10+
#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
11+
#include "mlir/Dialect/Bufferization/Transforms/Passes.h"
12+
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
13+
#include "mlir/Dialect/Linalg/Passes.h"
14+
#include "mlir/Dialect/MemRef/IR/MemRef.h"
15+
#include "mlir/Dialect/MemRef/Transforms/Passes.h"
16+
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
17+
#include "mlir/Dialect/SCF/IR/SCF.h"
18+
#include "mlir/Dialect/Tensor/IR/Tensor.h"
19+
#include "mlir/InitAllPasses.h"
20+
#include "mlir/Pass/PassManager.h"
21+
#include "mlir/Support/LogicalResult.h"
22+
#include "mlir/Transforms/Passes.h"
23+
24+
#include "gc/Dialect/CPURuntime/Transforms/CPURuntimePasses.h"
25+
#include "gc/Dialect/OneDNNGraph/OneDNNGraphDialect.h"
26+
#include "gc/Transforms/Passes.h"
27+
28+
namespace mlir::gc {
29+
30+
// linalg + linalgX + tensor
31+
void populateFrontendPasses(mlir::PassManager &pm) {
32+
// pm.addPass(onednn_graph::createConvertOneDNNGraphToLinalg());
33+
}
34+
35+
// scf + arith + math + vector + tensor + linalg.brgemm + tensor.pack/unpack
36+
void populateTensorPasses(mlir::PassManager &pm) {
37+
// todo: padding propagation pass
38+
// todo: layout propagation pass
39+
// todo: tensor constant propagation pass
40+
// todo: linalg.matmul lowering to (scf.loop + linalg.brgemm) pass
41+
// todo: fine-grain fusion pass
42+
// todo: lower linalg to arith/math on virtual vector pass
43+
44+
// REMOVE this pass after the above passes are added. Currently we add this
45+
// pass to make the pipeline work properly
46+
pm.addNestedPass<func::FuncOp>(createLinalgGeneralizeNamedOpsPass());
47+
}
48+
49+
// scf + arith + math + vector + tensor + linalg.brgemm
50+
void populateVectorPasses(mlir::PassManager &pm) {
51+
// todo: bf16 promotion pass, device dependent pass
52+
// todo: bf16 cast elimilation pass, fast-math kind pass, designed to support
53+
// oneDNN graph spec
54+
pm.addNestedPass<func::FuncOp>(arith::createArithExpandOpsPass());
55+
// todo: lower to physical vector pass, device dependent pass
56+
}
57+
58+
// scf + arith + math + vector + memref + linalg.brgemm
59+
void populateBufferizationPasses(mlir::PassManager &pm) {
60+
bufferization::OneShotBufferizationOptions options;
61+
pm.addPass(bufferization::createOneShotBufferizePass(options));
62+
pm.addPass(createCSEPass());
63+
pm.addPass(mlir::func::createFuncBufferizePass());
64+
pm.addNestedPass<func::FuncOp>(
65+
bufferization::createBufferizationBufferizePass());
66+
pm.addNestedPass<func::FuncOp>(
67+
bufferization::createFinalizingBufferizePass());
68+
bufferization::BufferResultsToOutParamsOpts opt{};
69+
opt.hoistStaticAllocs = true;
70+
pm.addPass(bufferization::createBufferResultsToOutParamsPass(opt));
71+
// todo: buffer schedule pass
72+
// todo: Need to improve this pass to support nested parallel.
73+
pm.addNestedPass<func::FuncOp>(bufferization::createBufferHoistingPass());
74+
pm.addNestedPass<func::FuncOp>(bufferization::createBufferLoopHoistingPass());
75+
pm.addNestedPass<func::FuncOp>(bufferization::createBufferDeallocationPass());
76+
pm.addPass(createBufferizationToMemRefPass());
77+
}
78+
79+
// scf + arith + math + vector + memref + func/microkernel
80+
void populateMicroKernelPasses(mlir::PassManager &pm) {
81+
// todo: ConvertLinalgToMicrokernel pass
82+
// todo: CleanupInvalidMicrokernel pass
83+
// todo: InvariantMicrokernelMotion pass
84+
// todo: ConvertMicrokernelToDnnlFunc to lower brgemm to dnnl call
85+
// todo: ConvertMicrokernelToXsmm, to lower brgemm to libxsmm call
86+
// todo: LowerMicrokernel pass
87+
// todo: DispatchMicrokernel
88+
}
89+
90+
void populateCPURuntimePasses(mlir::PassManager &pm) {
91+
// todo: flatten nested parallel pass to support coarse-grain usion
92+
// remove this pass after we add FlattenNestedParallel
93+
pm.addPass(createConvertSCFToOpenMPPass());
94+
}
95+
96+
void populateLoweringToLLVMPasses(mlir::PassManager &pm) {
97+
pm.addPass(createConvertSCFToCFPass());
98+
pm.addPass(cpuruntime::createCPURuntimeToLLVM());
99+
pm.addPass(createConvertOpenMPToLLVMPass());
100+
pm.addNestedPass<func::FuncOp>(createConvertMathToLLVMPass());
101+
pm.addPass(createConvertMathToLibmPass());
102+
pm.addPass(createFinalizeMemRefToLLVMConversionPass());
103+
pm.addNestedPass<func::FuncOp>(createArithToLLVMConversionPass());
104+
pm.addPass(createConvertFuncToLLVMPass());
105+
pm.addPass(createConvertControlFlowToLLVMPass());
106+
pm.addPass(createCSEPass());
107+
pm.addPass(createCanonicalizerPass());
108+
pm.addPass(createReconcileUnrealizedCastsPass());
109+
pm.addPass(createSymbolDCEPass());
110+
}
111+
112+
void populateLLVMPasses(mlir::PassManager &pm) {
113+
pm.addPass(memref::createExpandOpsPass());
114+
pm.addPass(memref::createExpandStridedMetadataPass());
115+
populateLoweringToLLVMPasses(pm);
116+
}
117+
118+
void populateCPUPipeline(mlir::PassManager &pm) {
119+
// front-end, oneDNN graph dialect
120+
populateFrontendPasses(pm);
121+
// middle-end, LinalgX/Linalg/tensor dialects
122+
populateTensorPasses(pm);
123+
// middle-end, arith/math/vector dialects
124+
populateVectorPasses(pm);
125+
// back-end, arith/math/vector/memref dialects
126+
populateBufferizationPasses(pm);
127+
// REMOVE this pass after the TensorPasses are added. Currently we add this
128+
// pass to make the pipeline work properly
129+
pm.addNestedPass<func::FuncOp>(createConvertLinalgToParallelLoopsPass());
130+
populateMicroKernelPasses(pm);
131+
populateCPURuntimePasses(pm);
132+
// // back-end, llvm dialect
133+
populateLLVMPasses(pm);
134+
}
135+
136+
#define GEN_PASS_DEF_GCCPUPIPELINE
137+
#include "gc/Transforms/Passes.h.inc"
138+
namespace {
139+
140+
class GCCPUPipeline : public impl::GCCPUPipelineBase<GCCPUPipeline> {
141+
public:
142+
friend struct PassHelper;
143+
using impl::GCCPUPipelineBase<GCCPUPipeline>::GCCPUPipelineBase;
144+
void runOnOperation() final {
145+
auto op = getOperation();
146+
PassManager pm{op->getContext()};
147+
populateCPUPipeline(pm);
148+
if (failed(pm.run(op)))
149+
signalPassFailure();
150+
}
151+
};
152+
153+
} // namespace
154+
} // namespace mlir::gc

test/gc/Transforms/Pipeline/run.mlir

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
// RUN: gc-opt %s --gc-cpu-pipeline | gc-cpu-runner -e main -entry-point-result=void | FileCheck %s
2+
3+
module {
4+
func.func @aaa() -> tensor<128xf32> {
5+
%c2 = arith.constant 2.0 : f32
6+
%a = tensor.empty() : tensor<128xf32>
7+
%2 = linalg.fill ins(%c2 : f32) outs(%a : tensor<128xf32>) -> tensor<128xf32>
8+
return %2 : tensor<128xf32>
9+
}
10+
11+
func.func @main() {
12+
%result = call @aaa() : ()-> tensor<128xf32>
13+
%c0 = arith.constant 0 : index
14+
%c128 = arith.constant 128 : index
15+
%c1 = arith.constant 1 : index
16+
scf.for %iv = %c0 to %c128 step %c1 {
17+
%4 = tensor.extract %result[%iv] : tensor<128xf32>
18+
cpuruntime.printf "%f\n" %4 : f32
19+
}
20+
return
21+
}
22+
// CHECK-COUNT-128: 2.000000
23+
}
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
// RUN: gc-opt %s --gc-cpu-pipeline | FileCheck %s
2+
3+
module {
4+
// CHECK: aaa
5+
// check that the func returns void
6+
// CHECK-NOT: ) -> !llvm.struct<
7+
func.func @aaa(%a: tensor<128xf32>, %b: tensor<128xf32>) -> tensor<128xf32> {
8+
%out = tensor.empty() : tensor<128xf32>
9+
%2 = linalg.add ins(%a, %b : tensor<128xf32>,tensor<128xf32>) outs(%out : tensor<128xf32>) -> tensor<128xf32>
10+
// CHECK-NOT: memcpy
11+
return %2 : tensor<128xf32>
12+
}
13+
}

0 commit comments

Comments
 (0)