Skip to content

Commit 50ce473

Browse files
authored
Add SCV graph replay and adaptive controller for NWOR staging
- add VLLM_SCV_MODE env flag to GPUModelRunner with “off” (default), “graph”, and “adaptive” modes layered on top of the NWOR deferred writer - implement CUDA-graph executor that captures the verify mask computation and replays it per window, falling back gracefully when graphs aren’t available - add vectorized mask path and adaptive controller that tracks recent acceptance ratios to adjust num_speculative_tokens on the fly - propagate SCV stats through existing nwor_metrics so we can monitor accepted prefixes without extra instrumentation - extend deferred-writer tests to cover the new mask logic in adaptive mode
2 parents cac7956 + 4fdd1a8 commit 50ce473

File tree

3 files changed

+283
-10
lines changed

3 files changed

+283
-10
lines changed

tests/v1/test_deferred_writer.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,17 @@ def test_nwor_immediate_mode_skips_window():
196196
assert manager.get_mode() == "immediate"
197197

198198

199+
def test_scv_vectorized_mask_matches_reference():
200+
metadata = _make_metadata([1, 2, 3, 4], [4])
201+
sampled = torch.tensor([[1, 2, 0, 4]], dtype=torch.int32)
202+
203+
runner = GPUModelRunner.__new__(GPUModelRunner)
204+
runner._scv_mode = "adaptive"
205+
206+
mask = runner._build_nwor_acceptance_mask(metadata, sampled)
207+
assert mask.tolist() == [True, True, False, False]
208+
209+
199210
def test_commit_failure_triggers_fallback_metrics():
200211
manager = DeferredWriteManager()
201212
assert manager.begin_window([1])

vllm/envs.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,7 @@
200200
VLLM_DISABLE_PAD_FOR_CUDAGRAPH: bool = False
201201
VLLM_DISABLE_NWOR: bool = False
202202
VLLM_NWOR_MODE: str = "stage"
203+
VLLM_SCV_MODE: str = "off"
203204
VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS: bool = False
204205
VLLM_CUSTOM_SCOPES_FOR_PROFILING: bool = False
205206
VLLM_NVTX_SCOPES_FOR_PROFILING: bool = False
@@ -1315,6 +1316,8 @@ def get_vllm_port() -> int | None:
13151316
"VLLM_DISABLE_NWOR": lambda: bool(int(os.getenv("VLLM_DISABLE_NWOR", "0"))),
13161317
# Select NWOR mode: "stage" (default) or "immediate" to bypass staging.
13171318
"VLLM_NWOR_MODE": lambda: os.getenv("VLLM_NWOR_MODE", "stage"),
1319+
# Speculative chunk verify mode: "off" (default), "graph", or "adaptive".
1320+
"VLLM_SCV_MODE": lambda: os.getenv("VLLM_SCV_MODE", "off"),
13181321
# Used to force set up loopback IP
13191322
"VLLM_LOOPBACK_IP": lambda: os.getenv("VLLM_LOOPBACK_IP", ""),
13201323
# Used to set the process name prefix for vLLM processes.

vllm/v1/worker/gpu_model_runner.py

