-
-
Notifications
You must be signed in to change notification settings - Fork 7.7k
[V1][Metrics] Add GPU prefix cache hit rate % gauge #12592
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
37915de
6494dde
7d0bed5
fd10665
60e1637
c9f8cf3
6d47433
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -9,6 +9,7 @@ | |
|
||
from vllm.config import ModelConfig | ||
from vllm.logger import init_logger | ||
from vllm.v1.core.kv_cache_utils import PrefixCachingMetrics | ||
from vllm.v1.engine import FinishReason | ||
from vllm.v1.metrics.stats import IterationStats, SchedulerStats | ||
|
||
|
@@ -37,6 +38,9 @@ def _reset(self, now): | |
self.num_prompt_tokens: List[int] = [] | ||
self.num_generation_tokens: List[int] = [] | ||
|
||
# Prefix cache metrics. TODO: Make the interval configurable. | ||
self.prefix_caching_metrics = PrefixCachingMetrics() | ||
|
||
def _local_interval_elapsed(self, now: float) -> bool: | ||
# Log every _LOCAL_LOGGING_INTERVAL_SEC. | ||
elapsed_time = now - self.last_log_time | ||
|
@@ -58,6 +62,8 @@ def log(self, scheduler_stats: SchedulerStats, | |
|
||
self._track_iteration_stats(iteration_stats) | ||
|
||
self.prefix_caching_metrics.observe(scheduler_stats.prefix_cache_stats) | ||
|
||
now = time.monotonic() | ||
if not self._local_interval_elapsed(now): | ||
return | ||
|
@@ -72,13 +78,15 @@ def log(self, scheduler_stats: SchedulerStats, | |
logger.info( | ||
"Avg prompt throughput: %.1f tokens/s, " | ||
"Avg generation throughput: %.1f tokens/s, " | ||
"Running: %d reqs, Waiting: %d reqs " | ||
"GPU KV cache usage: %.1f%%.", | ||
"Running: %d reqs, Waiting: %d reqs, " | ||
"GPU KV cache usage: %.1f%%, " | ||
"Prefix cache hit rate: %.1f%%", | ||
prompt_throughput, | ||
generation_throughput, | ||
scheduler_stats.num_running_reqs, | ||
scheduler_stats.num_waiting_reqs, | ||
scheduler_stats.gpu_cache_usage * 100, | ||
self.prefix_caching_metrics.hit_rate * 100, | ||
Comment on lines
+83
to
+89
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmm.. I wonder if it's worth the effort to have a separate If I understand correctly, this will always be 0 if prefix caching is turned off, right? (In V0, it's always -1 in this case, so we might have to keep the same behavior) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good point. In fact in v0 I disabled this log when prefix caching is disabled, but given that prefix caching is a first class citizen in v1, I feel it might be fine. Open to other opinions tho There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sounds good! |
||
) | ||
|
||
|
||
|
@@ -107,6 +115,18 @@ def __init__(self, model_config: ModelConfig): | |
documentation="GPU KV-cache usage. 1 means 100 percent usage.", | ||
labelnames=labelnames).labels(*labelvalues) | ||
|
||
self.counter_gpu_prefix_cache_queries = prometheus_client.Counter( | ||
name="vllm:gpu_prefix_cache_queries", | ||
documentation= | ||
"GPU prefix cache queries, in terms of number of queried blocks.", | ||
labelnames=labelnames).labels(*labelvalues) | ||
|
||
self.counter_gpu_prefix_cache_hits = prometheus_client.Counter( | ||
name="vllm:gpu_prefix_cache_hits", | ||
documentation= | ||
"GPU prefix cache hits, in terms of number of cached blocks.", | ||
labelnames=labelnames).labels(*labelvalues) | ||
|
||
self.counter_prompt_tokens = prometheus_client.Counter( | ||
name="vllm:prompt_tokens_total", | ||
documentation="Number of prefill tokens processed.", | ||
|
@@ -170,6 +190,11 @@ def log(self, scheduler_stats: SchedulerStats, | |
|
||
self.gauge_gpu_cache_usage.set(scheduler_stats.gpu_cache_usage) | ||
|
||
self.counter_gpu_prefix_cache_queries.inc( | ||
scheduler_stats.prefix_cache_stats.queries) | ||
self.counter_gpu_prefix_cache_hits.inc( | ||
scheduler_stats.prefix_cache_stats.hits) | ||
|
||
self.counter_prompt_tokens.inc(iteration_stats.num_prompt_tokens) | ||
self.counter_generation_tokens.inc( | ||
iteration_stats.num_generation_tokens) | ||
|
Uh oh!
There was an error while loading. Please reload this page.