Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 0 additions & 126 deletions examples/gemm_fp8/example_tilelang_gemm_fp8_sm100.py

This file was deleted.

14 changes: 4 additions & 10 deletions src/op/copy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1117,20 +1117,16 @@ Stmt CopyNode::LowerTmemCopy(const LowerArgs &T,
bool is_ld = false; // tcgen05.ld (tensor memory -> register)
bool is_st = false; // tcgen05.st (register -> tensor memory)
bool is_cp = false; // tcgen05.cp (shared memory -> tensor memory)
bool src_needs_pack =
16 == src->dtype.bits(); // if needs .pack::16b when is_ld
bool dst_needs_unpack =
16 == dst->dtype.bits(); // if needs .unpack::16b when is_st

if (src.scope() == "shared.tmem" && dst.scope() == "local.fragment") {
is_ld = true;
} else if (src.scope() == "local.fragment" && dst.scope() == "shared.tmem") {
is_st = true;
} else if (src.scope() == "shared.dyn" && dst.scope() == "shared.tmem") {
is_cp = true;
} else {
ICHECK(0) << "Unsupported tensor memory copy: " << "src scope = "
<< src.scope() << ", dst scope = " << dst.scope();
ICHECK(0) << "Unsupported tensor memory copy: "
<< "src scope = " << src.scope()
<< ", dst scope = " << dst.scope();
}
// Currently tcgen05.cp is not supported
// TODO (mzw) Support tcgen05.cp
Expand Down Expand Up @@ -1250,10 +1246,8 @@ Stmt CopyNode::LowerTmemCopy(const LowerArgs &T,
: relative_wg_idx * (num_chunks_each_wg * meta.width);
have_succeeded = true;
Array<PrimExpr> args;
const char *bool_str = src_needs_pack ? "true" : "false";
args.push_back(StringImm(meta.intrinsics_name + "<" +
std::to_string(num_chunks_each_wg) + ", " +
bool_str + ">"));
std::to_string(num_chunks_each_wg) + ">"));
args.push_back(
BufferLoad(src, {(int)logical_row_min,
(int)logical_col_min})); // Will be translated later
Expand Down
2 changes: 0 additions & 2 deletions src/op/gemm_py.cc
Original file line number Diff line number Diff line change
Expand Up @@ -428,8 +428,6 @@ TVM_FFI_STATIC_INIT_BLOCK() {
result.push_back(Integer(meta.atom_m));
result.push_back(Integer(meta.atom_n));
result.push_back(Integer(meta.atom_k));
result.push_back(Integer(meta.enable_ws));
result.push_back(Integer(meta.enable_2cta));
}
return result;
});
Expand Down
38 changes: 11 additions & 27 deletions src/op/tcgen5_meta.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,16 @@ using runtime::DataType;

struct TCGEN5MMAMeta {
int atom_m, atom_n, atom_k;
bool enable_ws, enable_2cta;
};

