Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add spec infer related into prometheus metrics. #4582

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 63 additions & 7 deletions vllm/engine/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,29 @@ def __init__(self, labelnames: List[str], max_model_len: int):
documentation="CPU KV-cache usage. 1 means 100 percent usage.",
labelnames=labelnames)

# Speculative infer Status in %
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add better descriptions.

Please name these:

  • vllm:spec_decode_system_efficiency
  • vllm:spec_decode_boost_ratio
  • vllm:spec_decode_draft_acceptance_rate

What is the difference between these?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We prefer to use Counters >> Gauges

Is there a way these metrics could be expressed as Counters with the rate function used in PromQL to compute the rates?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe we could directly express the total emitted token, along with steps number? so that user could do the cal they want with those counters?

self.counter_draft_tokens = Counter(
name="vllm:draft_tokens",
documentation=
"Number of speculative tokens produced by the proposal method.",
labelnames=labelnames)
self.counter_emitted_tokens = Counter(
name="vllm:emitted_tokens",
documentation="Number of tokens emitted by the entire system.",
labelnames=labelnames)
self.counter_accepted_tokens = Counter(
name="vllm:accepted_tokens",
documentation=
"Number of token accepted by the verification routine",
labelnames=labelnames)
self.counter_num_spec_tokens = Counter(
name="vllm:num_spec_tokens",
documentation="Number of speculative tokens per sequence.",
labelnames=labelnames)
self.counter_num_specs = Counter(
name="vllm:num_specs",
documentation="Number of speculative verification has been taken",
labelnames=labelnames)
# Iteration stats
self.counter_num_preemption = Counter(
name="vllm:num_preemptions_total",
Expand Down Expand Up @@ -215,6 +238,13 @@ def __init__(self, local_interval: float, labels: Dict[str, str],
self.last_local_log = time.time()
self.local_interval = local_interval

# Metadata for saving spec infer related aggregated data
self.last_accpted_tokens = 0
self.last_emitted_tokens = 0
self.last_draft_tokens = 0
self.last_spec_tokens = 0
self.last_specs = 0

# Tracked stats over current local logging interval.
self.num_prompt_tokens: List[int] = []
self.num_generation_tokens: List[int] = []
Expand Down Expand Up @@ -248,6 +278,30 @@ def _log_prometheus(self, stats: Stats) -> None:
self._log_gauge(self.metrics.gauge_cpu_cache_usage,
stats.cpu_cache_usage_sys)

if stats.spec_decode_metrics is not None:
# assume we have one bonus token each step
self._log_counter(self.metrics.counter_draft_tokens,
(stats.spec_decode_metrics.draft_tokens -
self.last_draft_tokens))
self._log_counter(self.metrics.counter_emitted_tokens,
(stats.spec_decode_metrics.emitted_tokens -
self.last_emitted_tokens))
self._log_counter(self.metrics.counter_accepted_tokens,
(stats.spec_decode_metrics.accepted_tokens -
self.last_accpted_tokens))
self._log_counter(self.metrics.counter_num_spec_tokens,
(stats.spec_decode_metrics.num_spec_tokens -
self.last_spec_tokens))
self._log_counter(
self.metrics.counter_num_specs,
(stats.spec_decode_metrics.num_specs - self.last_specs))
self.last_draft_tokens = stats.spec_decode_metrics.draft_tokens
self.last_emitted_tokens = stats.spec_decode_metrics.emitted_tokens
self.last_accpted_tokens = (
stats.spec_decode_metrics.accepted_tokens)
self.last_spec_tokens = stats.spec_decode_metrics.num_spec_tokens
self.last_specs = stats.spec_decode_metrics.num_specs

# Iteration level data
self._log_counter(self.metrics.counter_num_preemption,
stats.num_preemption_iter)
Expand Down Expand Up @@ -366,10 +420,12 @@ def log(self, stats: Stats) -> None:
def _format_spec_decode_metrics_str(
self, metrics: "SpecDecodeWorkerMetrics") -> str:

return ("Speculative metrics: "
f"Draft acceptance rate: {metrics.draft_acceptance_rate:.3f}, "
f"System efficiency: {metrics.system_efficiency:.3f}, "
f"Number of speculative tokens: {metrics.num_spec_tokens}, "
f"Number of accepted tokens: {metrics.accepted_tokens}, "
f"Number of draft tokens tokens: {metrics.draft_tokens}, "
f"Number of emitted tokens tokens: {metrics.emitted_tokens}.")
return (
"Speculative metrics: "
f"Draft acceptance rate: {metrics.draft_acceptance_rate:.3f}, "
f"System efficiency: {metrics.system_efficiency:.3f}, "
f"Number of speculative verification taken: {metrics.num_specs}, "
f"Number of speculative tokens: {metrics.num_spec_tokens}, "
f"Number of accepted tokens: {metrics.accepted_tokens}, "
f"Number of draft tokens tokens: {metrics.draft_tokens}, "
f"Number of emitted tokens tokens: {metrics.emitted_tokens}.")
6 changes: 6 additions & 0 deletions vllm/model_executor/layers/rejection_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def __init__(self,
self.num_accepted_tokens: Optional[torch.Tensor] = None
self.num_emitted_tokens: Optional[torch.Tensor] = None
self.num_draft_tokens: int = 0
self.num_specs: int = 0

def init_gpu_tensors(self, rank: int) -> None:
assert self.num_accepted_tokens is None
Expand Down Expand Up @@ -330,6 +331,11 @@ def _create_output(
self.num_emitted_tokens += (output_with_bonus_tokens != -1).sum()
self.num_draft_tokens += batch_size * k

# k might not be constant, if we enable dynamic spec
# also for ngram case, batch_size might be 0, if not matched
if batch_size > 0:
self.num_specs += k

return output_with_bonus_tokens

def _raise_if_incorrect_shape(
Expand Down
7 changes: 7 additions & 0 deletions vllm/spec_decode/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@ class SpecDecodeWorkerMetrics:
# The number of speculative tokens per sequence.
num_spec_tokens: int

# The number of speculative verification has been taken.
num_specs: int


Timer = Callable[[], float]

Expand Down Expand Up @@ -70,6 +73,7 @@ def __init__(self,
self._aggregate_num_emitted_tokens = torch.tensor(
0, dtype=torch.long, device="cpu", pin_memory=pin_memory)
self._aggregate_num_draft_tokens = 0
self._aggregate_num_specs = 0

self._rejsample_metrics_collect_interval_s = collect_interval_s
self._last_metrics_collect_time = self._timer()
Expand Down Expand Up @@ -124,6 +128,8 @@ def _copy_rejsample_metrics_async(self) -> torch.cuda.Event:
# required.
self._aggregate_num_draft_tokens = (
self._rejection_sampler.num_draft_tokens)
# Number of spec infer has been taken
self._aggregate_num_specs = (self._rejection_sampler.num_specs)

aggregate_metrics_ready = torch.cuda.Event()
aggregate_metrics_ready.record(self._copy_stream)
Expand Down Expand Up @@ -162,6 +168,7 @@ def _collect_rejsample_metrics(

return SpecDecodeWorkerMetrics(
num_spec_tokens=k,
num_specs=self._aggregate_num_specs,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can calculate this with draft_tokens // k, don't need to record

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

see get_max_num_emitted_tokens

draft_acceptance_rate=draft_acceptance_rate,
system_efficiency=system_efficiency,
accepted_tokens=accepted_tokens,
Expand Down
Loading