Lines changed: 269 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import itertools
66
import time
77
from collections import defaultdict
8+
from dataclasses import dataclass
89
from collections.abc import Iterator
910
from contextlib import contextmanager
1011
from copy import deepcopy
@@ -509,6 +510,8 @@ def __init__(
509510
# Cached outputs.
510511
self._deferred_write_manager = DeferredWriteManager(mode=envs.VLLM_NWOR_MODE)
511512
self._latest_nwor_window_metrics: dict[str, int | str] | None = None
513+
self._scv_mode = envs.VLLM_SCV_MODE.lower()
514+
self._scv_graph_executor: SCVGraphExecutor | None = None
512515
self._draft_token_ids: list[list[int]] | torch.Tensor | None = None
513516
self.transfer_event = torch.cuda.Event()
514517
self.sampled_token_ids_pinned_cpu = torch.empty(
@@ -518,6 +521,14 @@ def __init__(
518521
pin_memory=self.pin_memory,
519522
)
520523

524+
def _scv_enabled(self) -> bool:
525+
if not hasattr(self, "_scv_mode"):
526+
self._scv_mode = envs.VLLM_SCV_MODE.lower()
527+
if self._scv_mode not in ("off", "graph", "adaptive"):
528+
logger.warning("SCV: unsupported mode '%s', disabling.", self._scv_mode)
529+
self._scv_mode = "off"
530+
return self._scv_mode != "off"
531+
521532
def reset_mm_cache(self) -> None:
522533
if self.mm_budget:
523534
self.mm_budget.reset_cache()
@@ -2316,6 +2327,15 @@ def _build_nwor_acceptance_mask(
23162327
target_device = spec_decode_metadata.draft_token_ids.device
23172328
work_device = sampled_token_ids.device
23182329

2330+
if self._scv_enabled():
2331+
mask = self._scv_vectorized_mask(
2332+
spec_decode_metadata, sampled_token_ids, total_tokens, work_device
2333+
)
2334+
if mask is not None:
2335+
if mask.device != target_device:
2336+
mask = mask.to(device=target_device)
2337+
return mask
2338+
23192339
draft_ids = spec_decode_metadata.draft_token_ids
23202340
if draft_ids.device != work_device:
23212341
draft_ids = draft_ids.to(device=work_device)
@@ -2336,16 +2356,9 @@ def _build_nwor_acceptance_mask(
23362356
row = row.to(dtype=draft_ids.dtype)
23372357

23382358
draft_slice = draft_ids[start:end]
2339-
comparison = (row == draft_slice).flatten()
2340-
2341-
if bool(comparison.all().item()):
2342-
accepted = draft_count
2343-
else:
2344-
reject = torch.nonzero(~comparison, as_tuple=False)
2345-
accepted = int(reject[0, 0].item()) if reject.numel() > 0 else draft_count
2346-
2347-
if accepted > 0:
2348-
mask_work[start : start + accepted] = True
2359+
comparison = (row == draft_slice)
2360+
prefix = torch.cumprod(comparison.to(torch.int32), dim=0)
2361+
mask_work[start:end] = prefix.to(torch.bool)
23492362
start = end
23502363

23512364
if start != total_tokens:
@@ -2355,6 +2368,130 @@ def _build_nwor_acceptance_mask(
23552368
return mask_work
23562369
return mask_work.to(device=target_device)
23572370

2371+
def _scv_vectorized_mask(
2372+
self,
2373+
spec_decode_metadata: SpecDecodeMetadata,
2374+
sampled_token_ids: torch.Tensor,
2375+
total_tokens: int,
2376+
device: torch.device,
2377+
) -> torch.Tensor | None:
2378+
draft_ids = spec_decode_metadata.draft_token_ids
2379+
max_spec_len = spec_decode_metadata.max_spec_len
2380+
num_draft_tensor = torch.tensor(
2381+
spec_decode_metadata.num_draft_tokens,
2382+
device=device,
2383+
dtype=torch.int32,
2384+
)
2385+
if draft_ids.device != device:
2386+
draft_ids = draft_ids.to(device=device)
2387+
2388+
cu = spec_decode_metadata.cu_num_draft_tokens.to(device=device)
2389+
2390+
if hasattr(self, "_scv_mode") and self._scv_mode == "graph":
2391+
executor = getattr(self, "_scv_graph_executor", None)
2392+
if executor is None:
2393+
executor = SCVGraphExecutor(device)
2394+
self._scv_graph_executor = executor
2395+
mask = executor.run(
2396+
spec_decode_metadata, sampled_token_ids, total_tokens
2397+
)
2398+
if mask is not None:
2399+
return mask
2400+
2401+
if hasattr(self, "_scv_mode") and self._scv_mode == "adaptive":
2402+
mask = self._scv_compute_mask(
2403+
draft_ids,
2404+
num_draft_tensor,
2405+
cu,
2406+
sampled_token_ids,
2407+
max_spec_len,
2408+
total_tokens,
2409+
)
2410+
self._scv_update_controller(spec_decode_metadata, mask)
2411+
return mask
2412+
2413+
mask = self._scv_compute_mask(
2414+
draft_ids,
2415+
num_draft_tensor,
2416+
cu,
2417+
sampled_token_ids,
2418+
max_spec_len,
2419+
total_tokens,
2420+
)
2421+
return mask
2422+
2423+
@staticmethod
2424+
def _scv_compute_mask(
2425+
draft_ids: torch.Tensor,
2426+
num_draft_tokens: torch.Tensor,
2427+
cu_num_draft_tokens: torch.Tensor,
2428+
sampled_token_ids: torch.Tensor,
2429+
max_spec_len: int,
2430+
total_tokens: int,
2431+
) -> torch.Tensor:
2432+
device = draft_ids.device
2433+
indices = torch.arange(total_tokens, device=device, dtype=torch.int32)
2434+
req_idx = torch.bucketize(indices, cu_num_draft_tokens)
2435+
prev_cu = torch.cat([cu_num_draft_tokens.new_zeros(1), cu_num_draft_tokens[:-1]])
2436+
pos_in_req = indices - prev_cu[req_idx]
2437+
2438+
gathered = sampled_token_ids[req_idx, pos_in_req]
2439+
comparison = gathered == draft_ids
2440+
2441+
max_val = max_spec_len + 1
2442+
values = torch.where(
2443+
~comparison,
2444+
(pos_in_req + 1).to(torch.int32),
2445+
torch.full_like(pos_in_req, max_val, dtype=torch.int32),
2446+
)
2447+
2448+
accepted = torch.full(
2449+
(num_draft_tokens.numel(),),
2450+
max_val,
2451+
device=device,
2452+
dtype=torch.int32,
2453+
)
2454+
accepted.scatter_reduce_(0, req_idx, values, reduce="amin")
2455+
accepted = torch.where(
2456+
accepted == max_val,
2457+
num_draft_tokens,
2458+
accepted - 1,
2459+
)
2460+
accepted_broadcast = accepted[req_idx]
2461+
mask_flat = pos_in_req < accepted_broadcast
2462+
return mask_flat
2463+
2464+
def _scv_update_controller(
2465+
self,
2466+
spec_decode_metadata: SpecDecodeMetadata,
2467+
mask: torch.Tensor,
2468+
) -> None:
2469+
target_ratio = 0.6
2470+
alpha = 0.2
2471+
accepted = int(mask.sum().item())
2472+
total = max(mask.numel(), 1)
2473+
ratio = accepted / total
2474+
prev = getattr(self, "_scv_accept_ratio", target_ratio)
2475+
new_ratio = (1 - alpha) * prev + alpha * ratio
2476+
self._scv_accept_ratio = new_ratio
2477+
2478+
speculative_config = getattr(self, "speculative_config", None)
2479+
if speculative_config is None or not hasattr(speculative_config, "num_speculative_tokens"):
2480+
return
2481+
2482+
base_k = speculative_config.num_speculative_tokens
2483+
k_min = max(1, base_k // 4)
2484+
k_max = max(1, base_k * 2)
2485+
2486+
if new_ratio < target_ratio * 0.8:
2487+
new_k = max(k_min, base_k - 1)
2488+
elif new_ratio > target_ratio * 1.2:
2489+
new_k = min(k_max, base_k + 1)
2490+
else:
2491+
new_k = base_k
2492+
2493+
speculative_config.num_speculative_tokens = new_k
2494+
23582495
def _bookkeeping_sync(
23592496
self,
23602497
scheduler_output: "SchedulerOutput",
@@ -4836,3 +4973,125 @@ def _to_list(self, sampled_token_ids: torch.Tensor) -> list[list[int]]:
48364973
self.transfer_event.record()
48374974
self.transfer_event.synchronize()
48384975
return pinned.tolist()
4976+
@dataclass
4977+
class _SCVGraphEntry:
4978+
num_reqs: int
4979+
max_spec_len: int
4980+
total_tokens: int
4981+
sampled_shape: tuple[int, int]
4982+
sampled_dtype: torch.dtype
4983+
draft_dtype: torch.dtype
4984+
device: torch.device
4985+
4986+
def __post_init__(self):
4987+
self.sampled_buffer = torch.empty(
4988+
self.sampled_shape, device=self.device, dtype=self.sampled_dtype
4989+
)
4990+
self.draft_buffer = torch.empty(
4991+
(self.total_tokens,), device=self.device, dtype=self.draft_dtype
4992+
)
4993+
self.num_tokens_buffer = torch.empty(
4994+
(self.num_reqs,), device=self.device, dtype=torch.int32
4995+
)
4996+
self.cu_buffer = torch.empty(
4997+
(self.num_reqs,), device=self.device, dtype=torch.int32
4998+
)
4999+
self.mask_buffer = torch.empty(
5000+
(self.total_tokens,), device=self.device, dtype=torch.bool
5001+
)
5002+
self.graph = torch.cuda.CUDAGraph()
5003+
self._captured = False
5004+
5005+
def capture(self):
5006+
if self._captured:
5007+
return
5008+
mask = GPUModelRunner._scv_compute_mask(
5009+
self.draft_buffer,
5010+
self.num_tokens_buffer,
5011+
self.cu_buffer,
5012+
self.sampled_buffer,
5013+
self.max_spec_len,
5014+
self.total_tokens,
5015+
)
5016+
self.mask_buffer.copy_(mask)
5017+
torch.cuda.synchronize()
5018+
with torch.cuda.graph(self.graph):
5019+
mask = GPUModelRunner._scv_compute_mask(
5020+
self.draft_buffer,
5021+
self.num_tokens_buffer,
5022+
self.cu_buffer,
5023+
self.sampled_buffer,
5024+
self.max_spec_len,
5025+
self.total_tokens,
5026+
)
5027+
self.mask_buffer.copy_(mask)
5028+
self._captured = True
5029+
5030+
def run(self):
5031+
if not self._captured:
5032+
self.capture()
5033+
self.graph.replay()
5034+
return self.mask_buffer
5035+
5036+
5037+
class SCVGraphExecutor:
5038+
def __init__(self, device: torch.device):
5039+
self.device = device
5040+
self.entries: dict[tuple[Any, ...], _SCVGraphEntry] = {}
5041+
self.enabled = torch.cuda.is_available()
5042+
5043+
def run(
5044+
self,
5045+
spec_decode_metadata: SpecDecodeMetadata,
5046+
sampled_token_ids: torch.Tensor,
5047+
total_tokens: int,
5048+
) -> torch.Tensor | None:
5049+
if not self.enabled:
5050+
return None
5051+
num_reqs = len(spec_decode_metadata.num_draft_tokens)
5052+
max_spec_len = spec_decode_metadata.max_spec_len
5053+
key = (
5054+
num_reqs,
5055+
max_spec_len,
5056+
sampled_token_ids.shape[1],
5057+
total_tokens,
5058+
sampled_token_ids.dtype,
5059+
)
5060+
entry = self.entries.get(key)
5061+
need_capture = False
5062+
if entry is None:
5063+
entry = _SCVGraphEntry(
5064+
num_reqs=num_reqs,
5065+
max_spec_len=max_spec_len,
5066+
total_tokens=total_tokens,
5067+
sampled_shape=sampled_token_ids[:, :max_spec_len].shape,
5068+
sampled_dtype=sampled_token_ids.dtype,
5069+
draft_dtype=spec_decode_metadata.draft_token_ids.dtype,
5070+
device=self.device,
5071+
)
5072+
self.entries[key] = entry
5073+
need_capture = True
5074+
try:
5075+
sampled_view = sampled_token_ids[:, :max_spec_len]
5076+
entry.sampled_buffer.copy_(sampled_view)
5077+
draft_ids = spec_decode_metadata.draft_token_ids.to(self.device)
5078+
entry.draft_buffer.zero_()
5079+
entry.draft_buffer[: draft_ids.numel()].copy_(draft_ids)
5080+
num_tokens_tensor = torch.tensor(
5081+
spec_decode_metadata.num_draft_tokens,
5082+
device=self.device,
5083+
dtype=torch.int32,
5084+
)
5085+
entry.num_tokens_buffer.copy_(num_tokens_tensor)
5086+
cu_tensor = spec_decode_metadata.cu_num_draft_tokens.to(
5087+
device=self.device, dtype=torch.int32
5088+
)
5089+
entry.cu_buffer.copy_(cu_tensor)
5090+
if need_capture:
5091+
entry.capture()
5092+
return entry.run()
5093+
except RuntimeError as exc:
5094+
logger.warning("SCV graph execution disabled: %s", exc)
5095+
self.enabled = False
5096+
self.entries.clear()
5097+
return None

0 commit comments

Comments
 (0)