Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
67 commits
Select commit Hold shift + click to select a range
46794c4
[Enhancement] Refactor buffer index handling for improved precision a…
Jul 29, 2025
499daa3
Remove obsolete test script for AMD example, streamlining the example…
Jul 29, 2025
555537a
Remove unused dtype_size variable in AMD example script to streamline…
Jul 29, 2025
f84bc97
Add input configuration file and update AMD example script for enhanc…
Jul 30, 2025
21cf0c3
Remove input configuration file and obsolete test script; enhance AMD…
Jul 30, 2025
9b2fab3
Refactor AMD example script for FlashAttention-2
Jul 30, 2025
24e08ae
Refactor formatting in AMD FlashAttention example script
Jul 30, 2025
bc2663a
Update example_amd_flash_attn_fwd.py
LeiWang1999 Jul 31, 2025
4d427d9
Enhance AMD example script and update CI workflows
Aug 18, 2025
4fd8529
Merge branch 'main' into main
Alex4210987 Aug 18, 2025
cf99bef
Remove redundant tool cache cleanup step in AMD CI workflow
Aug 18, 2025
e839192
Remove `torch` dependency from `requirements-rocm.txt` to streamline …
Aug 18, 2025
70f3f6a
Add new AMD FlashAttention example and test script
Aug 23, 2025
2bf7961
Update configurations in `example_amd_flash_attn_fwd.py` for autotuner
Aug 23, 2025
f7f6131
Update submodule 'tvm' to commit 6ccc74f622c7ec4ac25d430d0f6546e7b9ed…
Aug 24, 2025
91e9548
Update submodule 'tvm' to commit 14ff70ab142b9e5a31bbf9c7923c8a697d41…
Aug 24, 2025
460c64f
Merge branch 'tile-ai:main' into main
Alex4210987 Aug 24, 2025
8eefca0
Merge branch 'tile-ai:main' into main
Alex4210987 Sep 3, 2025
7bd45c5
Add example for AMD Flash Attention backward pass implementation
Sep 3, 2025
4cf8c30
Merge branch 'amd_dev'
Sep 3, 2025
bc22219
Merge branch 'main' of https://github.com/Alex4210987/tilelang
Sep 3, 2025
50b97e1
Enhance AMD Flash Attention example with additional testing capabilities
Sep 3, 2025
05305f2
Update submodule TVM to commit a64a5926a6e59f5417ef2501f9d88b467337cf6a
Sep 3, 2025
923fc6d
Refactor HIP intrinsic rules to CUDA
Sep 3, 2025
7b7fda3
Update AMD CI workflow to uninstall specific PyTorch packages before …
Sep 3, 2025
1008679
Remove unused shared memory allocations in AMD Flash Attention backwa…
Sep 3, 2025
f490b4a
Remove unnecessary pip uninstall command from AMD CI workflow
Sep 3, 2025
b39ada8
Refactor DispatchHIPWarpActiveMask function in HIP intrinsic rules
Sep 3, 2025
d62b898
Refactor formatting of HIP intrinsic rule registrations
Sep 3, 2025
e7b0f30
Update file name and documentation for HIP intrinsic rules
Sep 3, 2025
8c73c9c
Enhance DispatchHIPShuffle function with clang-analyzer comments
Sep 3, 2025
c8aec22
lint fix
LeiWang1999 Sep 4, 2025
4549e0e
Merge branch 'main' of https://github.com/tile-ai/tilelang into Alex4…
LeiWang1999 Sep 4, 2025
ccadc2e
fix
LeiWang1999 Sep 4, 2025
b491082
Enhance autotuner configurations in example_amd_flash_attn_fwd.py by …
Sep 7, 2025
3289910
Add backward attention example to test script
Sep 7, 2025
10870e1
Refactor FlashAttention implementation in example_amd_flash_attn_bwd.…
Sep 7, 2025
f20cd33
Enhance FlashAttention backward implementation in example_amd_flash_a…
Sep 7, 2025
570c6c9
Enhance FlashAttention backward implementation in example_amd_flash_a…
Sep 7, 2025
fff5543
Refactor FlashAttention implementation in example_amd_flash_attn_bwd.…
Sep 8, 2025
d5e3b6b
Enhance FlashAttention backward implementation in example_amd_flash_a…
Sep 10, 2025
3f15c59
Expand autotuner configurations in example_amd_flash_attn_bwd.py and …
Sep 10, 2025
0582143
Enhance performance calculations and benchmarking in example_amd_flas…
Sep 10, 2025
e8f0d9f
Remove forward attention test commands from test.sh and retain backwa…
Sep 11, 2025
335bbc6
Refactor FlashAttention forward and backward implementations in examp…
Sep 18, 2025
cf8cc88
Refactor FlashAttention implementation in example_amd_flash_attn_bwd.py
Sep 20, 2025
3a00c4d
Enhance FlashAttention backward implementation in example_amd_flash_a…
Sep 20, 2025
3b839d2
Refactor configuration and tensor operations in example_amd_flash_att…
Sep 30, 2025
4c11021
Merge remote-tracking branch 'upstream/main'
Sep 30, 2025
bc9a5fb
Enhance HIP code generation and FP8 type support
Sep 30, 2025
dd5b64f
Enhance FP8 type support and clarify accumulator handling in HIP
Sep 30, 2025
42e5538
Remove deprecated files and update print statements for clarity in ex…
Oct 10, 2025
9d53c8a
Update print statement formatting for clarity in example_amd_flash_at…
Oct 10, 2025
cd3b6b5
Remove redundant verification results summary print statement in exam…
Oct 10, 2025
3072de6
Fix formatting inconsistencies in example_amd_flash_attn_bwd.py and e…
Oct 10, 2025
1913abb
Refactor and enhance HIP code generation for improved FP8 support
Oct 10, 2025
acaf988
Fix formatting issue in HIP code generation for MFMA call
Oct 10, 2025
4bc49cd
Refactor HIP code generation and enhance FP8 type handling
Oct 10, 2025
ae39e35
Merge branch 'main' into main
Alex4210987 Oct 10, 2025
0c0fa53
Remove unnecessary blank line in example_amd_flash_attn_bwd.py for im…
Oct 10, 2025
9a4a08f
Merge branch 'main' of https://github.com/Alex4210987/tilelang
Oct 10, 2025
8b345ae
Refactor backward attention implementation in example_amd_flash_attn_…
Oct 10, 2025
c34315c
Fix formatting by removing an unnecessary blank line in example_amd_f…
Oct 10, 2025
cd1564d
Merge branch 'main' of https://github.com/tile-ai/tilelang into Alex4…
LeiWang1999 Oct 14, 2025
426df21
Merge branch 'tile-ai:main' into main
Alex4210987 Oct 15, 2025
40eabcd
Add additional test cases for `assert_tl_matmul_correctness` with `fl…
Oct 15, 2025
5b6bcaa
Refactor test case formatting for `assert_tl_matmul_correctness` in `…
Oct 15, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
810 changes: 525 additions & 285 deletions examples/amd/example_amd_flash_attn_bwd.py

