Skip to content

Commit

Permalink
Merge pull request triton-lang#109 from dfukalov/dfukalov/work-3
Browse files Browse the repository at this point in the history
[ROCM] Enable part of tl.dot operations.
  • Loading branch information
rsanthanam-amd authored Feb 12, 2023
2 parents ec82cd6 + a6596fc commit ec387d5
Show file tree
Hide file tree
Showing 4 changed files with 193 additions and 0 deletions.
6 changes: 6 additions & 0 deletions lib/Analysis/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,9 @@ bool supportMMA(triton::DotOp op, int version) {
// Refer to mma section for the data type supported by Volta and Hopper
// Tensor Core in
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-fragment-mma-884-f16
#ifdef USE_ROCM
return false;
#endif
auto aElemTy = op.a().getType().cast<RankedTensorType>().getElementType();
auto bElemTy = op.b().getType().cast<RankedTensorType>().getElementType();
if (aElemTy.isF32() && bElemTy.isF32()) {
Expand All @@ -135,6 +138,9 @@ bool supportMMA(Value value, int version) {
// types of both the operands are identical here.
assert((version == 1 || version == 2) &&
"Unexpected MMA layout version found");
#ifdef USE_ROCM
return false;
#endif
auto elemTy = value.getType().cast<RankedTensorType>().getElementType();
return elemTy.isF16() || elemTy.isBF16() ||
(elemTy.isF32() && version >= 2) ||
Expand Down
8 changes: 8 additions & 0 deletions lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -383,10 +383,12 @@ class ConvertTritonGPUToLLVM

// If the load byte width is not eligible or the current compute
// capability does not support async copy, then we do decompose
#ifndef USE_ROCM
if (triton::gpu::InsertSliceAsyncOp::getEligibleLoadByteWidth(
computeCapability)
.contains(byteWidth))
return;
#endif

// load
auto tmpTy =
Expand Down Expand Up @@ -424,6 +426,11 @@ class ConvertTritonGPUToLLVM
});

mod.walk([&](triton::gpu::AsyncWaitOp asyncWaitOp) -> void {
#ifdef USE_ROCM
assert(decomposed &&
"AsyncWait is not supported for ROCM and should be removed");
asyncWaitOp.erase();
#else
if (!triton::gpu::AsyncWaitOp::isSupported(computeCapability)) {
// async wait is supported in Ampere and later
asyncWaitOp.erase();
Expand All @@ -434,6 +441,7 @@ class ConvertTritonGPUToLLVM
builder.create<triton::gpu::AsyncWaitOp>(asyncWaitOp.getLoc(), 0);
asyncWaitOp.erase();
}
#endif
});
}
};
Expand Down
3 changes: 3 additions & 0 deletions lib/Dialect/TritonGPU/Transforms/Combine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -862,6 +862,9 @@ class RematerializeForward : public mlir::RewritePattern {
// -----------------------------------------------------------------------------
namespace {
int computeCapabilityToMMAVersion(int computeCapability) {
#ifdef USE_ROCM
return 1;
#endif
if (computeCapability < 70) {
return 0;
} else if (computeCapability < 80) {
Expand Down
176 changes: 176 additions & 0 deletions python/test/unit/language/test_core_amd.py
Original file line number Diff line number Diff line change
Expand Up @@ -1051,6 +1051,182 @@ def kernel(X, Z, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, AXIS: tl.constexp
np.testing.assert_equal(z_ref, z_tri)


# ---------------
# test dot
# ---------------


@pytest.mark.parametrize("M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, dtype",
[(*shape, 4, False, False, epilogue, allow_tf32, dtype)
for shape in [(64, 64, 64)]
for epilogue in ['none', 'trans', 'add-matrix', 'add-rows', 'add-cols', 'softmax', 'chain-dot']
for allow_tf32 in [True, False]
# for dtype in ['float16', 'float32']
for dtype in ['float32']
if not (allow_tf32 and (dtype in ['float16']))] +
[(*shape_nw, col_a, col_b, 'none', allow_tf32, dtype)
for shape_nw in [[128, 256, 32, 8],
[128, 16, 32, 4],
[32, 128, 64, 4],
[128, 128, 64, 4],
#[64, 128, 128, 4],
[32, 128, 64, 2],
[128, 128, 64, 2],
#[64, 128, 128, 2]
]
for allow_tf32 in [True]
for col_a in [True, False]
for col_b in [True, False]
#for dtype in ['int8', 'float16', 'float32']])
for dtype in ['float32']])
def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, dtype, device='cuda'):
capability = torch.cuda.get_device_capability()
if capability[0] < 7:
pytest.skip("Only test tl.dot() on devices with sm >= 70")
if capability[0] < 8:
if dtype == 'int8':
pytest.skip("Only test int8 on devices with sm >= 80")
elif dtype == 'float32' and allow_tf32:
pytest.skip("Only test tf32 on devices with sm >= 80")

torch.backends.cuda.matmul.allow_tf32 = allow_tf32

# triton kernel
@triton.jit
def kernel(X, stride_xm, stride_xk,
Y, stride_yk, stride_yn,
W, stride_wn, stride_wl,
Z, stride_zm, stride_zn,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
ADD_MATRIX: tl.constexpr, ADD_ROWS: tl.constexpr, ADD_COLS: tl.constexpr,
ALLOW_TF32: tl.constexpr,
DO_SOFTMAX: tl.constexpr, CHAIN_DOT: tl.constexpr,
COL_A: tl.constexpr, COL_B: tl.constexpr):
off_m = tl.arange(0, BLOCK_M)
off_n = tl.arange(0, BLOCK_N)
off_l = tl.arange(0, BLOCK_N)
off_k = tl.arange(0, BLOCK_K)
Xs = X + off_m[:, None] * stride_xm + off_k[None, :] * stride_xk
Ys = Y + off_k[:, None] * stride_yk + off_n[None, :] * stride_yn
Ws = W + off_n[:, None] * stride_wn + off_l[None, :] * stride_wl
Zs = Z + off_m[:, None] * stride_zm + off_n[None, :] * stride_zn
x = tl.load(Xs)
y = tl.load(Ys)
z = tl.dot(x, y, allow_tf32=ALLOW_TF32)
if ADD_MATRIX:
z += tl.load(Zs)
if ADD_ROWS:
ZRs = Z + off_m * stride_zm
z += tl.load(ZRs)[:, None]
if ADD_COLS:
ZCs = Z + off_n * stride_zn
z += tl.load(ZCs)[None, :]
if DO_SOFTMAX:
max = tl.max(z, 1)
z = z - max[:, None]
num = tl.exp(z)
den = tl.sum(num, 1)
z = num / den[:, None]
if CHAIN_DOT:
w = tl.load(Ws)
z = tl.dot(z.to(w.dtype), w)
tl.store(Zs, z)
# input
rs = RandomState(17)
if col_a:
x = numpy_random((K, M), dtype_str=dtype, rs=rs).T
else:
x = numpy_random((M, K), dtype_str=dtype, rs=rs)
if col_b:
y = numpy_random((N, K), dtype_str=dtype, rs=rs).T
else:
y = numpy_random((K, N), dtype_str=dtype, rs=rs)
w = numpy_random((N, N), dtype_str=dtype, rs=rs)
if 'int' not in dtype:
x *= .1
y *= .1
if dtype == 'float32' and allow_tf32:
x = (x.view('uint32') & np.uint32(0xffffe000)).view('float32')
y = (y.view('uint32') & np.uint32(0xffffe000)).view('float32')
w = (w.view('uint32') & np.uint32(0xffffe000)).view('float32')
x_tri = to_triton(x, device=device)
y_tri = to_triton(y, device=device)
w_tri = to_triton(w, device=device)
# triton result
if dtype == 'int8':
z = 1 + numpy_random((M, N), dtype_str='int32', rs=rs)
else:
z = 1 + numpy_random((M, N), dtype_str=dtype, rs=rs) * .1

z_tri = to_triton(z, device=device)
if epilogue == 'trans':
z_tri = torch.as_strided(z_tri, (M, N), z_tri.stride()[::-1])
pgm = kernel[(1, 1)](x_tri, x_tri.stride(0), x_tri.stride(1),
y_tri, y_tri.stride(0), y_tri.stride(1),
w_tri, w_tri.stride(0), w_tri.stride(1),
z_tri, z_tri.stride(0), z_tri.stride(1),
COL_A=col_a, COL_B=col_b,
BLOCK_M=M, BLOCK_K=K, BLOCK_N=N,
ADD_MATRIX=epilogue == 'add-matrix',
ADD_ROWS=epilogue == 'add-rows',
ADD_COLS=epilogue == 'add-cols',
DO_SOFTMAX=epilogue == 'softmax',
CHAIN_DOT=epilogue == 'chain-dot',
ALLOW_TF32=allow_tf32,
num_warps=num_warps)
# torch result
if dtype == 'int8':
z_ref = np.matmul(x.astype(np.float32),
y.astype(np.float32())).astype(np.int32)
else:
z_ref = np.matmul(x, y)

if epilogue == 'add-matrix':
z_ref += z
if epilogue == 'add-rows':
z_ref += z[:, 0][:, None]
if epilogue == 'add-cols':
z_ref += z[0, :][None, :]
if epilogue == 'softmax':
num = np.exp(z_ref - np.max(z_ref, axis=-1, keepdims=True))
denom = np.sum(num, axis=-1, keepdims=True)
z_ref = num / denom
if epilogue == 'chain-dot':
z_ref = np.matmul(z_ref, w)
# compare
# print(z_ref[:,0], z_tri[:,0])
if dtype == 'float32':
# XXX: Somehow there's a larger difference when we use float32
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01, atol=1e-3)
else:
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01)
# make sure ld/st are vectorized
#ptx = pgm.asm['ptx']
#assert 'ld.global.v4' in ptx
#assert 'st.global.v4' in ptx
#if dtype == 'float32' and allow_tf32:
# assert 'mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32' in ptx
#elif dtype == 'float32' and allow_tf32:
# assert 'mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32' not in ptx
#elif dtype == 'int8':
# assert 'mma.sync.aligned.m16n8k32.row.col.satfinite.s32.s8.s8.s32' in ptx


def test_dot_without_load():
@triton.jit
def kernel(out):
pid = tl.program_id(axis=0)
a = tl.zeros((32, 32), tl.float32)
b = tl.zeros((32, 32), tl.float32)
c = tl.zeros((32, 32), tl.float32)
c = tl.dot(a, b)
pout = out + tl.arange(0, 32)[:, None] * 32 + tl.arange(0, 32)[None, :]
tl.store(pout, c)

out = torch.ones((32, 32), dtype=torch.float32, device="cuda")
kernel[(1,)](out)

# ---------------
# test arange
# ---------------
Expand Down

0 comments on commit ec387d5

Please sign in to comment.