Skip to content

Commit 2f808e6

Browse files
authored
[Bugfix] StatLoggers: cache spec decode metrics when they get collected. (#6645)
Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
1 parent 01c16ed commit 2f808e6

File tree

2 files changed

+122
-16
lines changed

2 files changed

+122
-16
lines changed

tests/metrics/test_metrics.py

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import time
12
from typing import List
23

34
import pytest
@@ -10,6 +11,8 @@
1011
from vllm.engine.metrics import RayPrometheusStatLogger
1112
from vllm.sampling_params import SamplingParams
1213

14+
from ..conftest import cleanup
15+
1316
MODELS = [
1417
"facebook/opt-125m",
1518
]
@@ -219,6 +222,94 @@ def test_metric_spec_decode(
219222
"does not meet expectation")
220223

221224

225+
@pytest.mark.parametrize("model", MODELS)
226+
@pytest.mark.parametrize("dtype", ["half"])
227+
@pytest.mark.parametrize("max_tokens", [10])
228+
@pytest.mark.parametrize("log_interval", [1, 3, 5, 7])
229+
def test_metric_spec_decode_interval(
230+
vllm_runner,
231+
example_prompts,
232+
model: str,
233+
dtype: str,
234+
max_tokens: int,
235+
log_interval: int,
236+
) -> None:
237+
k = 5
238+
239+
engine_args = EngineArgs(model=model,
240+
dtype=dtype,
241+
disable_log_stats=False,
242+
gpu_memory_utilization=0.4,
243+
speculative_model=model,
244+
num_speculative_tokens=k,
245+
use_v2_block_manager=True,
246+
enforce_eager=True)
247+
248+
engine = LLMEngine.from_engine_args(engine_args)
249+
250+
try:
251+
252+
engine.add_request(
253+
"request-id-0",
254+
example_prompts[0],
255+
SamplingParams(max_tokens=max_tokens),
256+
)
257+
258+
# set log internal
259+
stat_logger = engine.stat_loggers['prometheus']
260+
stat_logger.local_interval = log_interval
261+
262+
# prefill
263+
engine.step()
264+
265+
# wait for 5 seconds to ensure that spec decode metrics
266+
# get triggered in first decode step
267+
time.sleep(5)
268+
269+
# first decode step should trigger async collection of metrics
270+
engine.step()
271+
272+
# wait one second to allow H2D transfer to finish
273+
time.sleep(1)
274+
275+
# second decode step should now be able to collect the spec
276+
# decode stats and the request should also be finished
277+
engine.step()
278+
279+
# must have finisehd now
280+
assert not engine.has_unfinished_requests()
281+
282+
# wait to ensure logging occurs
283+
time.sleep(log_interval)
284+
285+
# force logging
286+
engine.step()
287+
288+
# Note that the purpose of this test is to verify spec decode
289+
# metrics instead of functional correctness, so the expected values
290+
# are intended to be loose.
291+
metric_name_to_expected_fn = {
292+
"gauge_spec_decode_draft_acceptance_rate": lambda v: 0 <= v <= 1,
293+
"gauge_spec_decode_efficiency": lambda v: 0 <= v <= 1,
294+
"counter_spec_decode_num_accepted_tokens": lambda v: 0 <= v <= k,
295+
"counter_spec_decode_num_draft_tokens": lambda v: v == k,
296+
"counter_spec_decode_num_emitted_tokens":
297+
lambda v: 0 <= v <= k + 1,
298+
}
299+
300+
for metric_name, is_expected in metric_name_to_expected_fn.items():
301+
metric_val = getattr(
302+
stat_logger.metrics,
303+
metric_name).labels(**stat_logger.labels)._value.get()
304+
assert is_expected(metric_val), (
305+
f"the value of metric {metric_name} ({metric_val}) "
306+
"does not meet expectation")
307+
308+
finally:
309+
del engine
310+
cleanup()
311+
312+
222313
def assert_metrics(engine: LLMEngine, disable_log_stats: bool,
223314
num_requests: int) -> None:
224315
if disable_log_stats:

vllm/engine/metrics.py

Lines changed: 31 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -355,6 +355,7 @@ def __init__(self, local_interval: float) -> None:
355355
self.num_generation_tokens: List[int] = []
356356
self.last_local_log = time.time()
357357
self.local_interval = local_interval
358+
self.spec_decode_metrics: Optional["SpecDecodeWorkerMetrics"] = None
358359

359360
@abstractmethod
360361
def info(self, type: str, obj: SupportsMetricsInfo) -> None:
@@ -364,6 +365,12 @@ def info(self, type: str, obj: SupportsMetricsInfo) -> None:
364365
def log(self, stats: Stats) -> None:
365366
raise NotImplementedError
366367

368+
def maybe_update_spec_decode_metrics(self, stats: Stats):
369+
"""Save spec decode metrics (since they are unlikely
370+
to be emitted at same time as log interval)."""
371+
if stats.spec_decode_metrics is not None:
372+
self.spec_decode_metrics = stats.spec_decode_metrics
373+
367374

368375
class LoggingStatLogger(StatLoggerBase):
369376
"""LoggingStatLogger is used in LLMEngine to log to Stdout."""
@@ -379,6 +386,9 @@ def log(self, stats: Stats) -> None:
379386
self.num_prompt_tokens.append(stats.num_prompt_tokens_iter)
380387
self.num_generation_tokens.append(stats.num_generation_tokens_iter)
381388

389+
# Update spec decode metrics
390+
self.maybe_update_spec_decode_metrics(stats)
391+
382392
# Log locally every local_interval seconds.
383393
if local_interval_elapsed(stats.now, self.last_local_log,
384394
self.local_interval):
@@ -408,15 +418,16 @@ def log(self, stats: Stats) -> None:
408418
stats.cpu_cache_usage_sys * 100,
409419
)
410420