Large diffs are not rendered by default.

20 changes: 4 additions & 16 deletions examples/amd/example_amd_flash_attn_fwd.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def get_configs():
block_N = [32, 64, 128, 256]
threads = [128, 256, 512]
num_split_q = [64, 128, 256]
num_stages = [0]
num_stages = [0, 1]
enable_rasterization = [True]
k_pack = [2]
panel_size = [7, 8]
Expand All @@ -60,18 +60,6 @@ def get_configs():
"qk_coalesced_width": qkw,
"v_coalesced_width": vw,
})
valid_configs.append({
'block_M': 64,
'block_N': 64,
'num_split_q': 64,
'threads': 256,
'num_stages': 1,
'enable_rasterization': True,
'k_pack': 2,
'panel_size': 64,
'qk_coalesced_width': 8,
'v_coalesced_width': 8,
})
return valid_configs


Expand All @@ -95,7 +83,7 @@ def fast_flashattn(
qk_coalesced_width: int,
v_coalesced_width: int,
):
scale = (1.0 / dim)**0.5 * 1.44269504
scale = (1.0 / dim)**0.5
head_kv = heads // groups
q_shape = [batch, seq_len, heads, dim]
kv_shape = [batch, seq_len, head_kv, dim]
Expand Down Expand Up @@ -185,15 +173,15 @@ def main(
T.reduce_max(acc_s, m_i, dim=1, clear=False)

for i in T.Parallel(block_M):
sf = T.exp2(m_prev[i] * scale - m_i[i] * scale)
sf = T.exp(m_prev[i] * scale - m_i[i] * scale)
l_i[i] *= sf
scale_factor[i] = sf

for i, j in T.Parallel(block_M, dim):
acc_o[i, j] *= scale_factor[i]

for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.exp2(acc_s[i, j] * scale - m_i[i] * scale)
acc_s[i, j] = T.exp(acc_s[i, j] * scale - m_i[i] * scale)

T.reduce_sum(acc_s, row_sum, dim=1)
for i in T.Parallel(block_M):
Expand Down
10 changes: 0 additions & 10 deletions examples/amd/test.sh

This file was deleted.

7 changes: 0 additions & 7 deletions examples/flash_attention/example_mha_bwd.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,10 @@ def flash_fwd(
scores_sum = T.alloc_fragment([block_M], accum_dtype)
logsum = T.alloc_fragment([block_M], accum_dtype)

T.annotate_layout({Q_shared: tilelang.layout.make_swizzled_layout(Q_shared)})
T.copy(Q[bz, bx * block_M:(bx + 1) * block_M, by, :], Q_shared)
T.fill(acc_o, 0)
T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype))
# T.copy(Q_shared, Q_local)
# for i, j in T.Parallel(block_M, dim):
# Q_local[i, j] *= scale
loop_range = (
T.ceildiv(
(bx + 1) * block_M, block_N) if is_causal else T.ceildiv(seq_len, block_N))
Expand Down Expand Up @@ -192,9 +188,6 @@ def flash_bwd(

T.annotate_layout({
dQ: make_dq_layout(dQ),
K_shared: tilelang.layout.make_swizzled_layout(K_shared),
dv_shared: tilelang.layout.make_swizzled_layout(dv_shared),
dk_shared: tilelang.layout.make_swizzled_layout(dk_shared),
})
T.copy(K[bz, by * block_M:(by + 1) * block_M, bx, :], K_shared)
T.copy(V[bz, by * block_M:(by + 1) * block_M, bx, :], V_shared)
Expand Down
28 changes: 22 additions & 6 deletions src/target/codegen_hip.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,18 @@ static std::string GetFP8Type(DataType type) {
stream << "fp8_e4" << vec << "_t";
} else if (type.code() == DataType::kFloat8_e4m3fnuz) {
stream << "fp8_e4" << vec << "_t";
} else if (type.code() == DataType::kFloat8_e4m3) {
stream << "fp8_e4" << vec << "_t";
} else if (type.code() == DataType::kFloat8_e4m3b11fnuz) {
stream << "fp8_e4" << vec << "_t";
} else if (type.code() == DataType::kFloat8_e5m2) {
stream << "fp8_e5" << vec << "_t";
} else if (type.code() == DataType::kFloat8_e5m2fnuz) {
stream << "fp8_e5" << vec << "_t";
} else if (type.code() == DataType::kFloat8_e8m0fnu) {
stream << "fp8_e8" << vec << "_t";
} else {
LOG(FATAL) << "Unsupported FP8 type in HIP codegen";
LOG(FATAL) << "Unsupported FP8 type in HIP codegen: " << type;
}
return stream.str();
}
Expand Down Expand Up @@ -926,10 +934,10 @@ void CodeGenTileLangHIP::VisitExpr_(const CallNode *op, std::ostream &os) {
{"float8_e4m3fnuzx8", "long"},
{"float32x16", "float32x16"}};
std::string call_mfma_code = R"({
*((({C_dtype}*){c_ref}) + {c_bias}) = {mfma_buildin}(*((({A_dtype}*){a_ref}) + {a_bias}),
*((({B_dtype}*){b_ref}) + {b_bias}),
*((({C_dtype}*){c_ref}) + {c_bias}), 0, 0, 0);
})";
*((({C_dtype}*){c_ref}) + {c_bias}) = {mfma_buildin}(*((({A_dtype}*){a_ref}) + {a_bias}),
*((({B_dtype}*){b_ref}) + {b_bias}),
*((({C_dtype}*){c_ref}) + {c_bias}), 0, 0, 0);
})";
std::string mfma_buildin = "__builtin_amdgcn_mfma_" + prefix;
Replacer replacer;

