Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] [AMD] Emit AMD specific intrinsics for dot #4594

Draft
wants to merge 12 commits into
base: main
Choose a base branch
from
33 changes: 33 additions & 0 deletions include/triton/Conversion/TritonGPUToLLVM/FMADotUtility.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
#ifndef TRITON_CONVERSION_FMA_DOT_UTILITY_H
#define TRITON_CONVERSION_FMA_DOT_UTILITY_H

#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/DialectConversion.h"
#include "triton/Dialect/Triton/IR/Dialect.h"

namespace mlir::triton::gpu {

/// \brief Abstract interface for scalar multiplication of Value vectors.
///
/// Enable generation of hardware specific code in different backends.
class FMAVectorMultiplier {
public:
/// \returns scalar product of two arrays, plus c: a·b + c
virtual mlir::Value multiplyVectors(mlir::ArrayRef<mlir::Value> a,
mlir::ArrayRef<mlir::Value> b,
mlir::Value c) = 0;

virtual ~FMAVectorMultiplier() = default;
};

/// \brief Implements an abstract framework for dot conversion to llvm.
LogicalResult parametricConvertFMADot(triton::DotOp op,
triton::DotOp::Adaptor adaptor,
const LLVMTypeConverter *typeConverter,
ConversionPatternRewriter &rewriter,
FMAVectorMultiplier &multiplier);

} // namespace mlir::triton::gpu

#endif // TRITON_CONVERSION_FMA_DOT_UTILITY_H
2 changes: 2 additions & 0 deletions include/triton/Conversion/TritonGPUToLLVM/Utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,8 @@ using namespace mlir::triton;
#define undef(...) rewriter.create<LLVM::UndefOp>(loc, __VA_ARGS__)
#define null(...) rewriter.create<LLVM::ZeroOp>(loc, __VA_ARGS__)
#define call(...) LLVM::createLLVMCallOp(rewriter, loc, __VA_ARGS__)
#define call_intrinsic(...) \
rewriter.create<LLVM::CallIntrinsicOp>(loc, __VA_ARGS__)

// Types
#define int_ty(width) rewriter.getIntegerType(width)
Expand Down
4 changes: 4 additions & 0 deletions include/triton/Dialect/TritonGPU/Transforms/Utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,9 @@ bool isPureUnaryInlineAsm(Operation *op);
// read the compute capability from the module attributes
int getNVIDIAComputeCapability(Operation *module);

// Read the amd target from the module attributes
StringRef getAMDArch(Operation *module);

std::optional<mlir::triton::gpu::SharedEncodingAttr>
getSharedEncIfAllUsersAreDotEnc(Value val, bool &incompatible);

Expand All @@ -209,6 +212,7 @@ MMALoadType getMMALoadType(Operation *loadOp);
std::optional<triton::LinearLayout>
getRegToSharedLayout(MLIRContext *ctx, ArrayRef<int64_t> shape,
Attribute srcEnc, Attribute dstEnc, int elemBitWidth);

} // namespace mlir

