55import itertools
66import time
77from collections import defaultdict
8+ from dataclasses import dataclass
89from collections .abc import Iterator
910from contextlib import contextmanager
1011from copy import deepcopy
@@ -509,6 +510,8 @@ def __init__(
509510 # Cached outputs.
510511 self ._deferred_write_manager = DeferredWriteManager (mode = envs .VLLM_NWOR_MODE )
511512 self ._latest_nwor_window_metrics : dict [str , int | str ] | None = None
513+ self ._scv_mode = envs .VLLM_SCV_MODE .lower ()
514+ self ._scv_graph_executor : SCVGraphExecutor | None = None
512515 self ._draft_token_ids : list [list [int ]] | torch .Tensor | None = None
513516 self .transfer_event = torch .cuda .Event ()
514517 self .sampled_token_ids_pinned_cpu = torch .empty (
@@ -518,6 +521,14 @@ def __init__(
518521 pin_memory = self .pin_memory ,
519522 )
520523
524+ def _scv_enabled (self ) -> bool :
525+ if not hasattr (self , "_scv_mode" ):
526+ self ._scv_mode = envs .VLLM_SCV_MODE .lower ()
527+ if self ._scv_mode not in ("off" , "graph" , "adaptive" ):
528+ logger .warning ("SCV: unsupported mode '%s', disabling." , self ._scv_mode )
529+ self ._scv_mode = "off"
530+ return self ._scv_mode != "off"
531+
521532 def reset_mm_cache (self ) -> None :
522533 if self .mm_budget :
523534 self .mm_budget .reset_cache ()
@@ -2316,6 +2327,15 @@ def _build_nwor_acceptance_mask(
23162327 target_device = spec_decode_metadata .draft_token_ids .device
23172328 work_device = sampled_token_ids .device
23182329
2330+ if self ._scv_enabled ():
2331+ mask = self ._scv_vectorized_mask (
2332+ spec_decode_metadata , sampled_token_ids , total_tokens , work_device
2333+ )
2334+ if mask is not None :
2335+ if mask .device != target_device :
2336+ mask = mask .to (device = target_device )
2337+ return mask
2338+
23192339 draft_ids = spec_decode_metadata .draft_token_ids
23202340 if draft_ids .device != work_device :
23212341 draft_ids = draft_ids .to (device = work_device )
@@ -2336,16 +2356,9 @@ def _build_nwor_acceptance_mask(
23362356 row = row .to (dtype = draft_ids .dtype )
23372357
23382358 draft_slice = draft_ids [start :end ]
2339- comparison = (row == draft_slice ).flatten ()
2340-
2341- if bool (comparison .all ().item ()):
2342- accepted = draft_count
2343- else :
2344- reject = torch .nonzero (~ comparison , as_tuple = False )
2345- accepted = int (reject [0 , 0 ].item ()) if reject .numel () > 0 else draft_count
2346-
2347- if accepted > 0 :
2348- mask_work [start : start + accepted ] = True
2359+ comparison = (row == draft_slice )
2360+ prefix = torch .cumprod (comparison .to (torch .int32 ), dim = 0 )
2361+ mask_work [start :end ] = prefix .to (torch .bool )
23492362 start = end
23502363
23512364 if start != total_tokens :
@@ -2355,6 +2368,130 @@ def _build_nwor_acceptance_mask(
23552368 return mask_work
23562369 return mask_work .to (device = target_device )
23572370
2371+ def _scv_vectorized_mask (
2372+ self ,
2373+ spec_decode_metadata : SpecDecodeMetadata ,
2374+ sampled_token_ids : torch .Tensor ,
2375+ total_tokens : int ,
2376+ device : torch .device ,
2377+ ) -> torch .Tensor | None :
2378+ draft_ids = spec_decode_metadata .draft_token_ids
2379+ max_spec_len = spec_decode_metadata .max_spec_len
2380+ num_draft_tensor = torch .tensor (
2381+ spec_decode_metadata .num_draft_tokens ,
2382+ device = device ,
2383+ dtype = torch .int32 ,
2384+ )
2385+ if draft_ids .device != device :
2386+ draft_ids = draft_ids .to (device = device )
2387+
2388+ cu = spec_decode_metadata .cu_num_draft_tokens .to (device = device )
2389+
2390+ if hasattr (self , "_scv_mode" ) and self ._scv_mode == "graph" :
2391+ executor = getattr (self , "_scv_graph_executor" , None )
2392+ if executor is None :
2393+ executor = SCVGraphExecutor (device )
2394+ self ._scv_graph_executor = executor
2395+ mask = executor .run (
2396+ spec_decode_metadata , sampled_token_ids , total_tokens
2397+ )
2398+ if mask is not None :
2399+ return mask
2400+
2401+ if hasattr (self , "_scv_mode" ) and self ._scv_mode == "adaptive" :
2402+ mask = self ._scv_compute_mask (
2403+ draft_ids ,
2404+ num_draft_tensor ,
2405+ cu ,
2406+ sampled_token_ids ,
2407+ max_spec_len ,
2408+ total_tokens ,
2409+ )
2410+ self ._scv_update_controller (spec_decode_metadata , mask )
2411+ return mask
2412+
2413+ mask = self ._scv_compute_mask (
2414+ draft_ids ,
2415+ num_draft_tensor ,
2416+ cu ,
2417+ sampled_token_ids ,
2418+ max_spec_len ,
2419+ total_tokens ,
2420+ )
2421+ return mask
2422+
2423+ @staticmethod
2424+ def _scv_compute_mask (
2425+ draft_ids : torch .Tensor ,
2426+ num_draft_tokens : torch .Tensor ,
2427+ cu_num_draft_tokens : torch .Tensor ,
2428+ sampled_token_ids : torch .Tensor ,
2429+ max_spec_len : int ,
2430+ total_tokens : int ,
2431+ ) -> torch .Tensor :
2432+ device = draft_ids .device
2433+ indices = torch .arange (total_tokens , device = device , dtype = torch .int32 )
2434+ req_idx = torch .bucketize (indices , cu_num_draft_tokens )
2435+ prev_cu = torch .cat ([cu_num_draft_tokens .new_zeros (1 ), cu_num_draft_tokens [:- 1 ]])
2436+ pos_in_req = indices - prev_cu [req_idx ]
2437+
2438+ gathered = sampled_token_ids [req_idx , pos_in_req ]
2439+ comparison = gathered == draft_ids
2440+
2441+ max_val = max_spec_len + 1
2442+ values = torch .where (
2443+ ~ comparison ,
2444+ (pos_in_req + 1 ).to (torch .int32 ),
2445+ torch .full_like (pos_in_req , max_val , dtype = torch .int32 ),
2446+ )
2447+
2448+ accepted = torch .full (
2449+ (num_draft_tokens .numel (),),
2450+ max_val ,
2451+ device = device ,
2452+ dtype = torch .int32 ,
2453+ )
2454+ accepted .scatter_reduce_ (0 , req_idx , values , reduce = "amin" )
2455+ accepted = torch .where (
2456+ accepted == max_val ,
2457+ num_draft_tokens ,
2458+ accepted - 1 ,
2459+ )
2460+ accepted_broadcast = accepted [req_idx ]
2461+ mask_flat = pos_in_req < accepted_broadcast
2462+ return mask_flat
2463+
2464+ def _scv_update_controller (
2465+ self ,
2466+ spec_decode_metadata : SpecDecodeMetadata ,
2467+ mask : torch .Tensor ,
2468+ ) -> None :
2469+ target_ratio = 0.6
2470+ alpha = 0.2
2471+ accepted = int (mask .sum ().item ())
2472+ total = max (mask .numel (), 1 )
2473+ ratio = accepted / total
2474+ prev = getattr (self , "_scv_accept_ratio" , target_ratio )
2475+ new_ratio = (1 - alpha ) * prev + alpha * ratio
2476+ self ._scv_accept_ratio = new_ratio
2477+
2478+ speculative_config = getattr (self , "speculative_config" , None )
2479+ if speculative_config is None or not hasattr (speculative_config , "num_speculative_tokens" ):
2480+ return
2481+
2482+ base_k = speculative_config .num_speculative_tokens
2483+ k_min = max (1 , base_k // 4 )
2484+ k_max = max (1 , base_k * 2 )
2485+
2486+ if new_ratio < target_ratio * 0.8 :
2487+ new_k = max (k_min , base_k - 1 )
2488+ elif new_ratio > target_ratio * 1.2 :
2489+ new_k = min (k_max , base_k + 1 )
2490+ else :
2491+ new_k = base_k
2492+
2493+ speculative_config .num_speculative_tokens = new_k
2494+
23582495 def _bookkeeping_sync (
23592496 self ,
23602497 scheduler_output : "SchedulerOutput" ,
@@ -4836,3 +4973,125 @@ def _to_list(self, sampled_token_ids: torch.Tensor) -> list[list[int]]:
48364973 self .transfer_event .record ()
48374974 self .transfer_event .synchronize ()
48384975 return pinned .tolist ()
4976+ @dataclass
4977+ class _SCVGraphEntry :
4978+ num_reqs : int
4979+ max_spec_len : int
4980+ total_tokens : int
4981+ sampled_shape : tuple [int , int ]
4982+ sampled_dtype : torch .dtype
4983+ draft_dtype : torch .dtype
4984+ device : torch .device
4985+
4986+ def __post_init__ (self ):
4987+ self .sampled_buffer = torch .empty (
4988+ self .sampled_shape , device = self .device , dtype = self .sampled_dtype
4989+ )
4990+ self .draft_buffer = torch .empty (
4991+ (self .total_tokens ,), device = self .device , dtype = self .draft_dtype
4992+ )
4993+ self .num_tokens_buffer = torch .empty (
4994+ (self .num_reqs ,), device = self .device , dtype = torch .int32
4995+ )
4996+ self .cu_buffer = torch .empty (
4997+ (self .num_reqs ,), device = self .device , dtype = torch .int32
4998+ )
4999+ self .mask_buffer = torch .empty (
5000+ (self .total_tokens ,), device = self .device , dtype = torch .bool
5001+ )
5002+ self .graph = torch .cuda .CUDAGraph ()
5003+ self ._captured = False
5004+
5005+ def capture (self ):
5006+ if self ._captured :
5007+ return
5008+ mask = GPUModelRunner ._scv_compute_mask (
5009+ self .draft_buffer ,
5010+ self .num_tokens_buffer ,
5011+ self .cu_buffer ,
5012+ self .sampled_buffer ,
5013+ self .max_spec_len ,
5014+ self .total_tokens ,
5015+ )
5016+ self .mask_buffer .copy_ (mask )
5017+ torch .cuda .synchronize ()
5018+ with torch .cuda .graph (self .graph ):
5019+ mask = GPUModelRunner ._scv_compute_mask (
5020+ self .draft_buffer ,
5021+ self .num_tokens_buffer ,
5022+ self .cu_buffer ,
5023+ self .sampled_buffer ,
5024+ self .max_spec_len ,
5025+ self .total_tokens ,
5026+ )
5027+ self .mask_buffer .copy_ (mask )
5028+ self ._captured = True
5029+
5030+ def run (self ):
5031+ if not self ._captured :
5032+ self .capture ()
5033+ self .graph .replay ()
5034+ return self .mask_buffer
5035+
5036+
5037+ class SCVGraphExecutor :
5038+ def __init__ (self , device : torch .device ):
5039+ self .device = device
5040+ self .entries : dict [tuple [Any , ...], _SCVGraphEntry ] = {}
5041+ self .enabled = torch .cuda .is_available ()
5042+
5043+ def run (
5044+ self ,
5045+ spec_decode_metadata : SpecDecodeMetadata ,
5046+ sampled_token_ids : torch .Tensor ,
5047+ total_tokens : int ,
5048+ ) -> torch .Tensor | None :
5049+ if not self .enabled :
5050+ return None
5051+ num_reqs = len (spec_decode_metadata .num_draft_tokens )
5052+ max_spec_len = spec_decode_metadata .max_spec_len
5053+ key = (
5054+ num_reqs ,
5055+ max_spec_len ,
5056+ sampled_token_ids .shape [1 ],
5057+ total_tokens ,
5058+ sampled_token_ids .dtype ,
5059+ )
5060+ entry = self .entries .get (key )
5061+ need_capture = False
5062+ if entry is None :
5063+ entry = _SCVGraphEntry (
5064+ num_reqs = num_reqs ,
5065+ max_spec_len = max_spec_len ,
5066+ total_tokens = total_tokens ,
5067+ sampled_shape = sampled_token_ids [:, :max_spec_len ].shape ,
5068+ sampled_dtype = sampled_token_ids .dtype ,
5069+ draft_dtype = spec_decode_metadata .draft_token_ids .dtype ,
5070+ device = self .device ,
5071+ )
5072+ self .entries [key ] = entry
5073+ need_capture = True
5074+ try :
5075+ sampled_view = sampled_token_ids [:, :max_spec_len ]
5076+ entry .sampled_buffer .copy_ (sampled_view )
5077+ draft_ids = spec_decode_metadata .draft_token_ids .to (self .device )
5078+ entry .draft_buffer .zero_ ()
5079+ entry .draft_buffer [: draft_ids .numel ()].copy_ (draft_ids )
5080+ num_tokens_tensor = torch .tensor (
5081+ spec_decode_metadata .num_draft_tokens ,
5082+ device = self .device ,
5083+ dtype = torch .int32 ,
5084+ )
5085+ entry .num_tokens_buffer .copy_ (num_tokens_tensor )
5086+ cu_tensor = spec_decode_metadata .cu_num_draft_tokens .to (
5087+ device = self .device , dtype = torch .int32
5088+ )
5089+ entry .cu_buffer .copy_ (cu_tensor )
5090+ if need_capture :
5091+ entry .capture ()
5092+ return entry .run ()
5093+ except RuntimeError as exc :
5094+ logger .warning ("SCV graph execution disabled: %s" , exc )
5095+ self .enabled = False
5096+ self .entries .clear ()
5097+ return None
0 commit comments