Skip to content

Commit

Permalink
(mfma_16 support)
Browse files Browse the repository at this point in the history
  • Loading branch information
ilia-cher committed Nov 20, 2024
1 parent 4e127ae commit 9aa870b
Show file tree
Hide file tree
Showing 4 changed files with 159 additions and 50 deletions.
8 changes: 5 additions & 3 deletions lib/Analysis/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -639,13 +639,15 @@ bool matchMFMAAndDotOperandShuffleCase(RankedTensorType srcTy,
if (!mfmaLayout || !dotOperandLayout)
return false;

// Currently supporting 32x32 FP8 MFMA -> dot operand case
// Currently supporting 32x32 and 16x16 FP8 MFMA -> dot operand case
return dotOperandLayout.getParent() == mfmaLayout &&
dotOperandLayout.getOpIdx() == 0 && mfmaLayout.getIsTransposed() &&
dotOperandLayout.getKWidth() == 8 &&
getContigPerThread(mfmaLayout)[1] == 4 && mfmaLayout.getMDim() == 32 &&
mfmaLayout.getNDim() == 32 &&
getContigPerThread(mfmaLayout)[1] == 4 &&
((mfmaLayout.getMDim() == 16 && mfmaLayout.getNDim() == 16) ||
(mfmaLayout.getMDim() == 32 && mfmaLayout.getNDim() == 32)) &&
triton::type::isFloat8(srcTy.getElementType()) &&
triton::type::isFloat8(dstTy.getElementType()) &&
mfmaLayout.getWarpsPerCTA()[1] == 1;
}

Expand Down
52 changes: 50 additions & 2 deletions test/Conversion/amd/mfma-shortcut.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,60 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
#dotop0 = #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth=8}>

module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 64 : i32} {
// CHECK-LABEL: mfma_dot_cvt_f8
tt.func public @mfma_dot_cvt_f8(%arg0: tensor<128x32xf8E4M3FNUZ, #mfma>) {
// CHECK-LABEL: mfma_dot_cvt_f8_mfma32
tt.func public @mfma_dot_cvt_f8_mfma32(%arg0: tensor<128x32xf8E4M3FNUZ, #mfma>) {
// CHECK-NOT: store
// CHECK-NOT: load
// CHECK: rocdl.ds_bpermute
%0 = triton_gpu.convert_layout %arg0 : tensor<128x32xf8E4M3FNUZ, #mfma> -> tensor<128x32xf8E4M3FNUZ, #dotop0>
tt.return
}
}

// -----

#mfma = #triton_gpu.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [32, 32], isTransposed = true}>
#dotop0 = #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth=8}>

module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 64 : i32} {
// CHECK-LABEL: mfma_dot_cvt_bf8_mfma32
tt.func public @mfma_dot_cvt_bf8_mfma32(%arg0: tensor<128x32xf8E5M2, #mfma>) {
// CHECK-NOT: store
// CHECK-NOT: load
// CHECK: rocdl.ds_bpermute
%0 = triton_gpu.convert_layout %arg0 : tensor<128x32xf8E5M2, #mfma> -> tensor<128x32xf8E5M2, #dotop0>
tt.return
}
}

// -----

#mfma = #triton_gpu.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 16], isTransposed = true}>
#dotop0 = #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth=8}>

module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 64 : i32} {
// CHECK-LABEL: mfma_dot_cvt_f8_mfma16
tt.func public @mfma_dot_cvt_f8_mfma16(%arg0: tensor<128x32xf8E4M3FNUZ, #mfma>) {
// CHECK-NOT: store
// CHECK-NOT: load
// CHECK: rocdl.ds_bpermute
%0 = triton_gpu.convert_layout %arg0 : tensor<128x32xf8E4M3FNUZ, #mfma> -> tensor<128x32xf8E4M3FNUZ, #dotop0>
tt.return
}
}

// -----

#mfma = #triton_gpu.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 16], isTransposed = true}>
#dotop0 = #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth=8}>

