From 5306db8d4de169d7e10725af052734bab6596b1f Mon Sep 17 00:00:00 2001 From: Lei Wen Date: Fri, 7 Jun 2024 10:38:53 +0800 Subject: [PATCH] add spec infer related into prometheus metrics. And add a new boost_ratio metric used to directly show how much spec infer would help in saving decoding steps. Signed-off-by: Lei Wen --- vllm/engine/metrics.py | 70 +++++++++++++++++-- .../layers/rejection_sampler.py | 5 ++ vllm/spec_decode/metrics.py | 7 ++ 3 files changed, 75 insertions(+), 7 deletions(-) diff --git a/vllm/engine/metrics.py b/vllm/engine/metrics.py index ae7ae144bc04f..b7b3285d50859 100644 --- a/vllm/engine/metrics.py +++ b/vllm/engine/metrics.py @@ -60,6 +60,29 @@ def __init__(self, labelnames: List[str], max_model_len: int): documentation="CPU KV-cache usage. 1 means 100 percent usage.", labelnames=labelnames) + # Speculative infer Status in % + self.counter_draft_tokens = Counter( + name="vllm:draft_tokens", + documentation= + "Number of speculative tokens produced by the proposal method.", + labelnames=labelnames) + self.counter_emitted_tokens = Counter( + name="vllm:emitted_tokens", + documentation="Number of tokens emitted by the entire system.", + labelnames=labelnames) + self.counter_accepted_tokens = Counter( + name="vllm:accepted_tokens", + documentation= + "Number of token accepted by the verification routine", + labelnames=labelnames) + self.counter_num_spec_tokens = Counter( + name="vllm:num_spec_tokens", + documentation="Number of speculative tokens per sequence.", + labelnames=labelnames) + self.counter_num_specs = Counter( + name="vllm:num_specs", + documentation="Number of speculative verification has been taken", + labelnames=labelnames) # Iteration stats self.counter_num_preemption = Counter( name="vllm:num_preemptions_total", @@ -215,6 +238,13 @@ def __init__(self, local_interval: float, labels: Dict[str, str], self.last_local_log = time.time() self.local_interval = local_interval + # Metadata for saving spec infer related aggregated data + self.last_accpted_tokens = 0 + self.last_emitted_tokens = 0 + self.last_draft_tokens = 0 + self.last_spec_tokens = 0 + self.last_specs = 0 + # Tracked stats over current local logging interval. self.num_prompt_tokens: List[int] = [] self.num_generation_tokens: List[int] = [] @@ -248,6 +278,30 @@ def _log_prometheus(self, stats: Stats) -> None: self._log_gauge(self.metrics.gauge_cpu_cache_usage, stats.cpu_cache_usage_sys) + if stats.spec_decode_metrics is not None: + # assume we have one bonus token each step + self._log_counter(self.metrics.counter_draft_tokens, + (stats.spec_decode_metrics.draft_tokens - + self.last_draft_tokens)) + self._log_counter(self.metrics.counter_emitted_tokens, + (stats.spec_decode_metrics.emitted_tokens - + self.last_emitted_tokens)) + self._log_counter(self.metrics.counter_accepted_tokens, + (stats.spec_decode_metrics.accepted_tokens - + self.last_accpted_tokens)) + self._log_counter(self.metrics.counter_num_spec_tokens, + (stats.spec_decode_metrics.num_spec_tokens - + self.last_spec_tokens)) + self._log_counter( + self.metrics.counter_num_specs, + (stats.spec_decode_metrics.num_specs - self.last_specs)) + self.last_draft_tokens = stats.spec_decode_metrics.draft_tokens + self.last_emitted_tokens = stats.spec_decode_metrics.emitted_tokens + self.last_accpted_tokens = ( + stats.spec_decode_metrics.accepted_tokens) + self.last_spec_tokens = stats.spec_decode_metrics.num_spec_tokens + self.last_specs = stats.spec_decode_metrics.num_specs + # Iteration level data self._log_counter(self.metrics.counter_num_preemption, stats.num_preemption_iter) @@ -366,10 +420,12 @@ def log(self, stats: Stats) -> None: def _format_spec_decode_metrics_str( self, metrics: "SpecDecodeWorkerMetrics") -> str: - return ("Speculative metrics: " - f"Draft acceptance rate: {metrics.draft_acceptance_rate:.3f}, " - f"System efficiency: {metrics.system_efficiency:.3f}, " - f"Number of speculative tokens: {metrics.num_spec_tokens}, " - f"Number of accepted tokens: {metrics.accepted_tokens}, " - f"Number of draft tokens tokens: {metrics.draft_tokens}, " - f"Number of emitted tokens tokens: {metrics.emitted_tokens}.") + return ( + "Speculative metrics: " + f"Draft acceptance rate: {metrics.draft_acceptance_rate:.3f}, " + f"System efficiency: {metrics.system_efficiency:.3f}, " + f"Number of speculative verification taken: {metrics.num_specs}, " + f"Number of speculative tokens: {metrics.num_spec_tokens}, " + f"Number of accepted tokens: {metrics.accepted_tokens}, " + f"Number of draft tokens tokens: {metrics.draft_tokens}, " + f"Number of emitted tokens tokens: {metrics.emitted_tokens}.") diff --git a/vllm/model_executor/layers/rejection_sampler.py b/vllm/model_executor/layers/rejection_sampler.py index 1f2ab7e2870ca..7c2b067b22c71 100644 --- a/vllm/model_executor/layers/rejection_sampler.py +++ b/vllm/model_executor/layers/rejection_sampler.py @@ -37,6 +37,7 @@ def __init__(self, self.num_accepted_tokens: Optional[torch.Tensor] = None self.num_emitted_tokens: Optional[torch.Tensor] = None self.num_draft_tokens: int = 0 + self.num_specs: int = 0 def init_gpu_tensors(self, rank: int) -> None: assert self.num_accepted_tokens is None @@ -330,6 +331,10 @@ def _create_output( self.num_emitted_tokens += (output_with_bonus_tokens != -1).sum() self.num_draft_tokens += batch_size * k + # k might not be constant, if we enable dynamic spec + # also for ngram case, batch_size might be 0, if not matched + self.num_specs += batch_size * k + return output_with_bonus_tokens def _raise_if_incorrect_shape( diff --git a/vllm/spec_decode/metrics.py b/vllm/spec_decode/metrics.py index ab1d96c558de7..7fd120c9b88eb 100644 --- a/vllm/spec_decode/metrics.py +++ b/vllm/spec_decode/metrics.py @@ -41,6 +41,9 @@ class SpecDecodeWorkerMetrics: # The number of speculative tokens per sequence. num_spec_tokens: int + # The number of speculative verification has been taken. + num_specs: int + Timer = Callable[[], float] @@ -70,6 +73,7 @@ def __init__(self, self._aggregate_num_emitted_tokens = torch.tensor( 0, dtype=torch.long, device="cpu", pin_memory=pin_memory) self._aggregate_num_draft_tokens = 0 + self._aggregate_num_specs = 0 self._rejsample_metrics_collect_interval_s = collect_interval_s self._last_metrics_collect_time = self._timer() @@ -124,6 +128,8 @@ def _copy_rejsample_metrics_async(self) -> torch.cuda.Event: # required. self._aggregate_num_draft_tokens = ( self._rejection_sampler.num_draft_tokens) + # Number of spec infer has been taken + self._aggregate_num_specs = (self._rejection_sampler.num_specs) aggregate_metrics_ready = torch.cuda.Event() aggregate_metrics_ready.record(self._copy_stream) @@ -162,6 +168,7 @@ def _collect_rejsample_metrics( return SpecDecodeWorkerMetrics( num_spec_tokens=k, + num_specs=self._aggregate_num_specs, draft_acceptance_rate=draft_acceptance_rate, system_efficiency=system_efficiency, accepted_tokens=accepted_tokens,