Skip to content

Commit faae7a7

Browse files
alexm-redhatyewentao256
authored andcommitted
[Bugfix] [B200] cutlass_mla - ensure kv_split == 1 for batch size > 1 (#25509)
Signed-off-by: Alexander Matveev <amatveev@redhat.com> Signed-off-by: yewentao256 <zhyanwentao@126.com>
1 parent d562c2e commit faae7a7

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

csrc/attention/mla/cutlass_sm100_mla/device/sm100_mla.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -135,10 +135,10 @@ class MLA {
135135
max_splits = min(16, max_splits);
136136

137137
// TODO: This avoids a hang when the batch size larger than 1 and
138-
// there is more than 4 kv_splits.
138+
// there is more than 1 kv_splits.
139139
// Discuss with NVIDIA how this can be fixed.
140140
if (B > 1) {
141-
max_splits = min(2, max_splits);
141+
max_splits = min(1, max_splits);
142142
}
143143

144144
// printf(" max_splits = %d\n", max_splits);

0 commit comments

Comments
 (0)