|
| 1 | +//===- Pipeline.cpp - Graph Compiler all-in-one pipeline --------*- C++ -*-===// |
| 2 | +// |
| 3 | +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. |
| 4 | +// See https://llvm.org/LICENSE.txt for license information. |
| 5 | +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | +// |
| 7 | +//===----------------------------------------------------------------------===// |
| 8 | + |
| 9 | +#include "mlir/Conversion/Passes.h" |
| 10 | +#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" |
| 11 | +#include "mlir/Dialect/Bufferization/Transforms/Passes.h" |
| 12 | +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" |
| 13 | +#include "mlir/Dialect/Linalg/Passes.h" |
| 14 | +#include "mlir/Dialect/MemRef/IR/MemRef.h" |
| 15 | +#include "mlir/Dialect/MemRef/Transforms/Passes.h" |
| 16 | +#include "mlir/Dialect/OpenMP/OpenMPDialect.h" |
| 17 | +#include "mlir/Dialect/SCF/IR/SCF.h" |
| 18 | +#include "mlir/Dialect/Tensor/IR/Tensor.h" |
| 19 | +#include "mlir/InitAllPasses.h" |
| 20 | +#include "mlir/Pass/PassManager.h" |
| 21 | +#include "mlir/Support/LogicalResult.h" |
| 22 | +#include "mlir/Transforms/Passes.h" |
| 23 | + |
| 24 | +#include "gc/Dialect/CPURuntime/Transforms/CPURuntimePasses.h" |
| 25 | +#include "gc/Dialect/OneDNNGraph/OneDNNGraphDialect.h" |
| 26 | +#include "gc/Transforms/Passes.h" |
| 27 | + |
| 28 | +namespace mlir::gc { |
| 29 | + |
| 30 | +// linalg + linalgX + tensor |
| 31 | +void populateFrontendPasses(mlir::PassManager &pm) { |
| 32 | + // pm.addPass(onednn_graph::createConvertOneDNNGraphToLinalg()); |
| 33 | +} |
| 34 | + |
| 35 | +// scf + arith + math + vector + tensor + linalg.brgemm + tensor.pack/unpack |
| 36 | +void populateTensorPasses(mlir::PassManager &pm) { |
| 37 | + // todo: padding propagation pass |
| 38 | + // todo: layout propagation pass |
| 39 | + // todo: tensor constant propagation pass |
| 40 | + // todo: linalg.matmul lowering to (scf.loop + linalg.brgemm) pass |
| 41 | + // todo: fine-grain fusion pass |
| 42 | + // todo: lower linalg to arith/math on virtual vector pass |
| 43 | + |
| 44 | + // REMOVE this pass after the above passes are added. Currently we add this |
| 45 | + // pass to make the pipeline work properly |
| 46 | + pm.addNestedPass<func::FuncOp>(createLinalgGeneralizeNamedOpsPass()); |
| 47 | +} |
| 48 | + |
| 49 | +// scf + arith + math + vector + tensor + linalg.brgemm |
| 50 | +void populateVectorPasses(mlir::PassManager &pm) { |
| 51 | + // todo: bf16 promotion pass, device dependent pass |
| 52 | + // todo: bf16 cast elimilation pass, fast-math kind pass, designed to support |
| 53 | + // oneDNN graph spec |
| 54 | + pm.addNestedPass<func::FuncOp>(arith::createArithExpandOpsPass()); |
| 55 | + // todo: lower to physical vector pass, device dependent pass |
| 56 | +} |
| 57 | + |
| 58 | +// scf + arith + math + vector + memref + linalg.brgemm |
| 59 | +void populateBufferizationPasses(mlir::PassManager &pm) { |
| 60 | + bufferization::OneShotBufferizationOptions options; |
| 61 | + pm.addPass(bufferization::createOneShotBufferizePass(options)); |
| 62 | + pm.addPass(createCSEPass()); |
| 63 | + pm.addPass(mlir::func::createFuncBufferizePass()); |
| 64 | + pm.addNestedPass<func::FuncOp>( |
| 65 | + bufferization::createBufferizationBufferizePass()); |
| 66 | + pm.addNestedPass<func::FuncOp>( |
| 67 | + bufferization::createFinalizingBufferizePass()); |
| 68 | + bufferization::BufferResultsToOutParamsOpts opt{}; |
| 69 | + opt.hoistStaticAllocs = true; |
| 70 | + pm.addPass(bufferization::createBufferResultsToOutParamsPass(opt)); |
| 71 | + // todo: buffer schedule pass |
| 72 | + // todo: Need to improve this pass to support nested parallel. |
| 73 | + pm.addNestedPass<func::FuncOp>(bufferization::createBufferHoistingPass()); |
| 74 | + pm.addNestedPass<func::FuncOp>(bufferization::createBufferLoopHoistingPass()); |
| 75 | + pm.addNestedPass<func::FuncOp>(bufferization::createBufferDeallocationPass()); |
| 76 | + pm.addPass(createBufferizationToMemRefPass()); |
| 77 | +} |
| 78 | + |
| 79 | +// scf + arith + math + vector + memref + func/microkernel |
| 80 | +void populateMicroKernelPasses(mlir::PassManager &pm) { |
| 81 | + // todo: ConvertLinalgToMicrokernel pass |
| 82 | + // todo: CleanupInvalidMicrokernel pass |
| 83 | + // todo: InvariantMicrokernelMotion pass |
| 84 | + // todo: ConvertMicrokernelToDnnlFunc to lower brgemm to dnnl call |
| 85 | + // todo: ConvertMicrokernelToXsmm, to lower brgemm to libxsmm call |
| 86 | + // todo: LowerMicrokernel pass |
| 87 | + // todo: DispatchMicrokernel |
| 88 | +} |
| 89 | + |
| 90 | +void populateCPURuntimePasses(mlir::PassManager &pm) { |
| 91 | + // todo: flatten nested parallel pass to support coarse-grain usion |
| 92 | + // remove this pass after we add FlattenNestedParallel |
| 93 | + pm.addPass(createConvertSCFToOpenMPPass()); |
| 94 | +} |
| 95 | + |
| 96 | +void populateLoweringToLLVMPasses(mlir::PassManager &pm) { |
| 97 | + pm.addPass(createConvertSCFToCFPass()); |
| 98 | + pm.addPass(cpuruntime::createCPURuntimeToLLVM()); |
| 99 | + pm.addPass(createConvertOpenMPToLLVMPass()); |
| 100 | + pm.addNestedPass<func::FuncOp>(createConvertMathToLLVMPass()); |
| 101 | + pm.addPass(createConvertMathToLibmPass()); |
| 102 | + pm.addPass(createFinalizeMemRefToLLVMConversionPass()); |
| 103 | + pm.addNestedPass<func::FuncOp>(createArithToLLVMConversionPass()); |
| 104 | + pm.addPass(createConvertFuncToLLVMPass()); |
| 105 | + pm.addPass(createConvertControlFlowToLLVMPass()); |
| 106 | + pm.addPass(createCSEPass()); |
| 107 | + pm.addPass(createCanonicalizerPass()); |
| 108 | + pm.addPass(createReconcileUnrealizedCastsPass()); |
| 109 | + pm.addPass(createSymbolDCEPass()); |
| 110 | +} |
| 111 | + |
| 112 | +void populateLLVMPasses(mlir::PassManager &pm) { |
| 113 | + pm.addPass(memref::createExpandOpsPass()); |
| 114 | + pm.addPass(memref::createExpandStridedMetadataPass()); |
| 115 | + populateLoweringToLLVMPasses(pm); |
| 116 | +} |
| 117 | + |
| 118 | +void populateCPUPipeline(mlir::PassManager &pm) { |
| 119 | + // front-end, oneDNN graph dialect |
| 120 | + populateFrontendPasses(pm); |
| 121 | + // middle-end, LinalgX/Linalg/tensor dialects |
| 122 | + populateTensorPasses(pm); |
| 123 | + // middle-end, arith/math/vector dialects |
| 124 | + populateVectorPasses(pm); |
| 125 | + // back-end, arith/math/vector/memref dialects |
| 126 | + populateBufferizationPasses(pm); |
| 127 | + // REMOVE this pass after the TensorPasses are added. Currently we add this |
| 128 | + // pass to make the pipeline work properly |
| 129 | + pm.addNestedPass<func::FuncOp>(createConvertLinalgToParallelLoopsPass()); |
| 130 | + populateMicroKernelPasses(pm); |
| 131 | + populateCPURuntimePasses(pm); |
| 132 | + // // back-end, llvm dialect |
| 133 | + populateLLVMPasses(pm); |
| 134 | +} |
| 135 | + |
| 136 | +#define GEN_PASS_DEF_GCCPUPIPELINE |
| 137 | +#include "gc/Transforms/Passes.h.inc" |
| 138 | +namespace { |
| 139 | + |
| 140 | +class GCCPUPipeline : public impl::GCCPUPipelineBase<GCCPUPipeline> { |
| 141 | +public: |
| 142 | + friend struct PassHelper; |
| 143 | + using impl::GCCPUPipelineBase<GCCPUPipeline>::GCCPUPipelineBase; |
| 144 | + void runOnOperation() final { |
| 145 | + auto op = getOperation(); |
| 146 | + PassManager pm{op->getContext()}; |
| 147 | + populateCPUPipeline(pm); |
| 148 | + if (failed(pm.run(op))) |
| 149 | + signalPassFailure(); |
| 150 | + } |
| 151 | +}; |
| 152 | + |
| 153 | +} // namespace |
| 154 | +} // namespace mlir::gc |
0 commit comments