module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 64 : i32} {
// CHECK-LABEL: mfma_dot_cvt_bf8_mfma16
tt.func public @mfma_dot_cvt_bf8_mfma16(%arg0: tensor<128x32xf8E5M2, #mfma>) {
// CHECK-NOT: store
// CHECK-NOT: load
// CHECK: rocdl.ds_bpermute
%0 = triton_gpu.convert_layout %arg0 : tensor<128x32xf8E5M2, #mfma> -> tensor<128x32xf8E5M2, #dotop0>
tt.return
}
}
120 changes: 85 additions & 35 deletions third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -134,23 +134,6 @@ struct ConvertLayoutOpMFMAToDotOpConversion
if (!matchMFMAAndDotOperandShuffleCase(srcType, dstType))
return failure();

/*
Using wave shuffle to convert layouts:
1) Input MMA layout (32x32, fp8, 16 values):
_____________________________________________________________
|(t0 v0 v1 v2 v3) (t32 v0 v1 v2 v3) ... (t32 v12 v13 v14 v15)|
| ... ... |
|(t31 v0 v1 v2 v3) (t63 v0 v1 v2 v3) ... (t63 v12 v13 v14 v15)|
|_____________________________________________________________|
2) Output Dot operand layout (two 32x16 tiles, fp8, 8 values each):
____________________________________________________________ ___
|(t0 v0 v1 v2 v3 v4 v5 v6 v7) (t32 v0 v1 v2 v3 v4 v5 v6 v7) | |
| ... ... | |...
|(t31 v0 v1 v2 v3 v4 v5 v6 v7) (t63 v0 v1 v2 v3 v4 v5 v6 v7) | |
|____________________________________________________________| |___
*/

auto loc = op.getLoc();

SmallVector<Value> inVals =
Expand All @@ -160,16 +143,24 @@ struct ConvertLayoutOpMFMAToDotOpConversion

Value threadId = tid_val();
auto mfmaLayout = dyn_cast<AMDMfmaEncodingAttr>(srcType.getEncoding());
Value warpSize = i32_val(triton::gpu::getWarpSize(mfmaLayout));
assert((mfmaLayout.getMDim() == 16 || mfmaLayout.getMDim() == 32) &&
"Expected MFMA size 16 or 32");
assert(triton::gpu::getWarpSize(mfmaLayout) == 64 &&
"Expected warp size 64 for MFMA");
Value warpSize = i32_val(64);
Value laneId = urem(threadId, warpSize);
Value laneOffset = i32_val(32);
Value mask = icmp_slt(laneId, laneOffset);
Value addr0 = select(mask, add(laneId, laneOffset), laneId);
Value addr1 = select(mask, laneId, sub(laneId, laneOffset));

SmallVector<Value> outVals;
auto elemTy = int_ty(8);
auto vecTy = vec_ty(elemTy, 4);

Value mask0 = icmp_slt(laneId, i32_val(32));
Value mask1 = icmp_slt(urem(laneId, i32_val(32)), i32_val(16));

Value addrShift16 = urem(add(laneId, i32_val(16)), warpSize);
Value addrShift32 = urem(add(laneId, i32_val(32)), warpSize);
Value addrShift48 = urem(add(laneId, i32_val(48)), warpSize);

SmallVector<Value> outVals;
for (size_t startIdx = 0; startIdx < inVals.size(); startIdx += 8) {
Value vec0 = undef(vecTy);
for (size_t vIdx = 0; vIdx < 4; ++vIdx) {
Expand All @@ -182,23 +173,82 @@ struct ConvertLayoutOpMFMAToDotOpConversion
i32_val(vIdx));
}

