Skip to content

Commit c28f9bc

Browse files
sleepcooFlamingoPgHongbosherlock
committed
Support FlashMLA backend cuda graph capture
Co-authored-by: yinfan98 <1106310035@qq.com> Co-authored-by: Hongbosherlock <hongbosherlock@gmail.com>
1 parent 4649c5e commit c28f9bc

File tree

1 file changed

+3
-8
lines changed

1 file changed

+3
-8
lines changed

python/sglang/srt/layers/attention/flashmla_backend.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from sglang.srt.layers.radix_attention import RadixAttention
2828
from sglang.srt.model_executor.model_runner import ModelRunner
2929
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
30+
from sglang.srt.speculative.spec_info import SpecInfo
3031

3132

3233
# FlashMLA only supports pagesize=64
@@ -76,9 +77,7 @@ def __init__(
7677
self.num_local_heads = (
7778
model_runner.model_config.num_attention_heads // get_attention_tp_size()
7879
)
79-
self.forward_metadata: Union[
80-
PrefillMetadata, DecodeMetadata, FlashMLADecodeMetadata
81-
] = None
80+
self.forward_metadata: Union[FlashMLADecodeMetadata] = None
8281
self.kv_lora_rank = model_runner.model_config.kv_lora_rank
8382
self.qk_nope_head_dim = model_runner.model_config.qk_nope_head_dim
8483
self.qk_rope_head_dim = model_runner.model_config.qk_rope_head_dim
@@ -111,7 +110,6 @@ def init_forward_metadata(self, forward_batch: ForwardBatch):
111110
block_kv_indices,
112111
self.indices_updater_decode.req_to_token.size(1),
113112
max_seqlen_pad,
114-
max_seqlen_pad,
115113
)
116114
mla_metadata, num_splits = get_mla_metadata(
117115
forward_batch.seq_lens.to(torch.int32),
@@ -136,7 +134,7 @@ def init_cuda_graph_state(
136134
if block_kv_indices is None:
137135
cuda_graph_kv_indices = torch.full(
138136
(max_bs, (self.max_context_len + PAGE_SIZE) // PAGE_SIZE),
139-
-1,
137+
1,
140138
dtype=torch.int32,
141139
device="cuda",
142140
)
@@ -167,7 +165,6 @@ def init_forward_metadata_capture_cuda_graph(
167165
):
168166
if forward_mode.is_decode_or_idle():
169167
if spec_info is None:
170-
171168
max_seqlen_pad = triton.cdiv(seq_lens.max().item(), PAGE_SIZE)
172169

173170
create_flashmla_kv_indices_triton[(bs,)](
@@ -178,7 +175,6 @@ def init_forward_metadata_capture_cuda_graph(
178175
self.cuda_graph_kv_indices,
179176
self.indices_updater_decode.req_to_token.size(1),
180177
max_seqlen_pad,
181-
max_seqlen_pad,
182178
)
183179
mla_metadata, num_splits = get_mla_metadata(
184180
seq_lens.to(torch.int32),
@@ -227,7 +223,6 @@ def init_forward_metadata_replay_cuda_graph(
227223
block_kv_indices,
228224
self.indices_updater_decode.req_to_token.size(1),
229225
max_seqlen_pad,
230-
max_seqlen_pad,
231226
)
232227
mla_metadata, num_splits = get_mla_metadata(
233228
seq_lens.to(torch.int32),

0 commit comments

Comments
 (0)