Skip to content

Commit 98f93db

Browse files
[Bugfix] Remove redundant T.fill to fix precision issue (#667)
1 parent 722c2a8 commit 98f93db

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

examples/flash_decoding/example_gqa_decode.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@ def flash_attn_split(
169169
T.fill(scores_max, -T.infinity(accum_dtype))
170170

171171
loop_range = T.ceildiv((seqlen_kv // num_split), block_N)
172-
T.fill(K_shared, 0)
172+
173173
for k in T.Pipelined(loop_range, num_stages=num_stages):
174174
T.copy(
175175
K[bid, (seqlen_kv // num_split) * sid +

0 commit comments

Comments
 (0)