@@ -89,10 +89,9 @@ def __init__(
8989 self .use_full_cuda_graph = (
9090 self .compilation_config .cudagraph_mode .has_full_cudagraphs ()
9191 )
92+ self .max_cudagraph_size = self .compilation_config .max_capture_size
9293
9394 if self .use_full_cuda_graph and self .fa_aot_schedule :
94- self .max_cudagraph_size = self .compilation_config .max_capture_size
95-
9695 if self .max_cudagraph_size > 992 :
9796 # This condition derives from FA3's internal heuristic.
9897 # TODO(woosuk): Support larger cudagraph sizes.
@@ -114,7 +113,14 @@ def __init__(
114113 self .max_num_splits = 1
115114
116115 def _schedule_decode (
117- self , num_reqs , cu_query_lens , max_query_len , seqlens , max_seq_len , causal
116+ self ,
117+ num_reqs ,
118+ cu_query_lens ,
119+ max_query_len ,
120+ seqlens ,
121+ max_seq_len ,
122+ causal ,
123+ max_num_splits ,
118124 ):
119125 if self .fa_aot_schedule :
120126 return get_scheduler_metadata (
@@ -130,7 +136,7 @@ def _schedule_decode(
130136 page_size = self .page_size ,
131137 cu_seqlens_q = cu_query_lens ,
132138 causal = causal ,
133- num_splits = self . max_num_splits ,
139+ num_splits = max_num_splits ,
134140 )
135141 return None
136142
@@ -148,17 +154,25 @@ def _build_decode(
148154 max_query_len = query_lens_cpu .max ().item ()
149155 max_seq_len = seq_lens_device .max ().item ()
150156
157+ # For Flash Attention MLA + full cudagraph
158+ max_num_splits = 0
159+ if self .use_full_cuda_graph and num_decode_tokens <= self .max_cudagraph_size :
160+ # NOTE(woosuk): Setting num_splits > 1 may increase the memory
161+ # usage, because the intermediate buffers of size [num_splits,
162+ # num_heads, num_tokens, head_size] are allocated. Therefore,
163+ # we only set num_splits when using cuda graphs.
164+ max_num_splits = self .max_num_splits
165+
151166 scheduler_metadata = self ._schedule_decode (
152167 num_reqs = seq_lens_cpu .numel (),
153168 cu_query_lens = query_start_loc_device ,
154169 max_query_len = max_query_len ,
155170 seqlens = seq_lens_device ,
156171 max_seq_len = max_seq_len ,
157172 causal = True ,
173+ max_num_splits = max_num_splits ,
158174 )
159175
160- # For FA3 + full cudagraph
161- max_num_splits = 0
162176 if self .use_full_cuda_graph and scheduler_metadata is not None :
163177 n = scheduler_metadata .shape [0 ]
164178 # Ensure the persistent buffer is large enough
@@ -174,13 +188,6 @@ def _build_decode(
174188 self .scheduler_metadata [n :] = 0
175189 scheduler_metadata = self .scheduler_metadata [:n ]
176190
177- if num_decode_tokens <= self .max_cudagraph_size :
178- # NOTE(woosuk): Setting num_splits > 1 may increase the memory
179- # usage, because the intermediate buffers of size [num_splits,
180- # num_heads, num_tokens, head_size] are allocated. Therefore,
181- # we only set num_splits when using cuda graphs.
182- max_num_splits = self .max_num_splits
183-
184191 if vllm_is_batch_invariant ():
185192 max_num_splits = 1
186193
0 commit comments