Skip to content

Commit 478238d

Browse files
committed
[WIP][V1][Metrics] Speculative decoding metrics
Fixes #13990, part of #10582 Omitting system efficiency for now. Signed-off-by: Mark McLoughlin <markmc@redhat.com>
1 parent c6bc003 commit 478238d

File tree

6 files changed

+141
-8
lines changed

6 files changed

+141
-8
lines changed

vllm/v1/core/sched/scheduler.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from vllm.v1.metrics.stats import SchedulerStats
2424
from vllm.v1.outputs import ModelRunnerOutput
2525
from vllm.v1.request import Request, RequestStatus
26+
from vllm.v1.spec_decode.metrics import SpecDecodingStats
2627
from vllm.v1.structured_output import StructuredOutputManager
2728

2829
logger = init_logger(__name__)
@@ -565,6 +566,7 @@ def update_from_output(
565566
spec_token_ids = model_runner_output.spec_token_ids
566567
logprobs = model_runner_output.logprobs
567568
prompt_logprobs_dict = model_runner_output.prompt_logprobs_dict
569+
spec_decoding_stats = SpecDecodingStats() if self.log_stats else None
568570
num_scheduled_tokens = scheduler_output.num_scheduled_tokens
569571

570572
new_running: list[Request] = []
@@ -597,6 +599,15 @@ def update_from_output(
597599
len(generated_token_ids))
598600
request.num_computed_tokens -= num_tokens_rejected
599601

602+
if spec_decoding_stats is not None:
603+
# FIXME: If a drafter proposes zero tokens, we should
604+
# treat this as if num_spec_tokens were proposed and
605+
# all rejected to allow fair comparisons between drafters
606+
spec_decoding_stats.observe(
607+
num_draft_tokens=len(scheduled_spec_token_ids),
608+
num_accepted_tokens=len(generated_token_ids) - 1,
609+
num_emitted_tokens=len(generated_token_ids))
610+
600611
cached_encoder_input_ids = (
601612
self.encoder_cache_manager.get_cached_input_ids(request))
602613
# OPTIMIZATION: Avoid list(set) if the set is empty.
@@ -672,7 +683,7 @@ def update_from_output(
672683
self.running = new_running
673684
engine_core_outputs = EngineCoreOutputs(
674685
outputs=outputs,
675-
scheduler_stats=self.make_stats(),
686+
scheduler_stats=self.make_stats(spec_decoding_stats),
676687
)
677688
if self.include_finished_set:
678689
#TODO currently sending duplicates here, improve this
@@ -739,12 +750,16 @@ def get_num_unscheduled_requests(self) -> int:
739750
def reset_prefix_cache(self) -> bool:
740751
return self.kv_cache_manager.reset_prefix_cache()
741752

742-
def make_stats(self) -> Optional[SchedulerStats]:
753+
def make_stats(
754+
self,
755+
spec_decoding_stats: Optional[SpecDecodingStats] = None,
756+
) -> Optional[SchedulerStats]:
743757
if not self.log_stats:
744758
return None
745759
return SchedulerStats(
746760
num_running_reqs=len(self.running),
747761
num_waiting_reqs=len(self.waiting),
748762
gpu_cache_usage=self.kv_cache_manager.usage,
749763
prefix_cache_stats=self.kv_cache_manager.make_prefix_cache_stats(),
764+
spec_decoding_stats=spec_decoding_stats,
750765
)

vllm/v1/engine/async_llm.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,8 @@ def __init__(
7474
for i in range(vllm_config.parallel_config.data_parallel_size):
7575
loggers: list[StatLoggerBase] = []
7676
if logger.isEnabledFor(logging.INFO):
77-
loggers.append(LoggingStatLogger(engine_index=i))
77+
loggers.append(
78+
LoggingStatLogger(vllm_config, engine_index=i))
7879
loggers.append(
7980
PrometheusStatLogger(vllm_config, engine_index=i))
8081
self.stat_loggers.append(loggers)

vllm/v1/metrics/loggers.py

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from vllm.v1.core.kv_cache_utils import PrefixCachingMetrics
1313
from vllm.v1.engine import FinishReason
1414
from vllm.v1.metrics.stats import IterationStats, SchedulerStats
15+
from vllm.v1.spec_decode.metrics import SpecDecodingMetrics
1516

1617
logger = init_logger(__name__)
1718

@@ -31,13 +32,15 @@ def log(self): # noqa
3132

3233
class LoggingStatLogger(StatLoggerBase):
3334

34-
def __init__(self, engine_index: int = 0):
35+
def __init__(self, vllm_config: VllmConfig, engine_index: int = 0):
3536
self.engine_index = engine_index
3637
self._reset(time.monotonic())
3738
self.last_scheduler_stats = SchedulerStats()
3839
# Prefix cache metrics. This cannot be reset.
3940
# TODO: Make the interval configurable.
4041
self.prefix_caching_metrics = PrefixCachingMetrics()
42+
self.spec_decoding_metrics = SpecDecodingMetrics(
43+
vllm_config.speculative_config)
4144

4245
def _reset(self, now):
4346
self.last_log_time = now
@@ -65,6 +68,10 @@ def record(self, scheduler_stats: SchedulerStats,
6568

6669
self.prefix_caching_metrics.observe(scheduler_stats.prefix_cache_stats)
6770

71+
if scheduler_stats.spec_decoding_stats is not None:
72+
self.spec_decoding_metrics.observe(
73+
scheduler_stats.spec_decoding_stats)
74+
6875
self.last_scheduler_stats = scheduler_stats
6976

7077
def log(self):
@@ -94,6 +101,9 @@ def log(self):
94101
self.prefix_caching_metrics.hit_rate * 100,
95102
)
96103

104+
if scheduler_stats.spec_decoding_stats is not None:
105+
self.spec_decoding_metrics.log()
106+
97107

98108
class PrometheusStatLogger(StatLoggerBase):
99109

@@ -302,6 +312,29 @@ def __init__(self, vllm_config: VllmConfig, engine_index: int = 0):
302312
self.labelname_running_lora_adapters,
303313
])
304314

315+
#
316+
# Speculative Decoding metrics
317+
# The acceptance rate can be calculated using a PromQL query:
318+
#
319+
# rate(vllm:spec_decode_num_accepted_tokens_total[$interval]) /
320+
# rate(vllm:spec_decode_num_draft_tokens_total[$interval])
321+
#
322+
self.counter_spec_decode_num_draft_tokens = \
323+
prometheus_client.Counter(
324+
name="vllm:spec_decode_num_draft_tokens_total",
325+
documentation="Number of draft tokens.",
326+
labelnames=labelnames).labels(*labelvalues)
327+
self.counter_spec_decode_num_accepted_tokens = \
328+
prometheus_client.Counter(
329+
name="vllm:spec_decode_num_accepted_tokens_total",
330+
documentation="Number of accepted tokens.",
331+
labelnames=labelnames).labels(*labelvalues)
332+
self.counter_spec_decode_num_emitted_tokens = \
333+
prometheus_client.Counter(
334+
name="vllm:spec_decode_num_emitted_tokens_total",
335+
documentation="Number of emitted tokens.",
336+
labelnames=labelnames).labels(*labelvalues)
337+
305338
#
306339
# Cache config info metric
307340
#
@@ -338,6 +371,14 @@ def record(self, scheduler_stats: SchedulerStats,
338371
self.counter_gpu_prefix_cache_hits.inc(
339372
scheduler_stats.prefix_cache_stats.hits)
340373

374+
if scheduler_stats.spec_decoding_stats is not None:
375+
self.counter_spec_decode_num_draft_tokens.inc(
376+
scheduler_stats.spec_decoding_stats.num_draft_tokens)
377+
self.counter_spec_decode_num_accepted_tokens.inc(
378+
scheduler_stats.spec_decoding_stats.num_accepted_tokens)
379+
self.counter_spec_decode_num_emitted_tokens.inc(
380+
scheduler_stats.spec_decoding_stats.num_emitted_tokens)
381+
341382
if iteration_stats is None:
342383
return
343384

vllm/v1/metrics/stats.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
from dataclasses import dataclass, field
55
from typing import TYPE_CHECKING, Optional
66

7+
from vllm.v1.spec_decode.metrics import SpecDecodingStats
8+
79
if TYPE_CHECKING:
810
from vllm.v1.engine import EngineCoreEvent, EngineCoreOutput, FinishReason
911
from vllm.v1.engine.output_processor import RequestState
@@ -35,6 +37,8 @@ class SchedulerStats:
3537
prefix_cache_stats: PrefixCacheStats = field(
3638
default_factory=PrefixCacheStats)
3739

40+
spec_decoding_stats: Optional[SpecDecodingStats] = None
41+
3842

3943
@dataclass
4044
class LoRAStats:

vllm/v1/spec_decode/metrics.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
3+
from dataclasses import dataclass
4+
5+
import numpy as np
6+
7+
from vllm.config import SpeculativeConfig
8+
from vllm.logger import init_logger
9+
10+
logger = init_logger(__name__)
11+
12+
13+
@dataclass
14+
class SpecDecodingStats:
15+
num_draft_tokens: int = 0
16+
num_accepted_tokens: int = 0
17+
num_emitted_tokens: int = 0
18+
19+
def take(self):
20+
copied = SpecDecodingStats(self.num_draft_tokens,
21+
self.num_accepted_tokens,
22+
self.num_emitted_tokens)
23+
self.reset()
24+
return copied
25+
26+
def reset(self):
27+
self.num_draft_tokens = 0
28+
self.num_accepted_tokens = 0
29+
self.num_emitted_tokens = 0
30+
31+
def observe(self, num_draft_tokens: int, num_accepted_tokens: int,
32+
num_emitted_tokens: int):
33+
self.num_draft_tokens += num_draft_tokens
34+
self.num_accepted_tokens += num_accepted_tokens
35+
self.num_emitted_tokens += num_emitted_tokens
36+
37+
38+
class SpecDecodingMetrics:
39+
40+
def __init__(self, speculative_config: SpeculativeConfig):
41+
self.num_spec_tokens = (speculative_config.num_speculative_tokens
42+
if speculative_config is not None else 0)
43+
self.reset()
44+
45+
def reset(self):
46+
self.num_draft_tokens: list[int] = []
47+
self.num_accepted_tokens: list[int] = []
48+
self.num_emitted_tokens: list[int] = []
49+
50+
def observe(self, spec_decoding_stats: SpecDecodingStats):
51+
self.num_draft_tokens.append(spec_decoding_stats.num_draft_tokens)
52+
self.num_accepted_tokens.append(
53+
spec_decoding_stats.num_accepted_tokens)
54+
self.num_emitted_tokens.append(spec_decoding_stats.num_emitted_tokens)
55+
56+
def log(self):
57+
num_draft_tokens = np.sum(self.num_draft_tokens)
58+
num_accepted_tokens = np.sum(self.num_accepted_tokens)
59+
num_emitted_tokens = np.sum(self.num_emitted_tokens)
60+
61+
draft_acceptance_rate = (num_accepted_tokens / num_draft_tokens
62+
if num_draft_tokens > 0 else float("nan"))
63+
64+
logger.info(
65+
"Speculative metrics: "
66+
"Draft acceptance rate: %.3f, "
67+
"Number of speculative tokens: %d, "
68+
"Number of accepted tokens: %d, "
69+
"Number of draft tokens: %d, "
70+
"Number of emitted tokens: %d.", draft_acceptance_rate,
71+
num_accepted_tokens, num_draft_tokens, num_emitted_tokens)
72+
self.reset()

vllm/v1/worker/gpu_model_runner.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1151,20 +1151,20 @@ def generate_draft_token_ids(
11511151
self,
11521152
sampled_token_ids: list[list[int]],
11531153
sampling_metadata: SamplingMetadata,
1154-
) -> list[list[int]]:
1154+
) -> list[Optional[list[int]]]:
11551155
# TODO(woosuk): Optimize.
1156-
draft_token_ids: list[list[int]] = []
1156+
draft_token_ids: list[Optional[list[int]]] = []
11571157
for i, sampled_ids in enumerate(sampled_token_ids):
11581158
num_sampled_ids = len(sampled_ids)
11591159
if not num_sampled_ids:
11601160
# Skip speculative decoding.
1161-
draft_token_ids.append([])
1161+
draft_token_ids.append(None)
11621162
continue
11631163

11641164
# Skip requests that require top-p, top-k, etc.
11651165
req_id = self.input_batch.req_ids[i]
11661166
if not is_spec_decode_supported(req_id, self.input_batch):
1167-
draft_token_ids.append([])
1167+
draft_token_ids.append(None)
11681168
continue
11691169

11701170
# Add sampled_token_ids to token_ids_cpu.

0 commit comments

Comments
 (0)