Skip to content

Commit 3567769

Browse files
committed
init sdpa op and flash attention pass
1 parent 02f519b commit 3567769

File tree

10 files changed

+588
-7
lines changed

10 files changed

+588
-7
lines changed

include/gc/Dialect/Arith/Utils/EasyBuild.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -264,7 +264,7 @@ inline EBFloatPoint operator-(const EBFloatPoint &a) {
264264
}
265265

266266
#define DEF_EASYBUILD_CMP_OPERATOR(OP, OPCLASS, TYPE, PRED) \
267-
EBUnsigned operator OP(const TYPE &a, const TYPE &b) { \
267+
inline EBUnsigned operator OP(const TYPE &a, const TYPE &b) { \
268268
return OperatorHandlers::handleCmp<OPCLASS>(a, b, PRED); \
269269
} \
270270
template <typename T> EBUnsigned operator OP(const TYPE &a, T b) { \

include/gc/Dialect/Linalgx/LinalgxStructuredOps.td

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,9 @@ include "mlir/Interfaces/InferTypeOpInterface.td"
2323
include "mlir/Interfaces/SideEffectInterfaces.td"
2424
include "mlir/IR/OpAsmInterface.td"
2525

26+
class Linalgx_Op<string mnemonic, list<Trait> traits = []> :
27+
Op<LinalgxDialect, mnemonic, traits>;
28+
2629
// Base Tablegen class for Linalg ops.
2730
// Linalg ops that correspond to library calls operate on ShapedType as their
2831
// first operands. These may be optionally followed by non-view operands
@@ -312,4 +315,27 @@ def Linalgx_MultiBatchMatmulOp : LinalgxStructuredBase_Op<"multi_batch_matmul",
312315
}];
313316
}
314317

318+
def Linalgx_ScaledDotProductAttentionOp
319+
: Linalgx_Op<"scaled_dot_product_attention",
320+
[AttrSizedOperandSegments,
321+
DeclareOpInterfaceMethods<AggregatedOpInterface, ["decomposeOperation"]>]> {
322+
let summary = "Attention structure.";
323+
let description = [{
324+
Q, K, V, attention_mask.
325+
Output = SoftMax(Q @ K.transpose(-2, -1) + attention_mask) @ V.
326+
}];
327+
let arguments = (ins
328+
Variadic<TensorOrMemref>:$inputs,
329+
Variadic<TensorOrMemref>:$outputs);
330+
let results = (outs Variadic<TensorOrMemref>:$results);
331+
332+
let hasVerifier = 1;
333+
let assemblyFormat = [{
334+
attr-dict
335+
`ins` `(` $inputs `:` type($inputs) `)`
336+
`outs` `(` $outputs `:` type($outputs) `)`
337+
(`->` type($results)^)?
338+
}];
339+
}
340+
315341
#endif // LINALGX_STRUCTURED_OPS

include/gc/IR/EasyBuildSCF.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ inline int IfIterator::operator*() const {
150150

151151
} // namespace impl
152152

153-
impl::IfSimulator makeIfRange(const EasyBuilder &s, Operation *op) {
153+
inline impl::IfSimulator makeIfRange(const EasyBuilder &s, Operation *op) {
154154
return impl::IfSimulator{s.builder, op};
155155
}
156156

include/gc/Transforms/Passes.td

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,18 @@ def DeepTileContractionNamedOp
3434
];
3535
}
3636

