Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
137 changes: 137 additions & 0 deletions examples/gemm_fp8/example_tilelang_gemm_amd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
import torch
import tilelang
import tilelang.language as T
from tilelang.utils.tensor import torch_assert_close
import itertools


def ref_program(A, B):
return (A.half() @ B.half().T).to(dtype=torch.float32)


def manual_check_prog(C, C_ref):
torch_assert_close(C[0], C_ref[0], rtol=0.01, atol=0.1)


def supply_prog(args):
a_param, b_param = args
M, K = a_param.shape
N, _ = b_param.shape
a = (torch.randn(M, K, dtype=torch.float16, device='cuda') *
0.01).to(dtype=torch.float8_e4m3fnuz)
b = (torch.randn(N, K, dtype=torch.float16, device='cuda') *
0.01).to(dtype=torch.float8_e4m3fnuz)
return [a, b]


def get_configs():
block_Ms = [32, 64, 128]
block_Ns = [32, 64, 128]
block_Ks = [64, 128]
num_stages = [0]
num_threads = [256]
k_packs = [1, 2]
gemm_types = ["ss", "rs"]

valid_configs = []

for m, n, k, stages, t, kp, gemm_type in itertools.product(block_Ms, block_Ns, block_Ks,
num_stages, num_threads, k_packs,
gemm_types):
valid_configs.append({
"block_M": m,
"block_N": n,
"block_K": k,
"num_stages": stages,
"num_threads": t,
"k_pack": kp,
"gemm_type": gemm_type,
})
return valid_configs


@tilelang.autotune(
configs=get_configs(),
cache_input_tensors=True,
ref_prog=ref_program,
manual_check_prog=manual_check_prog,
supply_prog=supply_prog)
@tilelang.jit(out_idx=[-1])
def fp8_matmul(M, N, K, block_M, block_N, block_K, num_stages, num_threads, k_pack, gemm_type):
dtype = "float8_e4m3fnuz"
accum_dtype = "float"

@T.prim_func
def gemm_fp8_rs(
A: T.Tensor((M, K), dtype),
B: T.Tensor((N, K), dtype),
C: T.Tensor((M, N), accum_dtype),
):
with T.Kernel(
T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=num_threads) as (bx, by):
A_local = T.alloc_fragment((block_M, block_K), dtype)
B_shared = T.alloc_shared((block_N, block_K), dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)

T.clear(C_local)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
T.copy(A[by * block_M, k * block_K], A_local)
T.copy(B[bx * block_N, k * block_K], B_shared)
T.gemm(
A_local,
B_shared,
C_local,
transpose_B=True,
k_pack=k_pack,
policy=T.GemmWarpPolicy.FullRow)

T.copy(C_local, C[by * block_M, bx * block_N])

@T.prim_func
def gemm_fp8_ss(
A: T.Tensor((M, K), dtype),
B: T.Tensor((N, K), dtype),
C: T.Tensor((M, N), accum_dtype),
):
with T.Kernel(
T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=num_threads) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K), dtype)
B_shared = T.alloc_shared((block_N, block_K), dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)

T.clear(C_local)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
T.copy(A[by * block_M, k * block_K], A_shared)
T.copy(B[bx * block_N, k * block_K], B_shared)
T.gemm(
A_shared,
B_shared,
C_local,
transpose_B=True,
k_pack=k_pack,
policy=T.GemmWarpPolicy.FullRow)

T.copy(C_local, C[by * block_M, bx * block_N])

if gemm_type == "ss":
return gemm_fp8_ss
elif gemm_type == "rs":
return gemm_fp8_rs
else:
raise ValueError(f"Invalid gemm_type: {gemm_type}")


def test_gemm_fp8(M, N, K):
kernel = fp8_matmul(M, N, K)
a = (torch.randn(M, K, dtype=torch.float16, device='cuda') *
0.01).to(dtype=torch.float8_e4m3fnuz)
b = (torch.randn(N, K, dtype=torch.float16, device='cuda') *
0.01).to(dtype=torch.float8_e4m3fnuz)
c = kernel(a, b)
ref_c = ref_program(a, b)
torch_assert_close(c, ref_c, rtol=1e-2, atol=1e-2)
print("passed~")


