|
| 1 | +//===- OneDNNGraphToLinalg.cpp - OneDNN Graph To Linalg Lowering --*- C++ -*-=// |
| 2 | +//-*-===// |
| 3 | +// |
| 4 | +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. |
| 5 | +// See https://llvm.org/LICENSE.txt for license information. |
| 6 | +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 7 | +// |
| 8 | +//===----------------------------------------------------------------------===// |
| 9 | + |
| 10 | +#include <numeric> |
| 11 | +#include <vector> |
| 12 | + |
| 13 | +#include "gc/Dialect/OneDNNGraph/OneDNNGraphDialect.h" |
| 14 | +#include "gc/Dialect/OneDNNGraph/OneDNNGraphOps.h" |
| 15 | +#include "gc/Transforms/Passes.h" |
| 16 | +#include "mlir/Dialect/Func/IR/FuncOps.h" |
| 17 | +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" |
| 18 | +#include "mlir/Dialect/Linalg/IR/Linalg.h" |
| 19 | +#include "mlir/Dialect/Math/IR/Math.h" |
| 20 | +#include "mlir/Dialect/Tensor/IR/Tensor.h" |
| 21 | +#include "mlir/IR/PatternMatch.h" |
| 22 | +#include "mlir/Rewrite/FrozenRewritePatternSet.h" |
| 23 | +#include "mlir/Support/LogicalResult.h" |
| 24 | +#include "mlir/Transforms/DialectConversion.h" |
| 25 | +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
| 26 | + |
| 27 | +using namespace mlir::onednn_graph; |
| 28 | + |
| 29 | +namespace mlir { |
| 30 | +namespace gc { |
| 31 | +#define GEN_PASS_DEF_CONVERTONEDNNGRAPHTOLINALG |
| 32 | +#include "gc/Transforms/Passes.h.inc" |
| 33 | + |
| 34 | +namespace { |
| 35 | +//===----------------------------------------------------------------------===// |
| 36 | +// Util funcs |
| 37 | +//===----------------------------------------------------------------------===// |
| 38 | + |
| 39 | +Value createBroadcastOperand(Location loc, PatternRewriter &rewriter, |
| 40 | + TensorType ty, Value op) { |
| 41 | + auto opTy = dyn_cast<TensorType>(op.getType()); |
| 42 | + llvm::ArrayRef<int64_t> bcastShape = ty.getShape(); |
| 43 | + llvm::ArrayRef<int64_t> opShape = opTy.getShape(); |
| 44 | + int64_t diff = bcastShape.size() - opShape.size(); |
| 45 | + |
| 46 | + if (bcastShape.equals(opShape)) { |
| 47 | + return op; |
| 48 | + } else { |
| 49 | + // get broadcast dimensions |
| 50 | + llvm::SmallVector<int64_t> bcastDims; |
| 51 | + for (int64_t i = 0; i < (int64_t)bcastShape.size(); i++) { |
| 52 | + int64_t idxOp = i - diff; |
| 53 | + if (idxOp < 0) { |
| 54 | + bcastDims.push_back(i); |
| 55 | + } else if (bcastShape[i] != opShape[idxOp]) { |
| 56 | + bcastDims.push_back(i); |
| 57 | + } |
| 58 | + } |
| 59 | + // create a new output tensor |
| 60 | + Value initTensor = |
| 61 | + rewriter.create<tensor::EmptyOp>(loc, bcastShape, ty.getElementType()); |
| 62 | + return rewriter |
| 63 | + .create<linalg::BroadcastOp>( |
| 64 | + /*location=*/loc, |
| 65 | + /*inputs=*/op, |
| 66 | + /*inits=*/initTensor, |
| 67 | + /*dimensions=*/bcastDims) |
| 68 | + .getResults() |
| 69 | + .front(); |
| 70 | + } |
| 71 | +} |
| 72 | + |
| 73 | +// Typedef for function to get operands for transformed op |
| 74 | +typedef mlir::Value (*GetOperandFn)(Operation *, PatternRewriter &, TensorType); |
| 75 | + |
| 76 | +// Functions to get operands for from original op |
| 77 | +struct OriginalOperand { |
| 78 | + template <unsigned I> |
| 79 | + static Value getIdx(Operation *op, PatternRewriter &b, TensorType ty) { |
| 80 | + if (I >= op->getNumOperands()) { |
| 81 | + op->emitError("Index exceeds operand num.\n"); |
| 82 | + return nullptr; |
| 83 | + } |
| 84 | + return createBroadcastOperand(op->getLoc(), b, ty, op->getOperand(I)); |
| 85 | + } |
| 86 | +}; |
| 87 | + |
| 88 | +// Functions to get constant operands |
| 89 | +struct ConstantOperand { |
| 90 | + template <int64_t I> |
| 91 | + static Value getConst(Operation *op, PatternRewriter &b, TensorType ty) { |
| 92 | + const auto loc = op->getLoc(); |
| 93 | + const auto elemTy = ty.getElementType(); |
| 94 | + if (llvm::isa<IntegerType>(elemTy)) { |
| 95 | + return b.create<arith::ConstantOp>( |
| 96 | + loc, DenseElementsAttr::get(ty, b.getIntegerAttr(elemTy, I))); |
| 97 | + } else if (llvm::isa<FloatType>(elemTy)) { |
| 98 | + return b.create<arith::ConstantOp>( |
| 99 | + loc, DenseElementsAttr::get(ty, b.getFloatAttr(elemTy, I))); |
| 100 | + } else { |
| 101 | + op->emitError("Not a supported element type for constant.\n"); |
| 102 | + return nullptr; |
| 103 | + } |
| 104 | + } |
| 105 | +}; |
| 106 | + |
| 107 | +//===----------------------------------------------------------------------===// |
| 108 | +// Elemwise lowering |
| 109 | +//===----------------------------------------------------------------------===// |
| 110 | + |
| 111 | +// Generate elementwise op using linalg named ops |
| 112 | +template <typename LoweredOp> |
| 113 | +Value createElemwiseOp(Location loc, PatternRewriter &rewriter, TensorType ty, |
| 114 | + llvm::ArrayRef<Value> inputs) { |
| 115 | + // create a new output tensor |
| 116 | + Value outTensor = |
| 117 | + rewriter.create<tensor::EmptyOp>(loc, ty.getShape(), ty.getElementType()); |
| 118 | + |
| 119 | + auto elemwiseOp = rewriter.create<LoweredOp>( |
| 120 | + /*location=*/loc, |
| 121 | + /*resultTensorTypes=*/outTensor.getType(), |
| 122 | + /*inputs=*/inputs, |
| 123 | + /*outputs=*/outTensor); |
| 124 | + |
| 125 | + return elemwiseOp.getResult(0); |
| 126 | +} |
| 127 | + |
| 128 | +template <typename UnaryOp, typename LoweredOp, GetOperandFn GetOperand> |
| 129 | +struct UnaryElemwiseLowering : public OpRewritePattern<UnaryOp> { |
| 130 | + using OpRewritePattern<UnaryOp>::OpRewritePattern; |
| 131 | + LogicalResult matchAndRewrite(UnaryOp op, |
| 132 | + PatternRewriter &rewriter) const final { |
| 133 | + auto loc = op->getLoc(); |
| 134 | + auto resultTy = dyn_cast<TensorType>(op->getResultTypes().front()); |
| 135 | + auto inOp = GetOperand(op, rewriter, resultTy); |
| 136 | + if (!inOp) { |
| 137 | + return rewriter.notifyMatchFailure(op, "Fail to get operand."); |
| 138 | + } |
| 139 | + auto unaryOp = createElemwiseOp<LoweredOp>(loc, rewriter, resultTy, {inOp}); |
| 140 | + rewriter.replaceOp(op, unaryOp); |
| 141 | + return success(); |
| 142 | + } |
| 143 | +}; |
| 144 | + |
| 145 | +template <typename BinaryOp, typename LoweredOp, GetOperandFn GetOperandLHS, |
| 146 | + GetOperandFn GetOperandRHS> |
| 147 | +struct BinaryElemwiseLowering : public OpRewritePattern<BinaryOp> { |
| 148 | + using OpRewritePattern<BinaryOp>::OpRewritePattern; |
| 149 | + LogicalResult matchAndRewrite(BinaryOp op, |
| 150 | + PatternRewriter &rewriter) const final { |
| 151 | + auto loc = op->getLoc(); |
| 152 | + auto resultTy = dyn_cast<TensorType>(op->getResultTypes().front()); |
| 153 | + auto lhsOp = GetOperandLHS(op, rewriter, resultTy); |
| 154 | + auto rhsOp = GetOperandRHS(op, rewriter, resultTy); |
| 155 | + if (!lhsOp || !rhsOp) { |
| 156 | + return rewriter.notifyMatchFailure(op, "Fail to get operand."); |
| 157 | + } |
| 158 | + auto binaryOp = createElemwiseOp<LoweredOp>(loc, rewriter, resultTy, // |
| 159 | + {lhsOp, rhsOp}); |
| 160 | + rewriter.replaceOp(op, binaryOp); |
| 161 | + return success(); |
| 162 | + } |
| 163 | +}; |
| 164 | + |
| 165 | +//===----------------------------------------------------------------------===// |
| 166 | +// Op lowering |
| 167 | +//===----------------------------------------------------------------------===// |
| 168 | + |
| 169 | +using ReLUOpLowering = |
| 170 | + BinaryElemwiseLowering<onednn_graph::ReLUOp, linalg::MaxOp, // |
| 171 | + OriginalOperand::getIdx<0>, |
| 172 | + ConstantOperand::getConst<0>>; |
| 173 | + |
| 174 | +using AddOpLowering = |
| 175 | + BinaryElemwiseLowering<onednn_graph::AddOp, linalg::AddOp, // |
| 176 | + OriginalOperand::getIdx<0>, |
| 177 | + OriginalOperand::getIdx<1>>; |
| 178 | + |
| 179 | +//===----------------------------------------------------------------------===// |
| 180 | +// MatMulOp lowering |
| 181 | +//===----------------------------------------------------------------------===// |
| 182 | + |
| 183 | +struct MatMulOpLowering : public OpRewritePattern<MatMulOp> { |
| 184 | + using OpRewritePattern<MatMulOp>::OpRewritePattern; |
| 185 | + LogicalResult matchAndRewrite(MatMulOp op, |
| 186 | + PatternRewriter &rewriter) const final { |
| 187 | + auto loc = op->getLoc(); |
| 188 | + auto resultTy = dyn_cast<TensorType>(op->getResultTypes().front()); |
| 189 | + auto typeA = dyn_cast<TensorType>(op.getInputA().getType()); |
| 190 | + auto typeB = dyn_cast<TensorType>(op.getInputB().getType()); |
| 191 | + // |
| 192 | + auto getEmptyTensor = [&](TensorType tensorTy) -> Value { |
| 193 | + Value zero = rewriter.create<arith::ConstantOp>( |
| 194 | + loc, rewriter.getZeroAttr(tensorTy.getElementType())); |
| 195 | + Value newTensor = rewriter.create<tensor::EmptyOp>( |
| 196 | + loc, tensorTy.getShape(), tensorTy.getElementType()); |
| 197 | + return rewriter.create<linalg::FillOp>(loc, zero, newTensor).getResult(0); |
| 198 | + }; |
| 199 | + |
| 200 | + if (typeA.getRank() != 2 || typeB.getRank() != 2) { |
| 201 | + return rewriter.notifyMatchFailure( |
| 202 | + op, "Currently not support multi batch matmul."); |
| 203 | + } |
| 204 | + bool transposeA = op.getTransposeA(); |
| 205 | + bool transposeB = op.getTransposeB(); |
| 206 | + Operation *newOp = nullptr; |
| 207 | + if (!transposeA && !transposeB) { |
| 208 | + // (A * B) |
| 209 | + auto outTensor = getEmptyTensor(resultTy); |
| 210 | + newOp = rewriter.create<linalg::MatmulOp>( |
| 211 | + /*location=*/loc, |
| 212 | + /*resultTensorTypes=*/resultTy, |
| 213 | + /*inputs=*/ValueRange{op.getInputA(), op.getInputB()}, |
| 214 | + /*outputs=*/outTensor); |
| 215 | + } else if (transposeA && !transposeB) { |
| 216 | + // T(A) * B |
| 217 | + auto outTensor = getEmptyTensor(resultTy); |
| 218 | + newOp = rewriter.create<linalg::MatmulTransposeAOp>( |
| 219 | + /*location=*/loc, |
| 220 | + /*resultTensorTypes=*/resultTy, |
| 221 | + /*inputs=*/ValueRange{op.getInputA(), op.getInputB()}, |
| 222 | + /*outputs=*/outTensor); |
| 223 | + } else if (!transposeA && transposeB) { |
| 224 | + // A * T(B) |
| 225 | + auto outTensor = getEmptyTensor(resultTy); |
| 226 | + newOp = rewriter.create<linalg::MatmulTransposeBOp>( |
| 227 | + /*location=*/loc, |
| 228 | + /*resultTensorTypes=*/resultTy, |
| 229 | + /*inputs=*/ValueRange{op.getInputA(), op.getInputB()}, |
| 230 | + /*outputs=*/outTensor); |
| 231 | + } else { |
| 232 | + // T(B * A) |
| 233 | + const auto &resultShape = resultTy.getShape(); |
| 234 | + SmallVector<int64_t> transShape{resultShape[1], resultShape[0]}; |
| 235 | + SmallVector<int64_t> permutation{1, 0}; |
| 236 | + auto transTy = resultTy.clone(transShape); |
| 237 | + auto transTensor = getEmptyTensor(transTy); |
| 238 | + auto matmulOp = rewriter.create<linalg::MatmulOp>( |
| 239 | + /*location=*/loc, |
| 240 | + /*resultTensorTypes=*/transTy, |
| 241 | + /*inputs=*/ValueRange{op.getInputB(), op.getInputA()}, |
| 242 | + /*outputs=*/transTensor); |
| 243 | + auto outTensor = getEmptyTensor(resultTy); |
| 244 | + newOp = rewriter.create<linalg::TransposeOp>( |
| 245 | + /*location=*/loc, |
| 246 | + /*inputs=*/matmulOp.getResult(0), |
| 247 | + /*outputs=*/outTensor, |
| 248 | + /*permutation=*/permutation); |
| 249 | + } |
| 250 | + |
| 251 | + if (op.getBias()) { |
| 252 | + Value bias = |
| 253 | + createBroadcastOperand(loc, rewriter, resultTy, op.getBias()); |
| 254 | + Value outBias = rewriter.create<tensor::EmptyOp>( |
| 255 | + loc, resultTy.getShape(), resultTy.getElementType()); |
| 256 | + newOp = rewriter.create<linalg::AddOp>( |
| 257 | + /*location=*/loc, |
| 258 | + /*resultTensorTypes=*/outBias.getType(), |
| 259 | + /*inputs=*/ValueRange{newOp->getResult(0), bias}, |
| 260 | + /*outputs=*/outBias); |
| 261 | + } |
| 262 | + |
| 263 | + rewriter.replaceOp(op, newOp); |
| 264 | + return success(); |
| 265 | + } |
| 266 | +}; |
| 267 | + |
| 268 | +//===----------------------------------------------------------------------===// |
| 269 | +// Pass define |
| 270 | +//===----------------------------------------------------------------------===// |
| 271 | + |
| 272 | +struct ConvertOneDNNGraphToLinalg |
| 273 | + : public impl::ConvertOneDNNGraphToLinalgBase<ConvertOneDNNGraphToLinalg> { |
| 274 | + |
| 275 | + void runOnOperation() final { |
| 276 | + auto *ctx = &getContext(); |
| 277 | + // add lowering target |
| 278 | + ConversionTarget target(getContext()); |
| 279 | + target.addIllegalDialect<onednn_graph::OneDNNGraphDialect>(); |
| 280 | + target.addLegalDialect<BuiltinDialect, arith::ArithDialect, |
| 281 | + linalg::LinalgDialect, func::FuncDialect, |
| 282 | + tensor::TensorDialect>(); |
| 283 | + // set pattern |
| 284 | + RewritePatternSet patterns(ctx); |
| 285 | + patterns.add<AddOpLowering>(ctx); |
| 286 | + patterns.add<ReLUOpLowering>(ctx); |
| 287 | + patterns.add<MatMulOpLowering>(ctx); |
| 288 | + // perform conversion |
| 289 | + if (failed( |
| 290 | + applyFullConversion(getOperation(), target, std::move(patterns)))) { |
| 291 | + signalPassFailure(); |
| 292 | + } |
| 293 | + } |
| 294 | +}; |
| 295 | + |
| 296 | +} // namespace |
| 297 | +} // namespace gc |
| 298 | +} // namespace mlir |
0 commit comments