Value shflVec0 =
bitcast(targetInfo.shuffleIdx(rewriter, loc,
bitcast(vec0, int_ty(32)), addr0),
vecTy);
Value shflVec1 =
bitcast(targetInfo.shuffleIdx(rewriter, loc,
bitcast(vec1, int_ty(32)), addr1),
vecTy);
Value resVec0, resVec1;
if (mfmaLayout.getMDim() == 32) {
/*
Using wave shuffle to convert layouts (32x32x16 case):
1) Input MMA layout (32x32, fp8, 16 values):
_____________________________________________________________
|(t0 v0 v1 v2 v3) (t32 v0 v1 v2 v3) ... (t32 v12 v13 v14 v15)|
| ... ... |
|(t31 v0 v1 v2 v3) (t63 v0 v1 v2 v3) ... (t63 v12 v13 v14 v15)|
|_____________________________________________________________|
Value firstVec = select(mask, vec0, shflVec1);
Value secondVec = select(mask, shflVec0, vec1);
2) Output Dot operand layout (two 32x16 tiles, fp8, 8 values each):
____________________________________________________________ ___
|(t0 v0 v1 v2 v3 v4 v5 v6 v7) (t32 v0 v1 v2 v3 v4 v5 v6 v7) | |
| ... ... | |...
|(t31 v0 v1 v2 v3 v4 v5 v6 v7) (t63 v0 v1 v2 v3 v4 v5 v6 v7) | |
|____________________________________________________________| |___
*/

Value shflVec0 =
bitcast(targetInfo.shuffleIdx(
rewriter, loc, bitcast(vec0, int_ty(32)), addrShift32),
vecTy);
Value shflVec1 =
bitcast(targetInfo.shuffleIdx(
rewriter, loc, bitcast(vec1, int_ty(32)), addrShift32),
vecTy);

resVec0 = select(mask0, vec0, shflVec1);
resVec1 = select(mask0, shflVec0, vec1);
} else if (mfmaLayout.getMDim() == 16) {
/*
16x16x32 case:
1) Input MMA layout (two 16x16, fp8, 4 values each):
_________________________________________________________ ___________
|(t0 v0 v1 v2 v3) (t16 v0 v1 v2 v3) ... (t48 v0 v1 v2 v3)||(t0 v4 ...
| ... ... || ...
|(t15 v0 v1 v2 v3) (t31 v0 v1 v2 v3) ... (t63 v0 v1 v2 v3)||(t15 v4 ...
|_________________________________________________________||___________
2) Output Dot operand layout (16x32 tile, fp8, 8 values):
________________________________________________________________
|(t0 v0 v1 v2 v3 v4 v5 v6 v7) ... (t48 v0 v1 v2 v3 v4 v5 v6 v7) |
| ... ... |
|(t15 v0 v1 v2 v3 v4 v5 v6 v7) ... (t63 v0 v1 v2 v3 v4 v5 v6 v7) |
|________________________________________________________________|
*/

Value shflVec0_16 =
bitcast(targetInfo.shuffleIdx(
rewriter, loc, bitcast(vec0, int_ty(32)), addrShift16),
vecTy);
Value shflVec0_32 =
bitcast(targetInfo.shuffleIdx(
rewriter, loc, bitcast(vec0, int_ty(32)), addrShift32),
vecTy);
Value shflVec1_32 =
bitcast(targetInfo.shuffleIdx(
rewriter, loc, bitcast(vec1, int_ty(32)), addrShift32),
vecTy);
Value shflVec1_48 =
bitcast(targetInfo.shuffleIdx(
rewriter, loc, bitcast(vec1, int_ty(32)), addrShift48),
vecTy);

resVec0 = select(mask0, select(mask1, vec0, shflVec0_16),
select(mask1, shflVec1_32, shflVec1_48));
resVec1 = select(mask0, select(mask1, shflVec0_16, shflVec0_32),
select(mask1, shflVec1_48, vec1));
}

for (size_t vIdx = 0; vIdx < 4; ++vIdx) {
outVals.push_back(extract_element(elemTy, firstVec, i32_val(vIdx)));
outVals.push_back(extract_element(elemTy, resVec0, i32_val(vIdx)));
}
for (size_t vIdx = 0; vIdx < 4; ++vIdx) {
outVals.push_back(extract_element(elemTy, secondVec, i32_val(vIdx)));
outVals.push_back(extract_element(elemTy, resVec1, i32_val(vIdx)));
}
}