inline std::pair<bool, TCGEN5MMAMeta>
GetTCGEN5MMAMeta(int M, int N, int K, DataType ab_dtype, DataType c_dtype) {
// TODO (lei) Currently not all shapes / dtypes are supported for TCGEN5MMA.
#define FAIL \
return { false, TCGEN5MMAMeta{0, 0, 0} }
#define SUCCESS(atom_m, atom_n, atom_k) \
return { \
false, TCGEN5MMAMeta { 0, 0, 0, false, false } \
}
#define SUCCESS(atom_m, atom_n, atom_k, use_ws, use_2cta) \
return { \
true, TCGEN5MMAMeta { atom_m, atom_n, atom_k, use_ws, use_2cta } \
true, TCGEN5MMAMeta { atom_m, atom_n, atom_k } \
}
std::vector<int> ws_valid_atom_ns = {256, 128, 64};
if ((ab_dtype.is_bfloat16() || ab_dtype.is_float16()) &&
Expand All @@ -37,52 +34,39 @@ GetTCGEN5MMAMeta(int M, int N, int K, DataType ab_dtype, DataType c_dtype) {
if (M % 128 == 0) {
for (int atom_n = 256; atom_n >= 16; atom_n -= 16)
if (N % atom_n == 0)
SUCCESS(128, atom_n, 16, false, false);
SUCCESS(128, atom_n, 16);
FAIL;
} else if (M % 64 == 0) {
for (int atom_n : ws_valid_atom_ns)
if (N % atom_n == 0)
SUCCESS(64, atom_n, 16, false, false);
SUCCESS(64, atom_n, 16);
FAIL;
} else if (M % 32 == 0) {
for (int atom_n : ws_valid_atom_ns)
if (N % atom_n == 0)
SUCCESS(32, atom_n, 16, false, false);
SUCCESS(32, atom_n, 16);
FAIL;
} else {
FAIL;
}
} else if ((ab_dtype.is_float8_e4m3fn() || ab_dtype.is_float8_e4m3() ||
ab_dtype.is_float8_e5m2() || ab_dtype.is_float8_e5m2fnuz() ||
ab_dtype.is_float6_e2m3fn() || ab_dtype.is_float6_e3m2fn() ||
ab_dtype.is_float4_e2m1fn()) &&
((c_dtype.is_float() && c_dtype.bits() == 32) ||
(c_dtype.is_float16() && c_dtype.bits() == 16))) {
} else if ((ab_dtype.is_float8_e4m3fn() || ab_dtype.is_float8_e5m2()) &&
(c_dtype.is_float() && c_dtype.bits() == 32)) {
if (K % 32 != 0)
FAIL;
if (M % 128 == 0) {
for (int atom_n : ws_valid_atom_ns)
if (N % atom_n == 0)
SUCCESS(128, atom_n, 32, true, false);
for (int atom_n = 256; atom_n >= 16; atom_n -= 16)
if (N % atom_n == 0)
SUCCESS(128, atom_n, 32, false, true);
for (int atom_n = 256; atom_n >= 8; atom_n -= 8)
if (N % atom_n == 0)
SUCCESS(128, atom_n, 32, false, false);
SUCCESS(128, atom_n, 32);
FAIL;
} else if (M % 64 == 0) {
for (int atom_n : ws_valid_atom_ns)
if (N % atom_n == 0)
SUCCESS(64, atom_n, 32, true, false);
for (int atom_n = 256; atom_n >= 8; atom_n -= 8)
if (N % atom_n == 0)
SUCCESS(128, atom_n, 32, false, false);
SUCCESS(64, atom_n, 32);
FAIL;
} else if (M % 32 == 0) {
for (int atom_n : ws_valid_atom_ns)
if (N % atom_n == 0)
SUCCESS(32, atom_n, 32, true, false);
SUCCESS(32, atom_n, 32);
FAIL;
} else {
FAIL;
Expand Down
35 changes: 10 additions & 25 deletions src/tl_templates/cuda/copy_sm100.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,21 +51,6 @@ __device__ __forceinline__ void st_global_256(fp8_e4_32_t *ptr,
:
: "l"(ptr), "l"(val.x), "l"(val.y), "l"(val.z), "l"(val.w));
}
__device__ __forceinline__ ulonglong4 ld_global_256(const fp8_e5_32_t *ptr) {
ulonglong4 ret;
asm volatile("ld.global.v4.u64 {%0, %1, %2, %3}, [%4];"
: "=l"(ret.x), "=l"(ret.y), "=l"(ret.z), "=l"(ret.w)
: "l"(ptr));
return ret;
}

__device__ __forceinline__ void st_global_256(fp8_e5_32_t *ptr,
fp8_e5_32_t &val8) {
ulonglong4 &val = *((ulonglong4 *)&val8);
asm volatile("st.global.v4.u64 [%0], {%1, %2, %3, %4};"
:
: "l"(ptr), "l"(val.x), "l"(val.y), "l"(val.z), "l"(val.w));
}

__device__ __forceinline__ unsigned long long
pack_bfloat16x4(const bfloat16_t x, const bfloat16_t y, const bfloat16_t z,
Expand Down Expand Up @@ -110,38 +95,38 @@ __device__ __forceinline__ void tcgen05_ld_core(uint32_t const &tmem_start_col,
}
}

template <int N, bool pack16, typename dst_t>
template <int N, typename dst_t>
__device__ __forceinline__ void
tcgen05_ld_32dp32bNx(uint32_t const &tmem_start_col,
uint32_t const &tmem_col_offset, dst_t *dst_ptr) {
tcgen05_ld_core<tl::tmem_ld_32dp32bNx<pack16>, 7, N>(
tmem_start_col + tmem_col_offset, dst_ptr);
tcgen05_ld_core<tl::tmem_ld_32dp32bNx, 7, N>(tmem_start_col + tmem_col_offset,
dst_ptr);
tl::fence_view_async_tmem_load();
}

template <int N, bool pack16, typename dst_t>
template <int N, typename dst_t>
__device__ __forceinline__ void
tcgen05_ld_32dp64bNx(uint32_t const &tmem_start_col,
uint32_t const &tmem_col_offset, dst_t *dst_ptr) {
tcgen05_ld_core<tl::tmem_ld_32dp64bNx<pack16>, 7, N>(
tmem_start_col + tmem_col_offset, dst_ptr);
tcgen05_ld_core<tl::tmem_ld_32dp64bNx, 7, N>(tmem_start_col + tmem_col_offset,
dst_ptr);
tl::fence_view_async_tmem_load();
}

template <int N, bool pack16, typename dst_t>
template <int N, typename dst_t>
__device__ __forceinline__ void
tcgen05_ld_32dp128bNx(uint32_t const &tmem_start_col,
uint32_t const &tmem_col_offset, dst_t *dst_ptr) {
tcgen05_ld_core<tl::tmem_ld_32dp128bNx<pack16>, 6, N>(
tcgen05_ld_core<tl::tmem_ld_32dp128bNx, 6, N>(
tmem_start_col + tmem_col_offset, dst_ptr);
tl::fence_view_async_tmem_load();
}

template <int N, bool pack16, typename dst_t>
template <int N, typename dst_t>
__device__ __forceinline__ void
tcgen05_ld_32dp256bNx(uint32_t const &tmem_start_col,
uint32_t const &tmem_col_offset, dst_t *dst_ptr) {
tcgen05_ld_core<tl::tmem_ld_32dp256bNx<pack16>, 5, N>(
tcgen05_ld_core<tl::tmem_ld_32dp256bNx, 5, N>(
tmem_start_col + tmem_col_offset, dst_ptr);
tl::fence_view_async_tmem_load();
}
Expand Down
Loading
Loading