if __name__ == "__main__":
test_gemm_fp8(512, 512, 512)
58 changes: 39 additions & 19 deletions src/layout/gemm_layouts.cc
Original file line number Diff line number Diff line change
Expand Up @@ -59,21 +59,39 @@ From https://github.com/RadeonOpenCompute/amd_matrix_instruction_calculator
./matrix_calculator.py --architecture cdna1 --instruction v_mfma_f32_16x16x16f16
--detail-instruction
*/
Fragment makeGemmFragmentAB16x16CDNA() {
Fragment makeGemmFragmentAB16x16CDNA(const int k_pack) {
IterVar i = make_itervar("i", 16);
IterVar j = make_itervar("j", 16 * k_pack);
IterVar rep = make_itervar("rep", 1);
PrimExpr forward_thread = 16 * FloorDiv(j->var, 4 * k_pack) + i;
PrimExpr index = FloorMod(j->var, 4 * k_pack);
return Fragment({i, j}, {index}, forward_thread, rep);
}

Fragment makeGemmFragmentAB16x16CDNATransposed(const int k_pack) {
IterVar i = make_itervar("i", 16 * k_pack);
IterVar j = make_itervar("j", 16);
IterVar rep = make_itervar("rep", 1);
PrimExpr forward_thread = 16 * FloorDiv(j->var, 4) + i;
PrimExpr index = FloorMod(j->var, 4);
PrimExpr forward_thread = 16 * FloorDiv(i->var, 4 * k_pack) + j;
PrimExpr index = FloorMod(i->var, 4 * k_pack);
return Fragment({i, j}, {index}, forward_thread, rep);
}

Fragment makeGemmFragmentAB16x16CDNATransposed() {
Fragment makeGemmFragmentAB16x32CDNA(const int k_pack) {
IterVar i = make_itervar("i", 16);
IterVar j = make_itervar("j", 32 * k_pack);
IterVar rep = make_itervar("rep", 1);
PrimExpr forward_thread = 16 * FloorDiv(j->var, 8 * k_pack) + i;
PrimExpr index = FloorMod(j->var, 8 * k_pack);
return Fragment({i, j}, {index}, forward_thread, rep);
}

Fragment makeGemmFragmentAB16x32CDNATransposed(const int k_pack) {
IterVar i = make_itervar("i", 32 * k_pack);
IterVar j = make_itervar("j", 16);
IterVar rep = make_itervar("rep", 1);
PrimExpr forward_thread = 16 * FloorDiv(i->var, 4) + j;
PrimExpr index = FloorMod(i->var, 4);
PrimExpr forward_thread = 16 * FloorDiv(i->var, 8 * k_pack) + j;
PrimExpr index = FloorMod(i->var, 8 * k_pack);
return Fragment({i, j}, {index}, forward_thread, rep);
}

Expand Down Expand Up @@ -224,27 +242,34 @@ Fragment makeGemmFragmentB(const int block_m, const int block_n,
Fragment makeGemmFragmentACDNA(const int block_m, const int block_n,
const int block_k, const int warp_m,
const int warp_n, const int element_size,
bool transposed) {
const int k_pack, bool transposed) {
// assume not transposed
ICHECK(block_m % warp_m == 0);
ICHECK(block_n % warp_n == 0);
ICHECK(warp_m % 16 == 0);
ICHECK(block_k % 16 == 0);
const int mfma_k = k_pack * (element_size == 16 ? 16 : 32);
ICHECK(block_k % mfma_k == 0);
ICHECK(element_size == 8 || element_size == 16)
<< "element bitwidth=" << element_size;
if (transposed) {
auto base_layout =
makeGemmFragmentAB16x16CDNATransposed()->Repeat({1, 1}, false, false);
element_size == 16
? makeGemmFragmentAB16x16CDNATransposed(k_pack)->Repeat(
{1, 1}, false, false)
: makeGemmFragmentAB16x32CDNATransposed(k_pack)->Repeat(
{1, 1}, false, false);
auto warp_layout =
base_layout->Repeat({block_k / 16, warp_m / 16}, false, true);
base_layout->Repeat({block_k / mfma_k, warp_m / 16}, false, true);
auto block_layout = warp_layout->Repeat({1, block_m / warp_m}, true, true)
->Replicate(block_n / warp_n);
return block_layout;
} else {
auto base_layout =
makeGemmFragmentAB16x16CDNA()->Repeat({1, 1}, false, false);
element_size == 16
? makeGemmFragmentAB16x16CDNA(k_pack)->Repeat({1, 1}, false, false)
: makeGemmFragmentAB16x32CDNA(k_pack)->Repeat({1, 1}, false, false);
auto warp_layout =
base_layout->Repeat({warp_m / 16, block_k / 16}, false, false);
base_layout->Repeat({warp_m / 16, block_k / mfma_k}, false, false);
auto block_layout = warp_layout->Repeat({block_m / warp_m, 1}, true, true)
->Replicate(block_n / warp_n);
return block_layout;
Expand Down Expand Up @@ -397,7 +422,7 @@ Layout makeMatrixCoreSwizzleLayout(int stride, int continuous, int element_size,
const int numBanks = 32;
const int bankBitWidth = 32;
const int SIMDWidth = 16;
const int vecSize = 4 * kPack;
const int vecSize = (64 / element_size) * kPack;
const int innerDimLength = continuous;
const int typeWidthInBit = element_size;

Expand Down Expand Up @@ -616,12 +641,7 @@ Layout makeGemmABLayoutHopper(int mat_stride, int mat_continuous,

Layout makeGemmABLayoutCDNA(int stride, int continuous, int element_size,
int kPack) {
int vector_size = 128 / element_size;
if (continuous % (vector_size * 4) == 0)
return makeMatrixCoreSwizzleLayout(stride, continuous, element_size, kPack);
else {
return makeGemmABLayoutPadded(stride, continuous, element_size);
}
return makeMatrixCoreSwizzleLayout(stride, continuous, element_size, kPack);
}
} // namespace tl
} // namespace tvm
2 changes: 1 addition & 1 deletion src/layout/layout.h
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ Fragment makeGemmFragmentB(const int block_m, const int block_n,
Fragment makeGemmFragmentACDNA(const int block_m, const int block_n,
const int block_k, const int warp_m,
const int warp_n, const int element_size,
bool transposed = false);
const int k_pack, bool transposed = false);

Comment on lines 154 to 158
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Public API now requires k_pack; consider defaulting to 1 to minimize breakage.

Adding a mandatory k_pack breaks all external callers. If backward-compat isn’t a concern, fine; otherwise, default it to 1 so existing FP16/BF16 paths keep compiling.

Apply:

-Fragment makeGemmFragmentACDNA(const int block_m, const int block_n,
-                               const int block_k, const int warp_m,
-                               const int warp_n, const int element_size,
-                               const int k_pack, bool transposed = false);
+Fragment makeGemmFragmentACDNA(const int block_m, const int block_n,
+                               const int block_k, const int warp_m,
+                               const int warp_n, const int element_size,
+                               const int k_pack = 1, bool transposed = false);
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
Fragment makeGemmFragmentACDNA(const int block_m, const int block_n,
const int block_k, const int warp_m,
const int warp_n, const int element_size,
bool transposed = false);
const int k_pack, bool transposed = false);
Fragment makeGemmFragmentACDNA(const int block_m, const int block_n,
const int block_k, const int warp_m,
const int warp_n, const int element_size,
const int k_pack = 1, bool transposed = false);
🤖 Prompt for AI Agents
In src/layout/layout.h around lines 154-158, the newly added mandatory k_pack
parameter breaks existing external callers; make k_pack optional by giving it a
default value of 1 in the function declaration so previous FP16/BF16 call sites
continue to compile. Update the prototype to set k_pack = 1 and ensure the
corresponding implementation/definition (and any inline/overload declarations)
uses the same default or handles the absent argument consistently.

// Default Memory Layout
Layout makeGemmLayoutLinear(int stride, int continuous);
Expand Down
6 changes: 1 addition & 5 deletions src/op/gemm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -582,7 +582,7 @@ LayoutMap GemmNode::InferLayout(const LayoutInferArgs &T,
results.Set(A, shared_layout);
} else if (A.scope() == "local.fragment") {
auto fragment = makeGemmFragmentACDNA(M, N, K, M / warp_m, N / warp_n,
A->dtype.bits(), trans_A);
A->dtype.bits(), kPack, trans_A);
results.Set(A, fragment->BindThreadRange(thread_range));
} else {
ICHECK(0);
Expand All @@ -594,10 +594,6 @@ LayoutMap GemmNode::InferLayout(const LayoutInferArgs &T,
*as_const_int(B->shape[dim_B - 1]), B->dtype.bits(), kPack);

results.Set(B, shared_layout);
} else if (B.scope() == "local.fragment") {
auto fragment =
makeGemmFragmentB(M, N, K, M / warp_m, N / warp_n, trans_B);
results.Set(B, fragment->BindThreadRange(thread_range));
} else {
ICHECK(0);
}
Expand Down
Loading
Loading