421+
if self.spec_decode_metrics is not None:
422+
logger.info(
423+
self._format_spec_decode_metrics_str(
424+
self.spec_decode_metrics))
425+
411426
# Reset tracked stats for next interval.
412427
self.num_prompt_tokens = []
413428
self.num_generation_tokens = []
414429
self.last_local_log = stats.now
415-
416-
if stats.spec_decode_metrics is not None:
417-
logger.info(
418-
self._format_spec_decode_metrics_str(
419-
stats.spec_decode_metrics))
430+
self.spec_decode_metrics = None
420431

421432
def _format_spec_decode_metrics_str(
422433
self, metrics: "SpecDecodeWorkerMetrics") -> str:
@@ -533,6 +544,9 @@ def log(self, stats: Stats):
533544
self.num_prompt_tokens.append(stats.num_prompt_tokens_iter)
534545
self.num_generation_tokens.append(stats.num_generation_tokens_iter)
535546

547+
# Update spec decode metrics
548+
self.maybe_update_spec_decode_metrics(stats)
549+
536550
# Log locally every local_interval seconds.
537551
if local_interval_elapsed(stats.now, self.last_local_log,
538552
self.local_interval):
@@ -550,26 +564,27 @@ def log(self, stats: Stats):
550564
prompt_throughput=prompt_throughput,
551565
generation_throughput=generation_throughput)
552566

553-
# Reset tracked stats for next interval.
554-
self.num_prompt_tokens = []
555-
self.num_generation_tokens = []
556-
self.last_local_log = stats.now
557-
558-
if stats.spec_decode_metrics is not None:
567+
if self.spec_decode_metrics is not None:
559568
self._log_gauge(
560569
self.metrics.gauge_spec_decode_draft_acceptance_rate,
561-
stats.spec_decode_metrics.draft_acceptance_rate)
570+
self.spec_decode_metrics.draft_acceptance_rate)
562571
self._log_gauge(self.metrics.gauge_spec_decode_efficiency,
563-
stats.spec_decode_metrics.system_efficiency)
572+
self.spec_decode_metrics.system_efficiency)
564573
self._log_counter(
565574
self.metrics.counter_spec_decode_num_accepted_tokens,
566-
stats.spec_decode_metrics.accepted_tokens)
575+
self.spec_decode_metrics.accepted_tokens)
567576
self._log_counter(
568577
self.metrics.counter_spec_decode_num_draft_tokens,
569-
stats.spec_decode_metrics.draft_tokens)
578+
self.spec_decode_metrics.draft_tokens)
570579
self._log_counter(
571580
self.metrics.counter_spec_decode_num_emitted_tokens,
572-
stats.spec_decode_metrics.emitted_tokens)
581+
self.spec_decode_metrics.emitted_tokens)
582+
583+
# Reset tracked stats for next interval.
584+
self.num_prompt_tokens = []
585+
self.num_generation_tokens = []
586+
self.last_local_log = stats.now
587+
self.spec_decode_metrics = None
573588

574589

575590
class RayPrometheusStatLogger(PrometheusStatLogger):

0 commit comments

Comments
 (0)