Expand Down
29 changes: 19 additions & 10 deletions third_party/amd/python/test/test_chained_dot_fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def _chained_dot(
stride_om,
stride_od,
Z,
M,
N,
BLOCK_D: tl.constexpr,
BLOCK_M: tl.constexpr,
Expand Down Expand Up @@ -92,26 +93,31 @@ def _chained_dot(
class chained_dot_fn(torch.autograd.Function):

@staticmethod
def forward(ctx, q, k, v, q_desc=1.0, k_desc=1.0, v_desc=1.0, s_sc=1.0, s_desc=1.0, o_sc=1.0):
def forward(ctx, q, k, v, msize=32, q_desc=1.0, k_desc=1.0, v_desc=1.0, s_sc=1.0, s_desc=1.0, o_sc=1.0):
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-2]
assert Lq == Lk and Lk == Lv
assert Lk in {16, 32, 64, 128}
assert msize in {16, 32}
o = torch.empty_like(q, dtype=v.dtype)

BLOCK_M = 128 if q.dtype == float8 else 256
if BLOCK_M > q.shape[1]:
BLOCK_M = int(math.pow(2, math.floor(math.log2(q.shape[1]))))
BLOCK_N = 32
if BLOCK_N > k.shape[1]:
BLOCK_N = int(math.pow(2, math.floor(math.log2(k.shape[1]))))
waves_per_eu = 2
num_warps = BLOCK_M // 32
num_warps = 4 if q.dtype == float8 else 8
num_stages = 1

grid = (triton.cdiv(q.shape[1], BLOCK_M), q.shape[0], 1)

_chained_dot[grid](q, k, v, o, q_desc,
k_desc, v_desc, s_sc, s_desc, o_sc, q.stride(0), q.stride(1), q.stride(2), k.stride(0),
k.stride(1), k.stride(2), v.stride(0), v.stride(1), v.stride(2), o.stride(0), o.stride(1),
o.stride(2), Z=q.shape[0], N=q.shape[1], BLOCK_D=Lk, BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N,
USE_FP8=(q.dtype == float8), waves_per_eu=waves_per_eu, num_warps=num_warps,
num_stages=num_stages)
o.stride(2), Z=q.shape[0], M=q.shape[1], N=k.shape[1], BLOCK_D=Lk, BLOCK_M=BLOCK_M,
BLOCK_N=BLOCK_N, USE_FP8=(q.dtype == float8), waves_per_eu=waves_per_eu, num_warps=num_warps,
num_stages=num_stages, matrix_instr_nonkdim=msize)

return o

Expand All @@ -127,13 +133,16 @@ def to_float8(x, dtype=float8, margin: float = 1.0):
return x_scaled.to(dtype), scale, 1.0 / scale


@pytest.mark.parametrize('N, D, dtype', [(*shape, dtype) for shape in [(128, 32), (256, 128)] for dtype in ['fp8']])
def test_chained_dot(N, D, dtype):
@pytest.mark.parametrize('M, N, D, dtype, msize', [(*shape, dtype, msize)
for shape in [(128, 64, 32), (256, 128, 128)]
for dtype in ['fp8']
for msize in [16, 32]])
def test_chained_dot(M, N, D, dtype, msize):
if dtype == 'fp8':
assert float8 is not None

BATCH = 1
q = torch.empty((BATCH, N, D), dtype=torch.float16, device="cuda").normal_(mean=0., std=0.5)
q = torch.empty((BATCH, M, D), dtype=torch.float16, device="cuda").normal_(mean=0., std=0.5)
k = torch.empty((BATCH, N, D), dtype=torch.float16, device="cuda").normal_(mean=0., std=0.5)
v = torch.empty((BATCH, D, N), dtype=torch.float16, device="cuda").normal_(mean=0., std=0.5)

Expand All @@ -151,7 +160,7 @@ def test_chained_dot(N, D, dtype):
scale_b=torch.tensor(v_desc, dtype=torch.float32, device="cuda"))
ref_f8, ref_sc, _ = to_float8(ref)

tri_out = chained_dot(q_f8, k_f8, v_f8, q_desc, k_desc, v_desc, s_sc, s_desc, ref_sc)
tri_out = chained_dot(q_f8, k_f8, v_f8, msize, q_desc, k_desc, v_desc, s_sc, s_desc, ref_sc)

assert tri_out.isnan().sum() == 0
torch.testing.assert_close(tri_out[0].float(), ref_f8.float(), atol=1e-2, rtol=0)
Expand All @@ -160,5 +169,5 @@ def test_chained_dot(N, D, dtype):
s = torch.matmul(q, k.transpose(1, 2))
ref = torch.matmul(s, v.transpose(1, 2))

tri_out = chained_dot(q, k, v)
tri_out = chained_dot(q, k, v, msize)
torch.testing.assert_close(tri_out, ref, atol=1e-2, rtol=0)

0 comments on commit 9aa870b

Please sign in to comment.