Skip to content

Commit 259c918

Browse files
author
Longsheng Du
authored
[Transform] Add basic onednn_graph dialect lowering (#73)
[Transform] Add basic onednn_graph dialect lowering (#61)
1 parent a713c16 commit 259c918

File tree

4 files changed

+427
-0
lines changed

4 files changed

+427
-0
lines changed

include/gc/Transforms/Passes.td

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,4 +17,18 @@ def TileLinalgNamed : Pass<"tile-named-linalg", "func::FuncOp"> {
1717
["linalg::LinalgDialect", "scf::SCFDialect", "tensor::TensorDialect"];
1818
}
1919

20+
def ConvertOneDNNGraphToLinalg : Pass<"convert-onednn-graph-to-linalg"> {
21+
let summary = "Lower the operations from the oneDNN Graph dialect into Linalg";
22+
let description = [{
23+
Lowers the `onednn_graph` ops to `linalg` ops.
24+
}];
25+
let dependentDialects = [
26+
"func::FuncDialect",
27+
"math::MathDialect",
28+
"arith::ArithDialect",
29+
"tensor::TensorDialect",
30+
"linalg::LinalgDialect"
31+
];
32+
}
33+
2034
#endif // GC_DIALECT_GC_PASSES

lib/gc/Transforms/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
add_mlir_library(GCPasses
2+
OneDNNGraphToLinalg.cpp
23
TileNamed.cpp
34

45
ADDITIONAL_HEADER_DIRS
@@ -9,6 +10,7 @@ add_mlir_library(GCPasses
910

1011
LINK_LIBS PUBLIC
1112
${mlir_dialect_libs}
13+
MLIROneDNNGraph
1214
MLIRIR
1315
MLIRSupport
1416
MLIRBufferizationToMemRef
Lines changed: 298 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,298 @@
1+
//===- OneDNNGraphToLinalg.cpp - OneDNN Graph To Linalg Lowering --*- C++ -*-=//
2+
//-*-===//
3+
//
4+
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
5+
// See https://llvm.org/LICENSE.txt for license information.
6+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7+
//
8+
//===----------------------------------------------------------------------===//
9+
10+
#include <numeric>
11+
#include <vector>
12+
13+
#include "gc/Dialect/OneDNNGraph/OneDNNGraphDialect.h"
14+
#include "gc/Dialect/OneDNNGraph/OneDNNGraphOps.h"
15+
#include "gc/Transforms/Passes.h"
16+
#include "mlir/Dialect/Func/IR/FuncOps.h"
17+
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
18+
#include "mlir/Dialect/Linalg/IR/Linalg.h"
19+
#include "mlir/Dialect/Math/IR/Math.h"
20+
#include "mlir/Dialect/Tensor/IR/Tensor.h"
21+
#include "mlir/IR/PatternMatch.h"
22+
#include "mlir/Rewrite/FrozenRewritePatternSet.h"
23+
#include "mlir/Support/LogicalResult.h"
24+
#include "mlir/Transforms/DialectConversion.h"
25+
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
26+
27+
using namespace mlir::onednn_graph;
28+
29+
namespace mlir {
30+
namespace gc {
31+
#define GEN_PASS_DEF_CONVERTONEDNNGRAPHTOLINALG
32+
#include "gc/Transforms/Passes.h.inc"
33+
34+
namespace {
35+
//===----------------------------------------------------------------------===//
36+
// Util funcs
37+
//===----------------------------------------------------------------------===//
38+
39+
Value createBroadcastOperand(Location loc, PatternRewriter &rewriter,
40+
TensorType ty, Value op) {
41+
auto opTy = dyn_cast<TensorType>(op.getType());
42+
llvm::ArrayRef<int64_t> bcastShape = ty.getShape();
43+
llvm::ArrayRef<int64_t> opShape = opTy.getShape();
44+
int64_t diff = bcastShape.size() - opShape.size();
45+
46+
if (bcastShape.equals(opShape)) {
47+
return op;
48+
} else {
49+
// get broadcast dimensions
50+
llvm::SmallVector<int64_t> bcastDims;
51+
for (int64_t i = 0; i < (int64_t)bcastShape.size(); i++) {
52+
int64_t idxOp = i - diff;
53+
if (idxOp < 0) {
54+
bcastDims.push_back(i);
55+
} else if (bcastShape[i] != opShape[idxOp]) {
56+
bcastDims.push_back(i);
57+
}
58+
}
59+
// create a new output tensor
60+
Value initTensor =
61+
rewriter.create<tensor::EmptyOp>(loc, bcastShape, ty.getElementType());
62+
return rewriter
63+
.create<linalg::BroadcastOp>(
64+
/*location=*/loc,
65+
/*inputs=*/op,
66+
/*inits=*/initTensor,
67+
/*dimensions=*/bcastDims)
68+
.getResults()
69+
.front();
70+
}
71+
}
72+
73+
// Typedef for function to get operands for transformed op
74+
typedef mlir::Value (*GetOperandFn)(Operation *, PatternRewriter &, TensorType);
75+
76+
// Functions to get operands for from original op
77+
struct OriginalOperand {
78+
template <unsigned I>
79+
static Value getIdx(Operation *op, PatternRewriter &b, TensorType ty) {
80+
if (I >= op->getNumOperands()) {
81+
op->emitError("Index exceeds operand num.\n");
82+
return nullptr;
83+
}
84+
return createBroadcastOperand(op->getLoc(), b, ty, op->getOperand(I));
85+
}
86+
};
87+
88+
// Functions to get constant operands
89+
struct ConstantOperand {
90+
template <int64_t I>
91+
static Value getConst(Operation *op, PatternRewriter &b, TensorType ty) {
92+
const auto loc = op->getLoc();
93+
const auto elemTy = ty.getElementType();
94+
if (llvm::isa<IntegerType>(elemTy)) {
95+
return b.create<arith::ConstantOp>(
96+
loc, DenseElementsAttr::get(ty, b.getIntegerAttr(elemTy, I)));
97+
} else if (llvm::isa<FloatType>(elemTy)) {
98+
return b.create<arith::ConstantOp>(
99+
loc, DenseElementsAttr::get(ty, b.getFloatAttr(elemTy, I)));
100+
} else {
101+
op->emitError("Not a supported element type for constant.\n");
102+
return nullptr;
103+
}
104+
}
105+
};
106+
107+
//===----------------------------------------------------------------------===//
108+
// Elemwise lowering
109+
//===----------------------------------------------------------------------===//
110+
111+
// Generate elementwise op using linalg named ops
112+
template <typename LoweredOp>
113+
Value createElemwiseOp(Location loc, PatternRewriter &rewriter, TensorType ty,
114+
llvm::ArrayRef<Value> inputs) {
115+
// create a new output tensor
116+
Value outTensor =
117+
rewriter.create<tensor::EmptyOp>(loc, ty.getShape(), ty.getElementType());
118+
119+
auto elemwiseOp = rewriter.create<LoweredOp>(
120+
/*location=*/loc,
121+
/*resultTensorTypes=*/outTensor.getType(),
122+
/*inputs=*/inputs,
123+
/*outputs=*/outTensor);
124+
125+
return elemwiseOp.getResult(0);
126+
}
127+
128+
template <typename UnaryOp, typename LoweredOp, GetOperandFn GetOperand>
129+
struct UnaryElemwiseLowering : public OpRewritePattern<UnaryOp> {
130+
using OpRewritePattern<UnaryOp>::OpRewritePattern;
131+
LogicalResult matchAndRewrite(UnaryOp op,
132+
PatternRewriter &rewriter) const final {
133+
auto loc = op->getLoc();
134+
auto resultTy = dyn_cast<TensorType>(op->getResultTypes().front());
135+
auto inOp = GetOperand(op, rewriter, resultTy);
136+
if (!inOp) {
137+
return rewriter.notifyMatchFailure(op, "Fail to get operand.");
138+
}
139+
auto unaryOp = createElemwiseOp<LoweredOp>(loc, rewriter, resultTy, {inOp});
140+
rewriter.replaceOp(op, unaryOp);
141+
return success();
142+
}
143+
};
144+
145+
template <typename BinaryOp, typename LoweredOp, GetOperandFn GetOperandLHS,
146+
GetOperandFn GetOperandRHS>
147+
struct BinaryElemwiseLowering : public OpRewritePattern<BinaryOp> {
148+
using OpRewritePattern<BinaryOp>::OpRewritePattern;
149+
LogicalResult matchAndRewrite(BinaryOp op,
150+
PatternRewriter &rewriter) const final {
151+
auto loc = op->getLoc();
152+
auto resultTy = dyn_cast<TensorType>(op->getResultTypes().front());
153+
auto lhsOp = GetOperandLHS(op, rewriter, resultTy);
154+
auto rhsOp = GetOperandRHS(op, rewriter, resultTy);
155+
if (!lhsOp || !rhsOp) {
156+
return rewriter.notifyMatchFailure(op, "Fail to get operand.");
157+
}
158+
auto binaryOp = createElemwiseOp<LoweredOp>(loc, rewriter, resultTy, //
159+
{lhsOp, rhsOp});
160+
rewriter.replaceOp(op, binaryOp);
161+
return success();
162+
}
163+
};
164+
165+
//===----------------------------------------------------------------------===//
166+
// Op lowering
167+
//===----------------------------------------------------------------------===//
168+
169+
using ReLUOpLowering =
170+
BinaryElemwiseLowering<onednn_graph::ReLUOp, linalg::MaxOp, //
171+
OriginalOperand::getIdx<0>,
172+
ConstantOperand::getConst<0>>;
173+
174+
using AddOpLowering =
175+
BinaryElemwiseLowering<onednn_graph::AddOp, linalg::AddOp, //
176+
OriginalOperand::getIdx<0>,
177+
OriginalOperand::getIdx<1>>;
178+
179+
//===----------------------------------------------------------------------===//
180+
// MatMulOp lowering
181+
//===----------------------------------------------------------------------===//
182+
183+
struct MatMulOpLowering : public OpRewritePattern<MatMulOp> {
184+
using OpRewritePattern<MatMulOp>::OpRewritePattern;
185+
LogicalResult matchAndRewrite(MatMulOp op,
186+
PatternRewriter &rewriter) const final {
187+
auto loc = op->getLoc();
188+
auto resultTy = dyn_cast<TensorType>(op->getResultTypes().front());
189+
auto typeA = dyn_cast<TensorType>(op.getInputA().getType());
190+
auto typeB = dyn_cast<TensorType>(op.getInputB().getType());
191+
//
192+
auto getEmptyTensor = [&](TensorType tensorTy) -> Value {
193+
Value zero = rewriter.create<arith::ConstantOp>(
194+
loc, rewriter.getZeroAttr(tensorTy.getElementType()));
195+
Value newTensor = rewriter.create<tensor::EmptyOp>(
196+
loc, tensorTy.getShape(), tensorTy.getElementType());
197+
return rewriter.create<linalg::FillOp>(loc, zero, newTensor).getResult(0);
198+
};
199+
200+
if (typeA.getRank() != 2 || typeB.getRank() != 2) {
201+
return rewriter.notifyMatchFailure(
202+
op, "Currently not support multi batch matmul.");
203+
}
204+
bool transposeA = op.getTransposeA();
205+
bool transposeB = op.getTransposeB();
206+
Operation *newOp = nullptr;
207+
if (!transposeA && !transposeB) {
208+
// (A * B)
209+
auto outTensor = getEmptyTensor(resultTy);
210+
newOp = rewriter.create<linalg::MatmulOp>(
211+
/*location=*/loc,
212+
/*resultTensorTypes=*/resultTy,
213+
/*inputs=*/ValueRange{op.getInputA(), op.getInputB()},
214+
/*outputs=*/outTensor);
215+
} else if (transposeA && !transposeB) {
216+
// T(A) * B
217+
auto outTensor = getEmptyTensor(resultTy);
218+
newOp = rewriter.create<linalg::MatmulTransposeAOp>(
219+
/*location=*/loc,
220+
/*resultTensorTypes=*/resultTy,
221+
/*inputs=*/ValueRange{op.getInputA(), op.getInputB()},
222+
/*outputs=*/outTensor);
223+
} else if (!transposeA && transposeB) {
224+
// A * T(B)
225+
auto outTensor = getEmptyTensor(resultTy);
226+
newOp = rewriter.create<linalg::MatmulTransposeBOp>(
227+
/*location=*/loc,
228+
/*resultTensorTypes=*/resultTy,
229+
/*inputs=*/ValueRange{op.getInputA(), op.getInputB()},
230+
/*outputs=*/outTensor);
231+
} else {
232+
// T(B * A)
233+
const auto &resultShape = resultTy.getShape();
234+
SmallVector<int64_t> transShape{resultShape[1], resultShape[0]};
235+
SmallVector<int64_t> permutation{1, 0};
236+
auto transTy = resultTy.clone(transShape);
237+
auto transTensor = getEmptyTensor(transTy);
238+
auto matmulOp = rewriter.create<linalg::MatmulOp>(
239+
/*location=*/loc,
240+
/*resultTensorTypes=*/transTy,
241+
/*inputs=*/ValueRange{op.getInputB(), op.getInputA()},
242+
/*outputs=*/transTensor);
243+
auto outTensor = getEmptyTensor(resultTy);
244+
newOp = rewriter.create<linalg::TransposeOp>(
245+
/*location=*/loc,
246+
/*inputs=*/matmulOp.getResult(0),
247+
/*outputs=*/outTensor,
248+
/*permutation=*/permutation);
249+
}
250+
251+
if (op.getBias()) {
252+
Value bias =
253+
createBroadcastOperand(loc, rewriter, resultTy, op.getBias());
254+
Value outBias = rewriter.create<tensor::EmptyOp>(
255+
loc, resultTy.getShape(), resultTy.getElementType());
256+
newOp = rewriter.create<linalg::AddOp>(
257+
/*location=*/loc,
258+
/*resultTensorTypes=*/outBias.getType(),
259+
/*inputs=*/ValueRange{newOp->getResult(0), bias},
260+
/*outputs=*/outBias);
261+
}
262+
263+
rewriter.replaceOp(op, newOp);
264+
return success();
265+
}
266+
};
267+
268+
//===----------------------------------------------------------------------===//
269+
// Pass define
270+
//===----------------------------------------------------------------------===//
271+
272+
struct ConvertOneDNNGraphToLinalg
273+
: public impl::ConvertOneDNNGraphToLinalgBase<ConvertOneDNNGraphToLinalg> {
274+
275+
void runOnOperation() final {
276+
auto *ctx = &getContext();
277+
// add lowering target
278+
ConversionTarget target(getContext());
279+
target.addIllegalDialect<onednn_graph::OneDNNGraphDialect>();
280+
target.addLegalDialect<BuiltinDialect, arith::ArithDialect,
281+
linalg::LinalgDialect, func::FuncDialect,
282+
tensor::TensorDialect>();
283+
// set pattern
284+
RewritePatternSet patterns(ctx);
285+
patterns.add<AddOpLowering>(ctx);
286+
patterns.add<ReLUOpLowering>(ctx);
287+
patterns.add<MatMulOpLowering>(ctx);
288+
// perform conversion
289+
if (failed(
290+
applyFullConversion(getOperation(), target, std::move(patterns)))) {
291+
signalPassFailure();
292+
}
293+
}
294+
};
295+
296+
} // namespace
297+
} // namespace gc
298+
} // namespace mlir

0 commit comments

Comments
 (0)