2727 from sglang .srt .layers .radix_attention import RadixAttention
2828 from sglang .srt .model_executor .model_runner import ModelRunner
2929 from sglang .srt .speculative .eagle_utils import EagleDraftInput , EagleVerifyInput
30+ from sglang .srt .speculative .spec_info import SpecInfo
3031
3132
3233# FlashMLA only supports pagesize=64
@@ -76,9 +77,7 @@ def __init__(
7677 self .num_local_heads = (
7778 model_runner .model_config .num_attention_heads // get_attention_tp_size ()
7879 )
79- self .forward_metadata : Union [
80- PrefillMetadata , DecodeMetadata , FlashMLADecodeMetadata
81- ] = None
80+ self .forward_metadata : Union [FlashMLADecodeMetadata ] = None
8281 self .kv_lora_rank = model_runner .model_config .kv_lora_rank
8382 self .qk_nope_head_dim = model_runner .model_config .qk_nope_head_dim
8483 self .qk_rope_head_dim = model_runner .model_config .qk_rope_head_dim
@@ -111,7 +110,6 @@ def init_forward_metadata(self, forward_batch: ForwardBatch):
111110 block_kv_indices ,
112111 self .indices_updater_decode .req_to_token .size (1 ),
113112 max_seqlen_pad ,
114- max_seqlen_pad ,
115113 )
116114 mla_metadata , num_splits = get_mla_metadata (
117115 forward_batch .seq_lens .to (torch .int32 ),
@@ -136,7 +134,7 @@ def init_cuda_graph_state(
136134 if block_kv_indices is None :
137135 cuda_graph_kv_indices = torch .full (
138136 (max_bs , (self .max_context_len + PAGE_SIZE ) // PAGE_SIZE ),
139- - 1 ,
137+ 1 ,
140138 dtype = torch .int32 ,
141139 device = "cuda" ,
142140 )
@@ -167,7 +165,6 @@ def init_forward_metadata_capture_cuda_graph(
167165 ):
168166 if forward_mode .is_decode_or_idle ():
169167 if spec_info is None :
170-
171168 max_seqlen_pad = triton .cdiv (seq_lens .max ().item (), PAGE_SIZE )
172169
173170 create_flashmla_kv_indices_triton [(bs ,)](
@@ -178,7 +175,6 @@ def init_forward_metadata_capture_cuda_graph(
178175 self .cuda_graph_kv_indices ,
179176 self .indices_updater_decode .req_to_token .size (1 ),
180177 max_seqlen_pad ,
181- max_seqlen_pad ,
182178 )
183179 mla_metadata , num_splits = get_mla_metadata (
184180 seq_lens .to (torch .int32 ),
@@ -227,7 +223,6 @@ def init_forward_metadata_replay_cuda_graph(
227223 block_kv_indices,
228224 self.indices_updater_decode.req_to_token.size(1),
229225 max_seqlen_pad,
230- max_seqlen_pad,
231226 )
232227 mla_metadata, num_splits = get_mla_metadata(
233228 seq_lens.to(torch.int32),
0 commit comments