Skip to content

Commit cd3ecad

Browse files
committed
[WIP][V1][Metrics] Speculative decoding metrics
Fixes #13990, part of #10582 Signed-off-by: Mark McLoughlin <markmc@redhat.com>
1 parent 86c6239 commit cd3ecad

File tree

9 files changed

+251
-22
lines changed

9 files changed

+251
-22
lines changed

tests/v1/sample/test_rejection_sampler.py

+43-6
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,10 @@ def test_perfect_match(rejection_sampler):
9595
device=logits.device)
9696
assert torch.equal(output, expected)
9797

98+
assert rejection_sampler.stats.num_draft_tokens == 3
99+
assert rejection_sampler.stats.num_accepted_tokens == 3
100+
assert rejection_sampler.stats.num_emitted_tokens == 4
101+
98102

99103
def test_early_mismatch(rejection_sampler):
100104
"""Test when there's an early mismatch in tokens"""
@@ -122,6 +126,10 @@ def test_early_mismatch(rejection_sampler):
122126
)
123127
assert torch.equal(output, expected)
124128

129+
assert rejection_sampler.stats.num_draft_tokens == 3
130+
assert rejection_sampler.stats.num_accepted_tokens == 1
131+
assert rejection_sampler.stats.num_emitted_tokens == 2
132+
125133

126134
def test_multiple_sequences(rejection_sampler):
127135
"""Test handling multiple sequences of speculated tokens"""
@@ -148,6 +156,10 @@ def test_multiple_sequences(rejection_sampler):
148156
device=logits.device)
149157
assert torch.equal(output, expected)
150158

159+
assert rejection_sampler.stats.num_draft_tokens == 3
160+
assert rejection_sampler.stats.num_accepted_tokens == 3
161+
assert rejection_sampler.stats.num_emitted_tokens == 5
162+
151163

152164
def test_single_token_sequence(rejection_sampler):
153165
"""Test handling sequences with single token"""
@@ -171,6 +183,10 @@ def test_single_token_sequence(rejection_sampler):
171183
expected = torch.tensor([[1, 2]], dtype=torch.int, device=logits.device)
172184
assert torch.equal(output, expected)
173185

186+
assert rejection_sampler.stats.num_draft_tokens == 1
187+
assert rejection_sampler.stats.num_accepted_tokens == 1
188+
assert rejection_sampler.stats.num_emitted_tokens == 2
189+
174190

175191
def test_empty_sequence(rejection_sampler):
176192
"""Test handling empty sequence of speculated tokens"""
@@ -194,6 +210,10 @@ def test_empty_sequence(rejection_sampler):
194210
expected = torch.tensor([[5]], dtype=torch.int, device=logits.device)
195211
assert torch.equal(output, expected)
196212

213+
assert rejection_sampler.stats.num_draft_tokens == 0
214+
assert rejection_sampler.stats.num_accepted_tokens == 0
215+
assert rejection_sampler.stats.num_emitted_tokens == 1
216+
197217

198218
def test_multiple_mismatches(rejection_sampler):
199219
"""Test handling multiple sequences with mismatches"""
@@ -223,17 +243,24 @@ def test_multiple_mismatches(rejection_sampler):
223243
)
224244
assert torch.equal(output, expected)
225245

246+
assert rejection_sampler.stats.num_draft_tokens == 6
247+
assert rejection_sampler.stats.num_accepted_tokens == 3
248+
assert rejection_sampler.stats.num_emitted_tokens == 5
249+
226250

227251
@pytest.mark.parametrize(
228-
"spec_tokens,output_tokens,expected",
252+
"spec_tokens,output_tokens,expected,expected_stats",
229253
[
230-
([[1, 2]], [[1, 2, 3]], [[1, 2, 3]]), # Perfect match with bonus
231-
([[1]], [[2, 3]], [[2, PLACEHOLDER_TOKEN_ID]]), # First mismatch
232-
([[1, 2], [3, 4]], [[1, 5, 6], [3, 4, 7]],
233-
[[1, 5, PLACEHOLDER_TOKEN_ID], [3, 4, 7]]), # Mixed matches
254+
([[1, 2]], [[1, 2, 3]], [[1, 2, 3]],
255+
(2, 2, 3)), # Perfect match with bonus
256+
([[1]], [[2, 3]], [[2, PLACEHOLDER_TOKEN_ID]],
257+
(1, 0, 1)), # First mismatch
258+
([[1, 2], [3, 4]], [[1, 5, 6], [3, 4, 7]
259+
], [[1, 5, PLACEHOLDER_TOKEN_ID], [3, 4, 7]],
260+
(4, 3, 5)), # Mixed matches
234261
])
235262
def test_parametrized_cases(rejection_sampler, spec_tokens, output_tokens,
236-
expected):
263+
expected, expected_stats):
237264
"""Parametrized test for various matching scenarios"""
238265
metadata = create_sampling_metadata(all_greedy=True)
239266
logits = create_logits_tensor(output_tokens)
@@ -254,6 +281,10 @@ def test_parametrized_cases(rejection_sampler, spec_tokens, output_tokens,
254281
device=logits.device)
255282
assert torch.equal(output, expected_tensor)
256283

