Skip to content

Commit d57a3b0

Browse files
authored
[Misc.] Scale kkt after reduction
1 parent 10a640b commit d57a3b0

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

fla/ops/common/chunk_scaled_dot_kkt.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,14 +58,14 @@ def chunk_scaled_dot_kkt_fwd_kernel(
5858
for i_k in range(tl.cdiv(K, BK)):
5959
p_k = tl.make_block_ptr(k + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
6060
b_k = tl.load(p_k, boundary_check=(0, 1))
61-
b_kb = b_k * b_beta[:, None]
62-
b_A += tl.dot(b_kb.to(b_k.dtype), tl.trans(b_k))
61+
b_A += tl.dot(b_k, tl.trans(b_k))
6362

6463
if USE_G:
6564
p_g = tl.make_block_ptr(g_cumsum + bos*H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,))
6665
b_g = tl.load(p_g, boundary_check=(0,))
6766
b_g_diff = b_g[:, None] - b_g[None, :]
6867
b_A = b_A * exp(b_g_diff)
68+
b_A = b_A * b_beta[:, None]
6969

7070
m_A = (o_t[:, None] > o_t[None, :]) & (m_t[:, None] & m_t)
7171
b_A = tl.where(m_A, b_A, 0)

0 commit comments

Comments
 (0)