#endif // TRITON_DIALECT_TRITONGPU_TRANSFORMS_UTILITY_H_
1 change: 1 addition & 0 deletions lib/Conversion/TritonGPUToLLVM/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
add_triton_library(TritonGPUToLLVM
ConvertLayoutOpToLLVM/SharedToDotOperandFMA.cpp
DotOpToLLVM/FMA.cpp
DotOpToLLVM/FMADotUtility.cpp
GlobalScratchMemoryAllocation.cpp
TypeConverter.cpp
Utility.cpp
Expand Down
151 changes: 23 additions & 128 deletions lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/FMA.cpp
Original file line number Diff line number Diff line change
@@ -1,144 +1,39 @@
#include "mlir/Support/LLVM.h"
#include "triton/Conversion/TritonGPUToLLVM/FMADotUtility.h"
#include "triton/Conversion/TritonGPUToLLVM/Utility.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/Transforms/Utility.h"

using namespace mlir;
using namespace mlir::triton;
using namespace ::mlir::triton::gpu;

using ::mlir::LLVM::linearize;
using ::mlir::triton::gpu::expandMatrixOrderWithBatch;
using ::mlir::triton::gpu::expandMatrixShapeWithBatch;
using ::mlir::triton::gpu::getShapePerCTA;
using ::mlir::triton::gpu::getSizePerThread;

/// \brief spatial position of repetition and register of a given value
struct OperandValueKey {
unsigned bRepIdx, nonKRepIdx;
unsigned bIdx, nonKIdx, kIdx;

bool operator==(const OperandValueKey &other) const {
return (bRepIdx == other.bRepIdx && nonKRepIdx == other.nonKRepIdx &&
bIdx == other.bIdx && nonKIdx == other.nonKIdx &&
kIdx == other.kIdx);
}
};

template <> struct std::hash<OperandValueKey> {
std::size_t operator()(const OperandValueKey &k) const {
return llvm::hash_combine(k.bRepIdx, k.nonKRepIdx, k.bIdx, k.nonKIdx,
k.kIdx);
namespace {
class GenericFMAVectorMultiplier : public FMAVectorMultiplier {
ConversionPatternRewriter &rewriter;
Location loc;

public:
GenericFMAVectorMultiplier(ConversionPatternRewriter &rewriter, Location loc)
: rewriter(rewriter), loc(loc) {}

mlir::Value multiplyVectors(mlir::ArrayRef<mlir::Value> a,
mlir::ArrayRef<mlir::Value> b,
mlir::Value c) override {
auto K = a.size();
assert(b.size() == K);
mlir::Value accum = c;
for (int k = 0; k < K; ++k)
accum = rewriter.create<LLVM::FMulAddOp>(loc, a[k], b[k], accum);
return accum;
}
};

using ValueTableFMA = std::unordered_map<OperandValueKey, Value>;

static ValueTableFMA getValueTableFromStructFMA(
Value val, ArrayRef<unsigned> perRepShape, ArrayRef<unsigned> repetitions,
unsigned kDim, unsigned nonKDim, ConversionPatternRewriter &rewriter,
Location loc, ArrayRef<unsigned> inRepOrder, ArrayRef<unsigned> repOrder) {
ValueTableFMA res;
auto elems = unpackLLElements(loc, val, rewriter);
assert(perRepShape.size() == 3);
auto numElemsRep = product(perRepShape);
assert(elems.size() == numElemsRep * product(repetitions));
assert(kDim == 1 || kDim == 2);
assert(nonKDim == 1 || nonKDim == 2);
const unsigned bDim = 0;

for (unsigned idx = 0; idx < elems.size(); ++idx) {
auto inRepLinearIdx = idx % numElemsRep;
auto repLinearIdx = idx / numElemsRep;
auto inRepSpatialIdx =
mlir::LLVM::delinearize(inRepLinearIdx, perRepShape, inRepOrder);
auto repSpatialIdx =
mlir::LLVM::delinearize(repLinearIdx, repetitions, repOrder);
OperandValueKey key{repSpatialIdx[0], repSpatialIdx[nonKDim],
inRepSpatialIdx[0], inRepSpatialIdx[nonKDim],
inRepSpatialIdx[kDim]};
res[key] = elems[idx];
}
return res;
}
} // namespace

LogicalResult convertFMADot(triton::DotOp op, triton::DotOp::Adaptor adaptor,
const LLVMTypeConverter *typeConverter,
ConversionPatternRewriter &rewriter) {
auto *ctx = rewriter.getContext();
auto loc = op.getLoc();

auto A = op.getA();
auto D = op.getResult();

auto aTensorTy = cast<RankedTensorType>(A.getType());
auto dTensorTy = cast<RankedTensorType>(D.getType());

SmallVector<int64_t> aShapePerCTA =
expandMatrixShapeWithBatch(ArrayRef(getShapePerCTA(aTensorTy)));
auto dShapePerCTA =
expandMatrixShapeWithBatch(ArrayRef(getShapePerCTA(dTensorTy)));

BlockedEncodingAttr dLayout =
cast<BlockedEncodingAttr>(dTensorTy.getEncoding());
// TODO process A and B operand separately
auto inRepOrder = expandMatrixOrderWithBatch(dLayout.getOrder());
auto repOrder = expandMatrixOrderWithBatch(dLayout.getRepOrder());
auto cc = unpackLLElements(loc, adaptor.getC(), rewriter);

Value llA = adaptor.getA();
Value llB = adaptor.getB();

auto sizePerThread =
expandMatrixShapeWithBatch(ArrayRef(getSizePerThread(dLayout)));
auto numElemsPerThread = product(sizePerThread);
auto shapePerCTATile =
expandMatrixShapeWithBatch(ArrayRef(getShapePerCTATile(dLayout)));

unsigned K = aShapePerCTA[2];

unsigned threadTileShape[3];
unsigned repetitions[3];
for (int i = 0; i < 3; ++i) {
repetitions[i] =
ceil(dShapePerCTA[i], static_cast<int64_t>(shapePerCTATile[i]));
}

auto has = getValueTableFromStructFMA(
llA, {sizePerThread[0], sizePerThread[1], K},
{repetitions[0], repetitions[1], 1},
/*kDim*/ 2, /*nonKDim*/ 1, rewriter, loc, inRepOrder, repOrder);
auto hbs = getValueTableFromStructFMA(
llB, {sizePerThread[0], K, sizePerThread[2]},
{repetitions[0], 1, repetitions[2]},
/*kDim*/ 1, /*nonKDim*/ 2, rewriter, loc, inRepOrder, repOrder);

SmallVector<Value> acc = cc;

for (unsigned bRep = 0; bRep < repetitions[0]; ++bRep)
for (unsigned mRep = 0; mRep < repetitions[1]; ++mRep)
for (unsigned nRep = 0; nRep < repetitions[2]; ++nRep)
for (unsigned b = 0; b < sizePerThread[0]; ++b)
for (unsigned m = 0; m < sizePerThread[1]; ++m)
for (unsigned n = 0; n < sizePerThread[2]; ++n) {
SmallVector<unsigned> multiDimAccumIdx = {b, m, n};
unsigned linearInRepIdx =
linearize(multiDimAccumIdx, sizePerThread, inRepOrder);
SmallVector<unsigned> multiDimRepIdx = {bRep, mRep, nRep};
unsigned linearRepIdx =
linearize(multiDimRepIdx, repetitions, repOrder);
unsigned linearAccumIdx =
linearInRepIdx + linearRepIdx * numElemsPerThread;
for (unsigned k = 0; k < K; ++k) {
auto aOp = has[{bRep, mRep, b, m, k}];
auto bOp = hbs[{bRep, nRep, b, n, k}];
acc[linearAccumIdx] = rewriter.create<LLVM::FMulAddOp>(
loc, aOp, bOp, acc[linearAccumIdx]);
}
}

auto res = packLLElements(loc, typeConverter, acc, rewriter, dTensorTy);
rewriter.replaceOp(op, res);

return success();
GenericFMAVectorMultiplier multiplier(rewriter, loc);
return parametricConvertFMADot(op, adaptor, typeConverter, rewriter,
multiplier);
}
153 changes: 153 additions & 0 deletions lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/FMADotUtility.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
#include "triton/Conversion/TritonGPUToLLVM/FMADotUtility.h"
#include "triton/Conversion/TritonGPUToLLVM/Utility.h"

using namespace mlir;

namespace {

/// \brief spatial position of repetition and register of a given value
struct OperandValueKey {
unsigned bRepIdx, nonKRepIdx;
unsigned bIdx, nonKIdx, kIdx;

bool operator==(const OperandValueKey &other) const {
return (bRepIdx == other.bRepIdx && nonKRepIdx == other.nonKRepIdx &&
bIdx == other.bIdx && nonKIdx == other.nonKIdx &&
kIdx == other.kIdx);
}
};

} // namespace

template <> struct std::hash<OperandValueKey> {
std::size_t operator()(const OperandValueKey &k) const {
return llvm::hash_combine(k.bRepIdx, k.nonKRepIdx, k.bIdx, k.nonKIdx,
k.kIdx);
}
};

namespace {

using ValueTableFMA = std::unordered_map<OperandValueKey, Value>;

ValueTableFMA getValueTableFromStructFMA(
Value val, ArrayRef<unsigned> perRepShape, ArrayRef<unsigned> repetitions,
unsigned kDim, unsigned nonKDim, ConversionPatternRewriter &rewriter,
Location loc, ArrayRef<unsigned> inRepOrder, ArrayRef<unsigned> repOrder) {
ValueTableFMA res;
auto elems = unpackLLElements(loc, val, rewriter);
assert(perRepShape.size() == 3);
auto numElemsRep = product(perRepShape);
assert(elems.size() == numElemsRep * product(repetitions));
assert(kDim == 1 || kDim == 2);
assert(nonKDim == 1 || nonKDim == 2);
const unsigned bDim = 0;

for (unsigned idx = 0; idx < elems.size(); ++idx) {
auto inRepLinearIdx = idx % numElemsRep;
auto repLinearIdx = idx / numElemsRep;
auto inRepSpatialIdx =
mlir::LLVM::delinearize(inRepLinearIdx, perRepShape, inRepOrder);
auto repSpatialIdx =
mlir::LLVM::delinearize(repLinearIdx, repetitions, repOrder);
OperandValueKey key{repSpatialIdx[0], repSpatialIdx[nonKDim],
inRepSpatialIdx[0], inRepSpatialIdx[nonKDim],
inRepSpatialIdx[kDim]};
res[key] = elems[idx];
}
return res;
}

} // namespace

namespace mlir::triton::gpu {

LogicalResult parametricConvertFMADot(triton::DotOp op,
triton::DotOp::Adaptor adaptor,
const LLVMTypeConverter *typeConverter,
ConversionPatternRewriter &rewriter,
FMAVectorMultiplier &multiplier) {
auto *ctx = rewriter.getContext();
auto loc = op.getLoc();

auto A = op.getA();
auto D = op.getResult();

auto aTensorTy = cast<RankedTensorType>(A.getType());
auto dTensorTy = cast<RankedTensorType>(D.getType());

SmallVector<int64_t> aShapePerCTA =
expandMatrixShapeWithBatch(ArrayRef(getShapePerCTA(aTensorTy)));
auto dShapePerCTA =
expandMatrixShapeWithBatch(ArrayRef(getShapePerCTA(dTensorTy)));

BlockedEncodingAttr dLayout =
cast<BlockedEncodingAttr>(dTensorTy.getEncoding());
// TODO process A and B operand separately
auto inRepOrder = expandMatrixOrderWithBatch(dLayout.getOrder());
auto repOrder = expandMatrixOrderWithBatch(dLayout.getRepOrder());
auto cc = unpackLLElements(loc, adaptor.getC(), rewriter);

Value llA = adaptor.getA();
Value llB = adaptor.getB();

auto sizePerThread =
expandMatrixShapeWithBatch(ArrayRef(getSizePerThread(dLayout)));
auto numElemsPerThread = product(sizePerThread);
auto shapePerCTATile =
expandMatrixShapeWithBatch(ArrayRef(getShapePerCTATile(dLayout)));

unsigned K = aShapePerCTA[2];

unsigned threadTileShape[3];
unsigned repetitions[3];
for (int i = 0; i < 3; ++i) {
repetitions[i] =
ceil(dShapePerCTA[i], static_cast<int64_t>(shapePerCTATile[i]));
}

auto has = getValueTableFromStructFMA(
llA, {sizePerThread[0], sizePerThread[1], K},
{repetitions[0], repetitions[1], 1},
/*kDim*/ 2, /*nonKDim*/ 1, rewriter, loc, inRepOrder, repOrder);
auto hbs = getValueTableFromStructFMA(
llB, {sizePerThread[0], K, sizePerThread[2]},
{repetitions[0], 1, repetitions[2]},
/*kDim*/ 1, /*nonKDim*/ 2, rewriter, loc, inRepOrder, repOrder);

SmallVector<Value> acc = cc;

for (unsigned bRep = 0; bRep < repetitions[0]; ++bRep)
for (unsigned mRep = 0; mRep < repetitions[1]; ++mRep)
for (unsigned nRep = 0; nRep < repetitions[2]; ++nRep)
for (unsigned b = 0; b < sizePerThread[0]; ++b)
for (unsigned m = 0; m < sizePerThread[1]; ++m)
for (unsigned n = 0; n < sizePerThread[2]; ++n) {
SmallVector<unsigned> multiDimAccumIdx = {b, m, n};
unsigned linearInRepIdx = mlir::LLVM::linearize(
multiDimAccumIdx, sizePerThread, inRepOrder);
SmallVector<unsigned> multiDimRepIdx = {bRep, mRep, nRep};
unsigned linearRepIdx =
mlir::LLVM::linearize(multiDimRepIdx, repetitions, repOrder);
unsigned linearAccumIdx =
linearInRepIdx + linearRepIdx * numElemsPerThread;

SmallVector<Value> aOpVector;
SmallVector<Value> bOpVector;

for (unsigned k = 0; k < K; ++k) {
aOpVector.push_back(has.at({bRep, mRep, b, m, k}));
bOpVector.push_back(hbs.at({bRep, nRep, b, n, k}));
}

acc[linearAccumIdx] = multiplier.multiplyVectors(
aOpVector, bOpVector, acc[linearAccumIdx]);
}

auto res = packLLElements(loc, typeConverter, acc, rewriter, dTensorTy);
rewriter.replaceOp(op, res);

return success();
}

} // namespace mlir::triton::gpu
Loading
Loading