Skip to content

Commit dbca650

Browse files
markmckylesayrs
authored andcommitted
[V1][Metrics] Initial speculative decoding metrics (#15151)
Signed-off-by: Mark McLoughlin <markmc@redhat.com> Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
1 parent 82df1e8 commit dbca650

File tree

5 files changed

+204
-2
lines changed

5 files changed

+204
-2
lines changed

tests/v1/core/test_scheduler.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -611,3 +611,98 @@ def test_schedule_concurrent_batches(enable_prefix_caching: Optional[bool],
611611
prompt_logprobs_dict={},
612612
)
613613
scheduler.update_from_output(scheduler_output1, model_runner_output)
614+
615+
616+
# Note - these test cases mirror some of those in test_rejection_sampler.py
617+
@pytest.mark.parametrize(
618+
"spec_tokens,output_tokens,expected",
619+
[
620+
([[1, 2, 3]], [[1, 2, 3, 4]], (3, 3)), # perfect match
621+
([[1, 2, 3]], [[1, 5]], (3, 1)), # early mismatch
622+
([[1, 2], [3]], [[1, 2, 5], [3, 4]], (3, 3)), # multiple sequences
623+
([[1]], [[1, 2]], (1, 1)), # single token sequence
624+
([[]], [[5]], (0, 0)), # empty sequence
625+
([[1, 2, 3], [4, 5, 6]], [[1, 2, 7], [4, 8]],
626+
(6, 3)), # multiple mismatches
627+
])
628+
def test_schedule_spec_decoding_stats(spec_tokens, output_tokens, expected):
629+
"""Test scheduling behavior with speculative decoding.
630+
631+
This test verifies that:
632+
1. Speculated tokens get scheduled correctly
633+
2. Spec decoding stats properly count number of draft and accepted tokens
634+
"""
635+
scheduler = create_scheduler()
636+
requests = create_requests(num_requests=len(spec_tokens), num_tokens=1)
637+
req_ids = []
638+
req_to_index = {}
639+
for i, request in enumerate(requests):
640+
scheduler.add_request(request)
641+
req_ids.append(request.request_id)
642+
req_to_index[request.request_id] = i
643+
644+
# Schedule a decode, which will also draft speculative tokens
645+
output = scheduler.schedule()
646+
assert len(output.scheduled_new_reqs) == len(requests)
647+
assert output.total_num_scheduled_tokens == len(requests)
648+
for i in range(len(requests)):
649+
req_id = requests[i].request_id
650+
assert output.num_scheduled_tokens[req_id] == 1
651+
assert req_id not in output.scheduled_spec_decode_tokens
652+
653+
model_runner_output = ModelRunnerOutput(
654+
req_ids=req_ids,
655+
req_id_to_index=req_to_index,
656+
sampled_token_ids=[[0] for _ in range(len(requests))],
657+
spec_token_ids=spec_tokens,
658+
logprobs=None,
659+
prompt_logprobs_dict={},
660+
)
661+
engine_core_outputs = scheduler.update_from_output(output,
662+
model_runner_output)
663+
664+
for i in range(len(requests)):
665+
running_req = scheduler.running[i]
666+
# The prompt token
667+
assert running_req.num_computed_tokens == 1
668+
# The prompt token and the sampled token
669+
assert running_req.num_tokens == 2
670+
# The prompt token, the sampled token, and the speculated tokens
671+
assert running_req.num_tokens_with_spec == 2 + len(spec_tokens[i])
672+
673+
# No draft or accepted tokens counted yet
674+
assert engine_core_outputs.scheduler_stats.spec_decoding_stats is not None
675+
stats = engine_core_outputs.scheduler_stats.spec_decoding_stats
676+
assert stats.num_draft_tokens == 0
677+
assert stats.num_accepted_tokens == 0
678+
679+
# Schedule the speculated tokens for validation
680+
output = scheduler.schedule()
681+
assert len(output.scheduled_new_reqs) == 0
682+
# The sampled token and speculated tokens
683+
assert output.total_num_scheduled_tokens == \
684+
len(requests) + sum(len(ids) for ids in spec_tokens)
685+
for i in range(len(requests)):
686+
req_id = requests[i].request_id
687+
assert output.num_scheduled_tokens[req_id] == 1 + len(spec_tokens[i])
688+
if spec_tokens[i]:
689+
assert len(output.scheduled_spec_decode_tokens[req_id]) == \
690+
len(spec_tokens[i])
691+
else:
692+
assert req_id not in output.scheduled_spec_decode_tokens
693+
694+
model_runner_output = ModelRunnerOutput(
695+
req_ids=req_ids,
696+
req_id_to_index=req_to_index,
697+
sampled_token_ids=output_tokens,
698+
spec_token_ids=None,
699+
logprobs=None,
700+
prompt_logprobs_dict={},
701+
)
702+
engine_core_outputs = scheduler.update_from_output(output,
703+
model_runner_output)
704+
705+
assert engine_core_outputs.scheduler_stats.spec_decoding_stats is not None
706+
stats = engine_core_outputs.scheduler_stats.spec_decoding_stats
707+
assert stats.num_draft_tokens == expected[0]
708+
assert stats.num_accepted_tokens == expected[1]

vllm/v1/core/sched/scheduler.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from vllm.v1.metrics.stats import SchedulerStats
2424
from vllm.v1.outputs import ModelRunnerOutput
2525
from vllm.v1.request import Request, RequestStatus
26+
from vllm.v1.spec_decode.metrics import SpecDecodingStats
2627
from vllm.v1.structured_output import StructuredOutputManager
2728

2829
logger = init_logger(__name__)
@@ -552,6 +553,7 @@ def update_from_output(
552553
spec_token_ids = model_runner_output.spec_token_ids
553554
logprobs = model_runner_output.logprobs
554555
prompt_logprobs_dict = model_runner_output.prompt_logprobs_dict
556+
spec_decoding_stats = SpecDecodingStats() if self.log_stats else None
555557
num_scheduled_tokens = scheduler_output.num_scheduled_tokens
556558

557559
new_running: list[Request] = []
@@ -584,6 +586,11 @@ def update_from_output(
584586
len(generated_token_ids))
585587
request.num_computed_tokens -= num_tokens_rejected
586588

589+
if spec_decoding_stats is not None:
590+
spec_decoding_stats.observe(
591+
num_draft_tokens=len(scheduled_spec_token_ids),
592+
num_accepted_tokens=len(generated_token_ids) - 1)
593+
587594
cached_encoder_input_ids = (
588595
self.encoder_cache_manager.get_cached_input_ids(request))
589596
# OPTIMIZATION: Avoid list(set) if the set is empty.
@@ -657,7 +664,7 @@ def update_from_output(
657664
self.running = new_running
658665
engine_core_outputs = EngineCoreOutputs(
659666
outputs=outputs,
660-
scheduler_stats=self.make_stats(),
667+
scheduler_stats=self.make_stats(spec_decoding_stats),
661668
)
662669
if self.include_finished_set:
663670
#TODO currently sending duplicates here, improve this
@@ -724,12 +731,16 @@ def get_num_unscheduled_requests(self) -> int:
724731
def reset_prefix_cache(self) -> bool:
725732
return self.kv_cache_manager.reset_prefix_cache()
726733

727-
def make_stats(self) -> Optional[SchedulerStats]:
734+
def make_stats(
735+
self,
736+
spec_decoding_stats: Optional[SpecDecodingStats] = None,
737+
) -> Optional[SchedulerStats]:
728738
if not self.log_stats:
729739
return None
730740
return SchedulerStats(
731741
num_running_reqs=len(self.running),
732742
num_waiting_reqs=len(self.waiting),
733743
gpu_cache_usage=self.kv_cache_manager.usage,
734744
prefix_cache_stats=self.kv_cache_manager.make_prefix_cache_stats(),
745+
spec_decoding_stats=spec_decoding_stats,
735746
)

vllm/v1/metrics/loggers.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from vllm.v1.core.kv_cache_utils import PrefixCachingMetrics
1313
from vllm.v1.engine import FinishReason
1414
from vllm.v1.metrics.stats import IterationStats, SchedulerStats
15+
from vllm.v1.spec_decode.metrics import SpecDecodingMetrics
1516

1617
logger = init_logger(__name__)
1718

@@ -38,6 +39,7 @@ def __init__(self, engine_index: int = 0):
3839
# Prefix cache metrics. This cannot be reset.
3940
# TODO: Make the interval configurable.
4041
self.prefix_caching_metrics = PrefixCachingMetrics()
42+
self.spec_decoding_metrics = SpecDecodingMetrics()
4143

4244
def _reset(self, now):
4345
self.last_log_time = now
@@ -65,6 +67,10 @@ def record(self, scheduler_stats: SchedulerStats,
6567

6668
self.prefix_caching_metrics.observe(scheduler_stats.prefix_cache_stats)
6769

70+
if scheduler_stats.spec_decoding_stats is not None:
71+
self.spec_decoding_metrics.observe(
72+
scheduler_stats.spec_decoding_stats)
73+
6874
self.last_scheduler_stats = scheduler_stats
6975

7076
def log(self):
@@ -94,6 +100,9 @@ def log(self):
94100
self.prefix_caching_metrics.hit_rate * 100,
95101
)
96102

103+
if scheduler_stats.spec_decoding_stats is not None:
104+
self.spec_decoding_metrics.log()
105+
97106

98107
class PrometheusStatLogger(StatLoggerBase):
99108

@@ -302,6 +311,24 @@ def __init__(self, vllm_config: VllmConfig, engine_index: int = 0):
302311
self.labelname_running_lora_adapters,
303312
])
304313

314+
#
315+
# Speculative Decoding metrics
316+
# The acceptance rate can be calculated using a PromQL query:
317+
#
318+
# rate(vllm:spec_decode_num_accepted_tokens_total[$interval]) /
319+
# rate(vllm:spec_decode_num_draft_tokens_total[$interval])
320+
#
321+
self.counter_spec_decode_num_draft_tokens = \
322+
prometheus_client.Counter(
323+
name="vllm:spec_decode_num_draft_tokens_total",
324+
documentation="Number of draft tokens.",
325+
labelnames=labelnames).labels(*labelvalues)
326+
self.counter_spec_decode_num_accepted_tokens = \
327+
prometheus_client.Counter(
328+
name="vllm:spec_decode_num_accepted_tokens_total",
329+
documentation="Number of accepted tokens.",
330+
labelnames=labelnames).labels(*labelvalues)
331+
305332
#
306333
# Cache config info metric
307334
#
@@ -338,6 +365,12 @@ def record(self, scheduler_stats: SchedulerStats,
338365
self.counter_gpu_prefix_cache_hits.inc(
339366
scheduler_stats.prefix_cache_stats.hits)
340367

368+
if scheduler_stats.spec_decoding_stats is not None:
369+
self.counter_spec_decode_num_draft_tokens.inc(
370+
scheduler_stats.spec_decoding_stats.num_draft_tokens)
371+
self.counter_spec_decode_num_accepted_tokens.inc(
372+
scheduler_stats.spec_decoding_stats.num_accepted_tokens)
373+
341374
if iteration_stats is None:
342375
return
343376

vllm/v1/metrics/stats.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
from dataclasses import dataclass, field
55
from typing import TYPE_CHECKING, Optional
66

7+
from vllm.v1.spec_decode.metrics import SpecDecodingStats
8+
79
if TYPE_CHECKING:
810
from vllm.v1.engine import EngineCoreEvent, EngineCoreOutput, FinishReason
911
from vllm.v1.engine.output_processor import RequestState
@@ -35,6 +37,8 @@ class SchedulerStats:
3537
prefix_cache_stats: PrefixCacheStats = field(
3638
default_factory=PrefixCacheStats)
3739

40+
spec_decoding_stats: Optional[SpecDecodingStats] = None
41+
3842

3943
@dataclass
4044
class LoRAStats:

vllm/v1/spec_decode/metrics.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
3+
from dataclasses import dataclass
4+
5+
import numpy as np
6+
7+
from vllm.logger import init_logger
8+
9+
logger = init_logger(__name__)
10+
11+
12+
@dataclass
13+
class SpecDecodingStats:
14+
num_draft_tokens: int = 0
15+
num_accepted_tokens: int = 0
16+
17+
def take(self):
18+
copied = SpecDecodingStats(self.num_draft_tokens,
19+
self.num_accepted_tokens)
20+
self.reset()
21+
return copied
22+
23+
def reset(self):
24+
self.num_draft_tokens = 0
25+
self.num_accepted_tokens = 0
26+
27+
def observe(self, num_draft_tokens: int, num_accepted_tokens: int):
28+
self.num_draft_tokens += num_draft_tokens
29+
self.num_accepted_tokens += num_accepted_tokens
30+
31+
32+
class SpecDecodingMetrics:
33+
34+
def __init__(self):
35+
self.reset()
36+
37+
def reset(self):
38+
self.num_draft_tokens: list[int] = []
39+
self.num_accepted_tokens: list[int] = []
40+
41+
def observe(self, spec_decoding_stats: SpecDecodingStats):
42+
self.num_draft_tokens.append(spec_decoding_stats.num_draft_tokens)
43+
self.num_accepted_tokens.append(
44+
spec_decoding_stats.num_accepted_tokens)
45+
46+
def log(self):
47+
num_draft_tokens = np.sum(self.num_draft_tokens)
48+
num_accepted_tokens = np.sum(self.num_accepted_tokens)
49+
50+
draft_acceptance_rate = (num_accepted_tokens / num_draft_tokens
51+
if num_draft_tokens > 0 else float("nan"))
52+
53+
logger.info(
54+
"Speculative metrics: "
55+
"Draft acceptance rate: %.3f, "
56+
"Number of accepted tokens: %d, "
57+
"Number of draft tokens: %d, ", draft_acceptance_rate,
58+
num_accepted_tokens, num_draft_tokens)
59+
self.reset()

0 commit comments

Comments
 (0)