37-
def GCCPUPipeline : Pass<"gc-cpu-pipeline"> {
37+
def FlashAttentionConversion
38+
: Pass<"flash-attention-conversion", "func::FuncOp"> {
39+
let summary = "Flash Attention Conversion";
40+
let description =
41+
[{The pass converts MHA to flash attention implementation.}];
42+
let dependentDialects = [
43+
"func::FuncDialect", "linalg::LinalgDialect", "scf::SCFDialect",
44+
"tensor::TensorDialect"
45+
];
46+
}
47+
48+
def GCCPUPipeline: Pass<"gc-cpu-pipeline"> {
3849
let summary = "All-in-one pipeline for GC for CPU";
3950
let dependentDialects = [
4051
"onednn_graph::OneDNNGraphDialect", "tensor::TensorDialect",

lib/gc/Dialect/Linalgx/LinalgxOps.cpp

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include "gc/Dialect/Linalgx/LinalgxOps.h"
1010
#include "gc/Dialect/Linalgx/LinalgxDialect.h"
1111
#include "mlir/IR/OpImplementation.h"
12+
#include <utility>
1213

1314
//===----------------------------------------------------------------------===//
1415
// Builder helper from mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -608,6 +609,80 @@ void MultiBatchMatmulOp::getEffects(
608609
getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
609610
}
610611

612+
//===----------------------------------------------------------------------===//
613+
// ScaledDotProductAttentionOp
614+
//===----------------------------------------------------------------------===//
615+
616+
LogicalResult ScaledDotProductAttentionOp::verify() { return success(); }
617+
618+
/// This method converts ScaledDotProductAttention into the following
619+
/// sequence of operations:
620+
/// output = softmax(ins[0] @ transpose(ins[1]) * scale + ins[3]) @ ins[2]
621+
FailureOr<SmallVector<Value>>
622+
ScaledDotProductAttentionOp::decomposeOperation(OpBuilder &b) {
623+
OpBuilder::InsertionGuard guard(b);
624+
b.setInsertionPoint(*this);
625+
Location loc = getLoc();
626+
Value query = getInputs()[0], key = getInputs()[1], value = getInputs()[2],
627+
mask = getInputs()[3];
628+
auto dtype = cast<RankedTensorType>(query.getType()).getElementType();
629+
auto shape = cast<RankedTensorType>(query.getType()).getShape();
630+
float rsqrt_head = 1 / sqrt(shape[3]);
631+
632+
SmallVector<int64_t> permutation{0, 1, 3, 2};
633+
SmallVector<int64_t> transposeShape{shape[0], shape[1], shape[3], shape[2]};
634+
auto transposeOut = b.create<tensor::EmptyOp>(loc, transposeShape, dtype);
635+
auto transpose = b.create<linalg::TransposeOp>(
636+
/*location=*/loc,
637+
/*inputs=*/key,
638+
/*outputs=*/transposeOut,
639+
/*permutation=*/permutation);
640+
641+
SmallVector<int64_t> matmulQKShape{shape[0], shape[1], shape[2], shape[2]};
642+
auto matmulQKOut = b.create<tensor::EmptyOp>(loc, matmulQKShape, dtype);
643+
auto matmulQK = b.create<linalgx::MultiBatchMatmulOp>(
644+
/*location=*/loc, matmulQKOut.getResult().getType(),
645+
/*inputs=*/ValueRange{query, transpose->getResult(0)},
646+
/*outputs=*/ValueRange{matmulQKOut.getResult()});
647+
648+
auto mulOut = b.create<tensor::EmptyOp>(loc, matmulQKShape, dtype);
649+
// Broadcast the initial value to the output tensor before convolving.
650+
SmallVector<AffineMap, 4> indexingMaps;
651+
indexingMaps.push_back(b.getMultiDimIdentityMap(4));
652+
indexingMaps.push_back(b.getMultiDimIdentityMap(4));
653+
auto mul = b.create<linalg::GenericOp>(
654+
/*location=*/loc, matmulQKOut.getResult().getType(),
655+
/*inputs=*/ValueRange{matmulQK->getResult(0)},
656+
/*outputs=*/ValueRange{mulOut.getResult()}, indexingMaps,
657+
SmallVector<utils::IteratorType>(4, utils::IteratorType::parallel),
658+
[&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
659+
Value constant = b.create<arith::ConstantOp>(
660+
loc, nestedBuilder.getFloatAttr(dtype, rsqrt_head));
661+
Value added =
662+
nestedBuilder.create<arith::MulFOp>(loc, args[0], constant);
663+
nestedBuilder.create<linalg::YieldOp>(nestedLoc, added);
664+
});
665+
666+
auto addOut = b.create<tensor::EmptyOp>(loc, matmulQKShape, dtype);
667+
auto add = b.create<linalg::AddOp>(
668+
/*location=*/loc, addOut.getResult().getType(),
669+
/*inputs=*/ValueRange{mul->getResult(0), mask},
670+
/*outputs=*/ValueRange{addOut.getResult()});
671+
672+
auto softmaxOut = b.create<tensor::EmptyOp>(loc, matmulQKShape, dtype);
673+
auto softmax = b.create<linalg::SoftmaxOp>(
674+
/*location=*/loc, softmaxOut.getResult().getType(),
675+
/*inputs=*/add->getResult(0),
676+
/*outputs=*/softmaxOut.getResult(), 3);
677+
678+
auto matmulVOut = b.create<tensor::EmptyOp>(loc, shape, dtype);
679+
auto matmulV = b.create<linalgx::MultiBatchMatmulOp>(
680+
/*location=*/loc, matmulVOut.getResult().getType(),
681+
/*inputs=*/ValueRange{softmax->getResult(0), value},
682+
/*outputs=*/ValueRange{matmulVOut.getResult()});
683+
return SmallVector<Value>{matmulV.getResults()[0]};
684+
}
685+
611686
/////// Operations corresponding to library calls defined with Tablegen ////////
612687

613688
#define GET_OP_CLASSES

lib/gc/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ add_mlir_library(GCPasses
1111
OneDNNGraphToLinalg.cpp
1212
Pipeline.cpp
1313
DeepTileContractionNamedOp.cpp
14+
FlashAttentionConversion.cpp
1415
Tiling.cpp
1516

1617
ADDITIONAL_HEADER_DIRS

0 commit comments

Comments
 (0)