Expand All @@ -955,6 +963,13 @@ void CodeGenTileLangHIP::VisitExpr_(const CallNode *op, std::ostream &os) {
op->args, true, os);
} else if (op->op.same_as(tl::tl_gemm_sp())) {
LOG(FATAL) << "tl_gemm_sp is not supported on HIP";
} else if (op->op.same_as(tl::loop_break())) {
this->PrintIndent();
this->stream << "break;\n";
} else if (op->op.same_as(tl::no_set_max_nreg())) {
// HIP doesn't need explicit register management like CUDA
// This is a no-op for HIP
return;
} else {
CodeGenC::VisitExpr_(op, os);
}
Expand Down Expand Up @@ -1160,7 +1175,8 @@ inline void PrintConst(const FloatImmNode *op, std::ostream &os,
os << "bfloat16_t";
os << '(' << std::scientific << op->value << 'f' << ')';
return;
} else if (op->dtype.is_float8_e4m3fnuz()) {
} else if (op->dtype.is_float8_e4m3fnuz() || op->dtype.is_float8_e4m3() ||
op->dtype.is_float8_e4m3fn()) {
os << "fp8_e4_t";
os << '(' << std::scientific << op->value << 'f' << ')';
return;
Expand Down
10 changes: 10 additions & 0 deletions src/tl_templates/hip/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -109,3 +109,13 @@ template <typename T1, typename T2>
TL_DEVICE void AtomicAdd(T1 *address, T2 val) {
atomicAdd(reinterpret_cast<T1 *>(address), static_cast<T1>(val));
}

// Overload for when the first argument is a value instead of a pointer
template <typename T1, typename T2>
TL_DEVICE void AtomicAdd(T1 address, T2 val) {
atomicAdd(reinterpret_cast<T1 *>(&address), static_cast<T1>(val));
}
Comment on lines +113 to +117
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Critical logic error: taking address of pass-by-value parameter.

This overload takes T1 address by value and then uses &address to get a pointer. This retrieves the address of the local copy created for the function parameter, not the address of the caller's original variable. The atomic operation would be performed on the local copy, which is destroyed when the function returns, making the atomic operation completely ineffective.

Either remove this overload entirely if it's not needed, or if the intent was to support non-pointer types, the signature should be AtomicAdd(T1 &address, T2 val) (taking a reference).

Apply this diff to fix the signature to take a reference:

 // Overload for when the first argument is a value instead of a pointer
 template <typename T1, typename T2>
-TL_DEVICE void AtomicAdd(T1 address, T2 val) {
-  atomicAdd(reinterpret_cast<T1 *>(&address), static_cast<T1>(val));
+TL_DEVICE void AtomicAdd(T1 &address, T2 val) {
+  atomicAdd(&address, static_cast<T1>(val));
 }
🤖 Prompt for AI Agents
In src/tl_templates/hip/common.h around lines 113 to 117, the AtomicAdd overload
takes its first parameter by value and then uses &address, which takes the
address of the local copy instead of the caller's variable; change the function
signature to take the address parameter by reference (e.g., T1 &address) so
atomicAdd operates on the caller's memory (or remove the overload if not
needed), and update any callers if necessary to pass lvalues.


template <typename T1, typename T2> TL_DEVICE T1 AtomicAddRet(T1 &ref, T2 val) {
return atomicAdd(&ref, static_cast<T1>(val));
}
Comment on lines +113 to +121
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Invalid AtomicAdd overload takes a value; atomics must target real memory

AtomicAdd(T1 address, T2 val) atomically updates a local copy via &address, not the intended memory; this is incorrect and can cause undefined behavior on HIP.

Remove this overload and, if needed, provide a reference overload (non‑returning) alongside AtomicAddRet:

-// Overload for when the first argument is a value instead of a pointer
-template <typename T1, typename T2>
-TL_DEVICE void AtomicAdd(T1 address, T2 val) {
-  atomicAdd(reinterpret_cast<T1 *>(&address), static_cast<T1>(val));
-}
+// Preferred non-returning reference overload (mirrors AtomicAddRet)
+template <typename T1, typename T2>
+TL_DEVICE void AtomicAdd(T1& ref, T2 val) {
+  atomicAdd(&ref, static_cast<T1>(val));
+}

Optionally constrain T1 to HIP‑supported atomic types via static_assert.

🤖 Prompt for AI Agents
In src/tl_templates/hip/common.h around lines 113 to 121, the overload
AtomicAdd(T1 address, T2 val) takes the address by value and does atomicAdd on
&address (a local copy) which is invalid; remove this value-taking overload,
replace with a non-returning reference overload AtomicAdd(T1 &address, T2 val)
that calls atomicAdd(&address, static_cast<T1>(val)), and (optionally) add a
static_assert or type trait to constrain T1 to HIP-supported atomic types to
catch misuse at compile time.

4 changes: 3 additions & 1 deletion src/tl_templates/hip/gemm.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,9 @@ template <int M, int N, int K, int num_warp_m, int num_warp_n, bool TransposeA,
typename B_type, typename C_type, typename AccDataType = float>
class GemmTensorOp {
public:
static_assert(!clear_accum, "clear_accum=true is not supported yet");
// Note: clear_accum=true is not fully supported in HIP implementation
// but we'll handle it by manually clearing the accumulator
// static_assert(!clear_accum, "clear_accum=true is not supported yet");
Comment on lines +73 to +75
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

Verify that clear_accum=true is properly handled.

The static_assert has been removed and replaced with a comment stating that clear_accum will be "manually cleared." However, examining the body() (lines 158-232) and body_rs() (lines 234-291) methods, there is no visible code that clears the accumulator when clear_accum=true.

If clear_accum=true is passed, the accumulator should be zeroed before accumulation begins, but this implementation appears to rely on the caller to pre-zero C_local.

Run the following script to check if clear_accum is handled elsewhere or if callers are responsible for zeroing:


🏁 Script executed:

#!/bin/bash
# Search for clear_accum usage patterns and accumulator initialization
rg -n -C5 'clear_accum.*true|C_local.*=.*0|memset.*C_local' --type=cpp --type=h

Length of output: 3003


Restore or implement clear_accum handling
Either re-enable static_assert(!clear_accum) on lines 73–75 or add code to zero C_local when clear_accum=true, as it’s currently unhandled.

🤖 Prompt for AI Agents
In src/tl_templates/hip/gemm.h around lines 73–75, restore handling for the
clear_accum template flag: either re-enable the static_assert(!clear_accum) to
fail compilation when clear_accum=true, or implement the runtime/compile-time
zeroing of C_local when clear_accum==true. If choosing static_assert, add the
assert back exactly for the HIP path; if choosing implementation, add a guarded
branch (compile-time if constexpr or preprocessor guarded by the template param)
that writes zeros into the accumulator buffer C_local before accumulation
begins, ensuring the zeroing matches the accumulator tile shape and memory
layout and does not break existing non-clear behavior.

Comment on lines +73 to +75
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Do not drop support guard without implementing clear_accum behavior

Removing the static_assert enables clear_accum=true but no code zeros accumulators, yielding nondeterministic results.

Add zeroing when clear_accum is true. Example insertion at the start of body() and body_rs() before the ki loop:

if constexpr (clear_accum) {
  for (int i = 0; i < warp_rows; ++i) {
    for (int j = 0; j < warp_cols; ++j) {
      ((float32x4*)C_local)[i * warp_cols + j] = 0;
    }
  }
}

Optionally re‑enable the static_assert if you cannot guarantee correct zeroing on all code paths.

🤖 Prompt for AI Agents
In src/tl_templates/hip/gemm.h around lines 73-75, the static_assert for
clear_accum was removed but no accumulator zeroing was added, so enabling
clear_accum=true yields nondeterministic results; add a constexpr guard at the
start of both body() and body_rs() (before the ki loop) that zeroes the per-warp
accumulator memory when clear_accum is true (iterate warp_rows and warp_cols and
set the corresponding C_local entries to zero, e.g., by casting C_local to the
appropriate vector type and writing zeros), and if you cannot guarantee zeroing
on all code paths re-enable the static_assert to prevent enabling clear_accum
without proper initialization.


static constexpr int micro_size_x = 16;
static constexpr int micro_size_y = 16;
Expand Down
55 changes: 55 additions & 0 deletions src/tl_templates/hip/hip_fp8.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,13 @@
using fp8_e4_t = __hip_fp8_e4m3_fnuz;
using fp8_e4_2_t = __hip_fp8x2_e4m3_fnuz;

// Additional FP8 types for compatibility
using fp8_e5_t = __hip_fp8_e5m2_fnuz;
using fp8_e5_2_t = __hip_fp8x2_e5m2_fnuz;
// Note: E8M0 types are not supported in current HIP version
// using fp8_e8_t = __hip_fp8_e8m0_fnuz;
// using fp8_e8_2_t = __hip_fp8x2_e8m0_fnuz;

// Simple wrapper that provides member access for generated code
struct fp8_e4_4_t {
union {
Expand Down Expand Up @@ -43,6 +50,54 @@ struct __align__(16) fp8_e4_16_t {
fp8_e4_8_t y;
};

// FP8 E5M2 vector types
struct fp8_e5_4_t {
union {
__hip_fp8x4_e5m2_fnuz data;
struct {
fp8_e5_t x, y, z, w;
};
};
__device__ fp8_e5_4_t() = default;
__device__ fp8_e5_4_t(const __hip_fp8x4_e5m2_fnuz &val) : data(val) {}
__device__ operator __hip_fp8x4_e5m2_fnuz() const { return data; }
};

struct __align__(8) fp8_e5_8_t {
fp8_e5_4_t x;
fp8_e5_4_t y;
};

struct __align__(16) fp8_e5_16_t {
fp8_e5_8_t x;
fp8_e5_8_t y;
};

// FP8 E8M0 vector types - not supported in current HIP version
/*
struct fp8_e8_4_t {
union {
__hip_fp8x4_e8m0_fnuz data;
struct {
fp8_e8_t x, y, z, w;
};
};
__device__ fp8_e8_4_t() = default;
__device__ fp8_e8_4_t(const __hip_fp8x4_e8m0_fnuz &val) : data(val) {}
__device__ operator __hip_fp8x4_e8m0_fnuz() const { return data; }
};

struct __align__(8) fp8_e8_8_t {
fp8_e8_4_t x;
fp8_e8_4_t y;
};

struct __align__(16) fp8_e8_16_t {
fp8_e8_8_t x;
fp8_e8_8_t y;
};
*/

__device__ fp8_e4_4_t make_fp8_e4_4_t(fp8_e4_t x, fp8_e4_t y, fp8_e4_t z,
fp8_e4_t w) {
// reinterpret the 4 fp8_e4_t values to signed char value and shift
Expand Down
6 changes: 6 additions & 0 deletions testing/python/amd/test_tilelang_gemm_mfma_intrinsic.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,12 @@ def test_assert_tl_matmul():
128, 256, 256, "int8", "int32", b_transposed=False, accum_dtype="int32")
assert_tl_matmul_correctness(
128, 256, 256, "int8", "int32", b_transposed=False, accum_dtype="int32", k_pack=2)
assert_tl_matmul_correctness(128, 128, 128, "float8_e4m3fnuz", "float16")
assert_tl_matmul_correctness(128, 256, 256, "float8_e4m3fnuz", "float32")
assert_tl_matmul_correctness(128, 256, 256, "float8_e4m3fnuz", "float32", k_pack=2)
assert_tl_matmul_correctness(128, 256, 256, "float8_e4m3fnuz", "float32", b_transposed=False)
assert_tl_matmul_correctness(
128, 256, 256, "float8_e4m3fnuz", "float32", b_transposed=False, k_pack=2)


if __name__ == "__main__":
Expand Down