Skip to content

[LAYOUTS] Generic stmatrix lowering #6609

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions include/triton/Analysis/Utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -225,8 +225,7 @@ bool supportMMA(Value value, int version);
// return nullopt). The output will be such that layout.getInDimNames() ==
// layout.getOutDimNames() and the conversion will not include kBlock (resp.
// kWarp or kLane) if it can be avoided
triton::LinearLayout minimalCvtLayout(RankedTensorType srcTy,
RankedTensorType dstTy);
triton::LinearLayout minimalCvtLayout(Type srcTy, Type dstTy);

// Conversion from `srcTy` to `dstTy` only involves reordering of registers.
// There is no need for data exchange across threads, warps, or blocks.
Expand Down
2 changes: 2 additions & 0 deletions include/triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,8 @@ class TargetInfoBase {

virtual bool supportVectorizedAtomics() const = 0;

virtual bool supportLdStMatrix() const = 0;

// Annotate target specific information to local store operations during
// lowering to LLVM.
virtual void localStoreOpAnnotation(triton::gpu::LocalStoreOp op,
Expand Down
16 changes: 16 additions & 0 deletions include/triton/Tools/LayoutUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,22 @@ standardOutDimPairs(MLIRContext *ctx, ArrayRef<int64_t> dstShape);
LinearLayout identityStandardND(StringAttr inDimName, ArrayRef<unsigned> shape,
ArrayRef<unsigned> order);

// Return a layout with the same in/out dimensions as `layout` but with all
// bases set to 0.
LinearLayout zerosLike(const LinearLayout &layout);

// For a layout A with A.hasInDim(kReg), find a permutation of registers action
// such that action.apply(A) may be divisible by B
// It's not always true that the action returned by this function will
// allow us to divideLeft, but it is true that if it if there exists one, it is
// the one returned by this function.
std::optional<ColumnAction> regPermForDivideLeft(const LinearLayout &A,
const LinearLayout &B);

// For a layout A with A.hasInDim(kReg), find a permutation of registers action
// such that action.apply(A) has the broadcasted registers removed
ColumnAction actionRemoveBroadcastedRegs(const LinearLayout &layout);

// Compute the supremum of two lists.
// Error out if the supremum does not exist (e.g. [a, b] and [b, a]).
// If the supremum is not unique, we return the first list first
Expand Down
66 changes: 56 additions & 10 deletions include/triton/Tools/LinearLayout.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include <vector>

#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/ValueRange.h"
#include "llvm/ADT/Hashing.h"
#include "llvm/ADT/MapVector.h"
#include "llvm/ADT/STLExtras.h"
Expand Down Expand Up @@ -538,19 +539,17 @@ class LinearLayout {
return reshapeOuts({{*getOutDimNames().begin(), getTotalOutDimSize()}});
}

// Concatenates two layouts by their input dimensions. The layouts must have
// the same output dimensions and sizes and different input dimensions. The
// input dimensions of this layout are placed before those of 'other'. This
// can be thought of as the opposite of `sublayout`, which slices a layout
// from a larger one.
// Concatenates two layouts by their in (resp. out) dimensions. The layouts
// must have the same output (resp. input) dimensions and sizes and different
// input (resp. output) dimensions. The input dimensions of this layout are
// placed before those of 'other'. This can be thought of as the opposite of
// `sublayout`, which slices a layout from a larger one.
[[nodiscard]] LinearLayout concatIns(const LinearLayout &other) const;
// Concatenates two layouts by their output dimensions. The layouts must have
// the same input dimensions and sizes and different output dimensions. The
// output dimensions of this layout are placed before those of 'other'. This
// can be thought of as the opposite of `sublayout`, which slices a layout
// from a larger one.
[[nodiscard]] LinearLayout concatOuts(const LinearLayout &other) const;

// Remove all the bases that equal to 0 for the given input dimension.
[[nodiscard]] LinearLayout unsqueezeIns(StringAttr dim) const;

// Computes the direct sum of two layouts.
// https://en.wikipedia.org/wiki/Direct_sum#Direct_sum_of_matrices
//
Expand Down Expand Up @@ -773,6 +772,53 @@ inline std::ostream &operator<<(std::ostream &os, const LinearLayout &layout) {
return os;
}

// Defines a map acting on the columns (i.e. bases) a given input dimension of a
// layout as per:
// action[i] -> i.
// This action can be:
// - Applied to a layout to get a new layout with the same input dimensions
// but with the bases permuted (and perhaps some of them dropped).
// - Applied to a range of Values to apply the same transformation to them
//
// E.g. if action = [2, 0, 1] and basesDim = [1, 2, 4]
// - action.apply(layout) returns a LL with basesDim = [4, 1, 2]
// - action.apply(range) with range.size() == 8, returns a range permuted as
// [x[0], x[4], x[1], x[5], x[2], x[6], x[3], x[7]]
class ColumnAction {
private:
SmallVector<size_t> action;
StringAttr inDim;
size_t inSizeLog2;
bool isIdentity;

public:
ColumnAction(ArrayRef<size_t> action, StringAttr inDim, size_t inSizeLog2)
: action(action), inDim(inDim), inSizeLog2(inSizeLog2) {
auto it = llvm::max_element(action);
// Assert in the constructor... ugh
assert(it == action.end() || *it < inSizeLog2);
// In many cases the action will be the identity, so we save that as an
// early return
isIdentity = action.size() == inSizeLog2 && llvm::is_sorted(action);
}

// Act on the columns of a layout
// Examples:
// - if action = [2, 0, 1] and layout.getBases()[inDim] = [[1], [2], [4]]
// - action.apply(layout) returns a LL with basesDim = [[4], [1], [2]]
// - if action = [2, 0] and layout.getBases()[inDim] = [[1], [4], [2]]
// - action.apply(layout) returns a LL with bases[inDim] = [[2], [1]]
LinearLayout apply(const LinearLayout &layout) const;

// Act on a range of values (representing registers)
// e.g. if action = [2, 0, 1] and inSizeLog2 = 3 and inDim.str() = "register"
// - action.apply(range) with range.size() == 8, returns
// [x[0], x[4], x[1], x[5], x[2], x[6], x[3], x[7]]
SmallVector<Value> apply(ValueRange values) const;

std::string toString() const;
};

} // namespace mlir::triton

#endif // TRITON_TOOLS_LINEARLAYOUT_H
31 changes: 19 additions & 12 deletions lib/Analysis/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -737,27 +737,34 @@ bool matchMFMAAndDotOperandShuffleCase(RankedTensorType srcTy,
}

// We get the smallest submap of srcTy^{-1} * dstTy that is not the identity
// under kBlock, kWarp or kLane (in that order). The idea here is that if we
// have a transformation that's the identity on kBlock, we don't need to use
// under the common dimensions. The idea here is that if we have a
// transformation that's the identity on kBlock, we don't need to use
// distributed shared memory. If it's also the identity on kWarp, we can
// transfer via warp-shuffles, and if it's the identity on kLane just have to
// reorder the registers
LinearLayout minimalCvtLayout(RankedTensorType srcTy, RankedTensorType dstTy) {
MLIRContext *ctx = srcTy.getContext();
// reorder the registers.
LinearLayout minimalCvtLayout(Type srcTy_, Type dstTy_) {
auto srcTy = cast<triton::gpu::TensorOrMemDesc>(srcTy_);
auto dstTy = cast<triton::gpu::TensorOrMemDesc>(dstTy_);
LinearLayout srcLayout =
toLinearLayout(srcTy.getShape(), srcTy.getEncoding());
LinearLayout dstLayout =
toLinearLayout(dstTy.getShape(), dstTy.getEncoding());
StringAttr kRegister = StringAttr::get(ctx, "register");
StringAttr kLane = StringAttr::get(ctx, "lane");
StringAttr kWarp = StringAttr::get(ctx, "warp");
StringAttr kBlock = StringAttr::get(ctx, "block");
auto sDims = to_vector(srcLayout.getInDimNames());
auto dDims = to_vector(dstLayout.getInDimNames());
SmallVector<StringAttr> dims;
for (int i = 0; i < std::min(sDims.size(), dDims.size()); ++i) {
auto srcDim = sDims[sDims.size() - i - 1];
auto dstDim = dDims[dDims.size() - i - 1];
if (srcDim != dstDim) {
break;
}
dims.push_back(srcDim);
}

auto comp = dstLayout.invertAndCompose(srcLayout);
// We try to quotient by the largest subspace first
auto dims = SmallVector<StringRef>{"block", "warp", "lane", "register"};
// We try to quotient by the slowers moving subspace first
for (auto dim : dims) {
auto quotient = comp.quotient(StringAttr::get(ctx, dim));
auto quotient = comp.quotient(dim);
if (!quotient.has_value()) {
break;
}
Expand Down
75 changes: 75 additions & 0 deletions lib/Tools/LayoutUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,81 @@ LinearLayout identityStandardND(StringAttr inDimName, ArrayRef<unsigned> shape,
return ret;
}

LinearLayout zerosLike(const LinearLayout &layout) {
auto bases = layout.getBases();
for (auto &basis : bases) {
for (auto &vec : basis.second) {
for (auto &val : vec) {
val = 0;
}
}
}

SmallVector<std::pair<StringAttr, int32_t>> outDims;
for (auto outDim : layout.getOutDimNames()) {
outDims.emplace_back(outDim, layout.getOutDimSize(outDim));
}
return LinearLayout(std::move(bases), std::move(outDims),
/*requireSurjective=*/false);
}

std::optional<ColumnAction> regPermForDivideLeft(const LinearLayout &A,
const LinearLayout &B) {
// We can implement this generically of any dimension, but for now we only do
// it for regs to keep the API simpler
assert(A.getNumInDims() != 0);
auto kReg = *A.getInDimNames().begin();
assert(kReg.str() == "register");
assert(B.getNumInDims() != 0);
assert(kReg == *B.getInDimNames().begin());
// Retrieve the register bases from A and B.
const auto &ARegBases = A.getBases().lookup(kReg);
const auto &BRegBases = B.getBases().lookup(kReg);

// Compute the permutation order:
// For each basis in B (in order), find its index in A (using each index at
// most once). We make sure we use each index at most once in case B
// broadcasts (weird case, but better safe than sorry).
SmallVector<size_t> permOrder;
permOrder.reserve(ARegBases.size());
SmallVector<bool> used(ARegBases.size(), false);
for (const auto &bB : BRegBases) {
bool found = false;
for (size_t j = 0; j < ARegBases.size(); ++j) {
found = !used[j] && (ARegBases[j] == bB);
if (found) {
permOrder.push_back(j);
used[j] = true;
break;
}
}
if (!found)
return std::nullopt; // A basis from B not found in A.
}
// Append remaining indices from A (preserving their original order).
for (size_t i = 0; i < ARegBases.size(); ++i) {
if (!used[i])
permOrder.push_back(i);
}
return ColumnAction(permOrder, kReg, ARegBases.size());
}

ColumnAction actionRemoveBroadcastedRegs(const LinearLayout &layout) {
assert(layout.getNumInDims() != 0);
auto kReg = *layout.getInDimNames().begin();
assert(kReg.str() == "register");

// Drop the bases that are zero
const auto &bases = layout.getBases().lookup(kReg);
SmallVector<size_t> permOrder;
for (size_t i = 0; i < bases.size(); ++i) {
if (!llvm::all_of(bases[i], [](size_t x) { return x == 0; })) {
permOrder.push_back(i);
}
}
return ColumnAction(permOrder, kReg, bases.size());
}

// Compute the supremum of two lists.
// If the supremum is not unique, we return the first list first
// Error out if the supremum does not exist
Expand Down
50 changes: 50 additions & 0 deletions lib/Tools/LinearLayout.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1183,4 +1183,54 @@ std::string LinearLayout::toString() const {
return ret;
}

LinearLayout ColumnAction::apply(const LinearLayout &layout) const {
assert(layout.hasInDim(inDim));
assert(layout.getInDimSizeLog2(inDim) == inSizeLog2 &&
"Layout has a different size than the ColumnAction");
if (isIdentity) {
return layout;
}

auto bases = layout.getBases();
const auto &basesInDim = bases[inDim];
std::vector<std::vector<int32_t>> newBases;
newBases.reserve(action.size());
for (size_t a : action) {
newBases.push_back(basesInDim[a]);
}
bases[inDim] = std::move(newBases);

SmallVector<std::pair<StringAttr, int32_t>> outDims;
for (auto outDim : layout.getOutDimNames()) {
outDims.emplace_back(outDim, layout.getOutDimSize(outDim));
}
return LinearLayout(std::move(bases), std::move(outDims),
/*requireSurjective=*/false);
}

SmallVector<Value> ColumnAction::apply(ValueRange values) const {
assert(values.size() == (1 << inSizeLog2) &&
"Values have a different size than the ColumnAction");
assert(inDim.str() == "register" && "Values are in registers, so we can only "
"apply ColumnAction to registers");
if (isIdentity) {
return values;
}
auto permLL = apply(LinearLayout::identity1D(values.size(), inDim, inDim));
SmallVector<Value> ret;
ret.reserve(permLL.getInDimSize(inDim));
for (int i = 0; i < permLL.getInDimSize(inDim); i++) {
int32_t srcIdx = permLL.apply({{inDim, i}}).begin()->second;
ret.push_back(values[srcIdx]);
}
return ret;
}

std::string ColumnAction::toString() const {
std::string ret = "ColumnAction([";
ret += join(action, ", ");
ret += "], " + inDim.str() + ", " + std::to_string(inSizeLog2) + ")";
return ret;
}

} // namespace mlir::triton
54 changes: 54 additions & 0 deletions test/Conversion/tritongpu_to_llvm_hopper.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,60 @@ module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-

// -----

#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}>
#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 32, 16]}>
#smem = #ttg.shared_memory
// CHECK-LABEL: distribute_to_swizzled_st_matrix_local_store
module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} {
tt.func @distribute_to_swizzled_st_matrix_local_store(%a: tensor<8x64xf16, #mma>) {
// CHECK-COUNT-2: nvgpu.stmatrix
// CHECK: llvm.return
%b = ttg.local_alloc {allocation.offset = 0 : i32} : () -> !ttg.memdesc<8x64xf16, #shared, #smem, mutable>
ttg.local_store %a, %b : tensor<8x64xf16, #mma> -> !ttg.memdesc<8x64xf16, #shared, #smem, mutable>
tt.return
}
}

// -----

#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}>
#linear = #ttg.linear<{register = [[0, 1], [8, 0], [0, 8], [0, 16]], lane = [[0, 2], [0, 4], [1, 0], [2, 0], [4, 0]], warp = [[16, 0], [32, 0]], block = []}>
#smem = #ttg.shared_memory
// CHECK-LABEL: linear_to_swizzled_st_matrix_local_store
module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} {
tt.func @linear_to_swizzled_st_matrix_local_store(%a: tensor<64x32xf16, #linear>) {
// CHECK-COUNT-2: nvgpu.stmatrix
// CHECK: llvm.return
%b = ttg.local_alloc {allocation.offset = 0 : i32} : () -> !ttg.memdesc<64x32xf16, #shared, #smem, mutable>
ttg.local_store %a, %b : tensor<64x32xf16, #linear> -> !ttg.memdesc<64x32xf16, #shared, #smem, mutable>
tt.return
}
}

// -----

// Stretching a bit the lowering. Feel free to kill this test if we restrain
// the lowering a bit later on.
// These layouts will have plenty of bank conflicts, so it'd make sense not to
// lower them via stmatrix.
// It is of course possible to design a shared memory layout that makes the lowering
// via stmatrix not have any bank conflicts, but yeah.
#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}>
#linear = #ttg.linear<{register = [[0, 1], [4, 0], [0, 0], [0, 16], [2, 0]], lane = [[0, 2], [0, 4], [0, 0], [8, 0], [0, 8]], warp = [[1, 0], [16, 0]], block = []}>
#smem = #ttg.shared_memory
// CHECK-LABEL: linear_to_swizzled_st_matrix_local_store
module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} {
tt.func @linear_to_swizzled_st_matrix_local_store(%a: tensor<32x32xf16, #linear>) {
// CHECK-COUNT-2: nvgpu.stmatrix
// CHECK: llvm.return
%b = ttg.local_alloc {allocation.offset = 0 : i32} : () -> !ttg.memdesc<32x32xf16, #shared, #smem, mutable>
ttg.local_store %a, %b : tensor<32x32xf16, #linear> -> !ttg.memdesc<32x32xf16, #shared, #smem, mutable>
tt.return
}
}

// -----

#blocked = #ttg.blocked<{sizePerThread = [8], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} {
tt.func @fp8_const(%arg0: tensor<1024xi1, #blocked>, %arg1: tensor<1024xf8E4M3FNUZ, #blocked>) attributes {noinline = false} {
Expand Down
2 changes: 2 additions & 0 deletions third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,8 @@ class TargetInfo : public mlir::triton::TargetInfoBase {

bool supportVectorizedAtomics() const override;

bool supportLdStMatrix() const override { return false; }

void localStoreOpAnnotation(triton::gpu::LocalStoreOp op,
size_t localStoreOpCount,
Type type) const override;
Expand Down
Loading
Loading