284+
assert rejection_sampler.stats.num_draft_tokens == expected_stats[0]
285+
assert rejection_sampler.stats.num_accepted_tokens == expected_stats[1]
286+
assert rejection_sampler.stats.num_emitted_tokens == expected_stats[2]
287+
257288

258289
########################### Tests for Random Sampling ###################
259290
@pytest.mark.parametrize("k", [1, 3, 5])
@@ -314,6 +345,12 @@ def test_deterministic_when_seeded(
314345

315346
results.append(rep_result)
316347

348+
stats = rejection_sampler.stats.take()
349+
assert stats.num_draft_tokens == num_tokens
350+
assert stats.num_emitted_tokens >= batch_size
351+
assert (stats.num_emitted_tokens -
352+
batch_size) == stats.num_accepted_tokens
353+
317354
for i in range(batch_size):
318355
if seeded_mask[i]:
319356
for j in range(1, n_rep):

vllm/v1/core/scheduler.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from vllm.v1.metrics.stats import SchedulerStats
2121
from vllm.v1.outputs import ModelRunnerOutput
2222
from vllm.v1.request import Request, RequestStatus
23+
from vllm.v1.spec_decode.metrics import SpecDecodingStats
2324
from vllm.v1.structured_output import StructuredOutputManager
2425

2526
logger = init_logger(__name__)
@@ -533,6 +534,7 @@ def update_from_output(
533534
spec_token_ids = model_runner_output.spec_token_ids
534535
logprobs = model_runner_output.logprobs
535536
prompt_logprobs_dict = model_runner_output.prompt_logprobs_dict
537+
spec_decoding_stats = model_runner_output.spec_decoding_stats
536538
num_scheduled_tokens = scheduler_output.num_scheduled_tokens
537539

538540
new_running: list[Request] = []
@@ -645,7 +647,7 @@ def update_from_output(
645647
self.running = new_running
646648
return EngineCoreOutputs(
647649
outputs=outputs,
648-
scheduler_stats=self.make_stats(),
650+
scheduler_stats=self.make_stats(spec_decoding_stats),
649651
)
650652

651653
def _check_stop(self, request: Request) -> bool:
@@ -733,12 +735,16 @@ def get_num_unscheduled_requests(self) -> int:
733735
def reset_prefix_cache(self) -> bool:
734736
return self.kv_cache_manager.reset_prefix_cache()
735737

736-
def make_stats(self) -> Optional[SchedulerStats]:
738+
def make_stats(
739+
self,
740+
spec_decoding_stats: Optional[SpecDecodingStats] = None,
741+
) -> Optional[SchedulerStats]:
737742
if not self.log_stats:
738743
return None
739744
return SchedulerStats(
740745
num_running_reqs=len(self.running),
741746
num_waiting_reqs=len(self.waiting),
742747
gpu_cache_usage=self.kv_cache_manager.usage,
743748
prefix_cache_stats=self.kv_cache_manager.make_prefix_cache_stats(),
749+
spec_decoding_stats=spec_decoding_stats,
744750
)

vllm/v1/engine/async_llm.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def __init__(
6666
self.stat_loggers: list[StatLoggerBase] = []
6767
if self.log_stats:
6868
if logger.isEnabledFor(logging.INFO):
69-
self.stat_loggers.append(LoggingStatLogger())
69+
self.stat_loggers.append(LoggingStatLogger(vllm_config))
7070
self.stat_loggers.append(PrometheusStatLogger(vllm_config))
7171

7272
# Tokenizer (+ ensure liveness if running in another process).

vllm/v1/metrics/loggers.py

+39-1
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from vllm.v1.core.kv_cache_utils import PrefixCachingMetrics
1313
from vllm.v1.engine import FinishReason
1414
from vllm.v1.metrics.stats import IterationStats, SchedulerStats
15+
from vllm.v1.spec_decode.metrics import SpecDecodingMetrics
1516

1617
logger = init_logger(__name__)
1718

@@ -31,12 +32,14 @@ def log(self): # noqa
3132

3233
class LoggingStatLogger(StatLoggerBase):
3334

34-
def __init__(self):
35+
def __init__(self, vllm_config: VllmConfig):
3536
self._reset(time.monotonic())
3637
self.last_scheduler_stats = SchedulerStats()
3738
# Prefix cache metrics. This cannot be reset.
3839
# TODO: Make the interval configurable.
3940
self.prefix_caching_metrics = PrefixCachingMetrics()
41+
self.spec_decoding_metrics = SpecDecodingMetrics(
42+
vllm_config.speculative_config)
4043

4144
def _reset(self, now):
4245
self.last_log_time = now
@@ -64,6 +67,10 @@ def record(self, scheduler_stats: SchedulerStats,
6467

6568
self.prefix_caching_metrics.observe(scheduler_stats.prefix_cache_stats)
6669

70+
if scheduler_stats.spec_decoding_stats is not None:
71+
self.spec_decoding_metrics.observe(
72+
scheduler_stats.spec_decoding_stats)
73+
6774
self.last_scheduler_stats = scheduler_stats
6875

6976
def log(self):
@@ -91,6 +98,9 @@ def log(self):
9198
self.prefix_caching_metrics.hit_rate * 100,
9299
)
93100

101+
if scheduler_stats.spec_decoding_stats is not None:
102+
self.spec_decoding_metrics.log()
103+
94104

95105
class PrometheusStatLogger(StatLoggerBase):
96106

@@ -296,6 +306,26 @@ def __init__(self, vllm_config: VllmConfig):
296306
self.labelname_running_lora_adapters,
297307
])
298308

309+
#
310+
# Speculative Decoding metrics
311+
# FIXME: add note on acceptance rate and system efficiency
312+
#
313+
self.counter_spec_decode_num_draft_tokens = \
314+
prometheus_client.Counter(
315+
name="vllm:spec_decode_num_draft_tokens_total",
316+
documentation="Number of draft tokens.",
317+
labelnames=labelnames).labels(*labelvalues)
318+
self.counter_spec_decode_num_accepted_tokens = \
319+
prometheus_client.Counter(
320+
name="vllm:spec_decode_num_accepted_tokens_total",
321+
documentation="Number of accepted tokens.",
322+
labelnames=labelnames).labels(*labelvalues)
323+
self.counter_spec_decode_num_emitted_tokens = \
324+
prometheus_client.Counter(
325+
name="vllm:spec_decode_num_emitted_tokens_total",
326+
documentation="Number of emitted tokens.",
327+
labelnames=labelnames).labels(*labelvalues)
328+
299329
#
300330
# Cache config info metric
301331
#
@@ -332,6 +362,14 @@ def record(self, scheduler_stats: SchedulerStats,
332362
self.counter_gpu_prefix_cache_hits.inc(
333363
scheduler_stats.prefix_cache_stats.hits)
334364

365+
if scheduler_stats.spec_decoding_stats is not None:
366+
self.counter_spec_decode_num_draft_tokens.inc(
367+
scheduler_stats.spec_decoding_stats.num_draft_tokens)
368+
self.counter_spec_decode_num_accepted_tokens.inc(
369+
scheduler_stats.spec_decoding_stats.num_accepted_tokens)
370+
self.counter_spec_decode_num_emitted_tokens.inc(
371+
scheduler_stats.spec_decoding_stats.num_emitted_tokens)
372+
335373
if iteration_stats is None:
336374
return
337375

vllm/v1/metrics/stats.py

+4
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
from dataclasses import dataclass, field
55
from typing import TYPE_CHECKING, Optional
66

7+
from vllm.v1.spec_decode.metrics import SpecDecodingStats
8+
79
if TYPE_CHECKING:
810
from vllm.v1.engine import EngineCoreEvent, EngineCoreOutput, FinishReason
911
from vllm.v1.output_processor import RequestState
@@ -35,6 +37,8 @@ class SchedulerStats:
3537
prefix_cache_stats: PrefixCacheStats = field(
3638
default_factory=PrefixCacheStats)
3739

40+
spec_decoding_stats: Optional[SpecDecodingStats] = None
41+
3842

3943
@dataclass
4044
class LoRAStats:

vllm/v1/outputs.py

+7
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55

66
import torch
77

8+
from vllm.v1.spec_decode.metrics import SpecDecodingStats
9+
810

911
class LogprobsLists(NamedTuple):
1012

@@ -50,6 +52,8 @@ class SamplerOutput:
5052
sampled_token_ids: torch.Tensor
5153
logprobs_tensors: Optional[LogprobsTensors]
5254

55+
spec_decoding_stats: Optional[SpecDecodingStats] = None
56+
5357

5458
# ModelRunnerOutput is serialized and sent to the scheduler process.
5559
# This is expensive for torch.Tensor so prefer to use list instead.
@@ -81,6 +85,8 @@ class ModelRunnerOutput:
8185
# [prompt_len]
8286
prompt_logprobs_dict: dict[str, Optional[LogprobsTensors]]
8387

88+
spec_decoding_stats: Optional[SpecDecodingStats] = None
89+
8490

8591
EMPTY_MODEL_RUNNER_OUTPUT = ModelRunnerOutput(
8692
req_ids=[],
@@ -89,4 +95,5 @@ class ModelRunnerOutput:
8995
spec_token_ids=None,
9096
logprobs=None,
9197
prompt_logprobs_dict={},
98+
spec_decoding_stats=None,
9299
)

0 commit comments

Comments
 (0)