Skip to content

Commit

Permalink
Port metrics from aioprometheus to prometheus_client (vllm-projec…
Browse files Browse the repository at this point in the history
  • Loading branch information
hmellor authored Feb 25, 2024
1 parent f7c1234 commit ef978fe
Show file tree
Hide file tree
Showing 9 changed files with 133 additions and 87 deletions.
2 changes: 1 addition & 1 deletion docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@

# Mock out external dependencies here.
autodoc_mock_imports = [
"torch", "transformers", "psutil", "aioprometheus", "sentencepiece",
"torch", "transformers", "psutil", "prometheus_client", "sentencepiece",
"vllm.cuda_utils", "vllm._C"
]

Expand Down
2 changes: 1 addition & 1 deletion requirements-neuron.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,4 @@ neuronx-cc
fastapi
uvicorn[standard]
pydantic >= 2.0 # Required for OpenAI server.
aioprometheus[starlette]
prometheus_client
2 changes: 1 addition & 1 deletion requirements-rocm.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,4 @@ transformers >= 4.38.0 # Required for Gemma.
fastapi
uvicorn[standard]
pydantic >= 2.0 # Required for OpenAI server.
aioprometheus[starlette]
prometheus_client
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ xformers == 0.0.23.post1 # Required for CUDA 12.1.
fastapi
uvicorn[standard]
pydantic >= 2.0 # Required for OpenAI server.
aioprometheus[starlette]
prometheus_client
pynvml == 11.5.0
triton >= 2.1.0
cupy-cuda12x == 12.1.0 # Required for CUDA graphs. CUDA 11.8 users should install cupy-cuda11x instead.
2 changes: 2 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,7 @@ def __init__(
dtype: str = "half",
disable_log_stats: bool = True,
tensor_parallel_size: int = 1,
**kwargs,
) -> None:
self.model = LLM(
model=model_name,
Expand All @@ -174,6 +175,7 @@ def __init__(
swap_space=0,
disable_log_stats=disable_log_stats,
tensor_parallel_size=tensor_parallel_size,
**kwargs,
)

def generate(
Expand Down
25 changes: 14 additions & 11 deletions tests/metrics/test_metrics.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import pytest
import vllm.engine.metrics

MODELS = [
"facebook/opt-125m",
Expand All @@ -16,10 +15,10 @@ def test_metric_counter_prompt_tokens(
dtype: str,
max_tokens: int,
) -> None:
# Reset metric
vllm.engine.metrics.counter_prompt_tokens.set_value({}, 0)

vllm_model = vllm_runner(model, dtype=dtype, disable_log_stats=False)
vllm_model = vllm_runner(model,
dtype=dtype,
disable_log_stats=False,
gpu_memory_utilization=0.4)
tokenizer = vllm_model.model.get_tokenizer()
prompt_token_counts = [len(tokenizer.encode(p)) for p in example_prompts]
# This test needs at least 2 prompts in a batch of different lengths to verify their token count is correct despite padding.
Expand All @@ -29,7 +28,9 @@ def test_metric_counter_prompt_tokens(
vllm_prompt_token_count = sum(prompt_token_counts)

_ = vllm_model.generate_greedy(example_prompts, max_tokens)
metric_count = vllm.engine.metrics.counter_prompt_tokens.get_value({})
stat_logger = vllm_model.model.llm_engine.stat_logger
metric_count = stat_logger.metrics.counter_prompt_tokens.labels(
**stat_logger.labels)._value.get()

assert vllm_prompt_token_count == metric_count, (
f"prompt token count: {vllm_prompt_token_count!r}\nmetric: {metric_count!r}"
Expand All @@ -46,13 +47,15 @@ def test_metric_counter_generation_tokens(
dtype: str,
max_tokens: int,
) -> None:
# Reset metric
vllm.engine.metrics.counter_generation_tokens.set_value({}, 0)

vllm_model = vllm_runner(model, dtype=dtype, disable_log_stats=False)
vllm_model = vllm_runner(model,
dtype=dtype,
disable_log_stats=False,
gpu_memory_utilization=0.4)
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
tokenizer = vllm_model.model.get_tokenizer()
metric_count = vllm.engine.metrics.counter_generation_tokens.get_value({})
stat_logger = vllm_model.model.llm_engine.stat_logger
metric_count = stat_logger.metrics.counter_generation_tokens.labels(
**stat_logger.labels)._value.get()
vllm_generation_count = 0
for i in range(len(example_prompts)):
vllm_output_ids, vllm_output_str = vllm_outputs[i]
Expand Down
3 changes: 2 additions & 1 deletion vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,8 @@ def __init__(
# Metric Logging.
if self.log_stats:
self.stat_logger = StatLogger(
local_interval=_LOCAL_LOGGING_INTERVAL_SEC)
local_interval=_LOCAL_LOGGING_INTERVAL_SEC,
labels=dict(model_name=model_config.model))

self.forward_dag = None
if USE_RAY_COMPILED_DAG:
Expand Down
170 changes: 107 additions & 63 deletions vllm/engine/metrics.py
Original file line number Diff line number Diff line change
@@ -1,66 +1,94 @@
from vllm.logger import init_logger
from aioprometheus import Counter, Gauge, Histogram
from prometheus_client import Counter, Gauge, Histogram, REGISTRY, disable_created_metrics

import time
import numpy as np
from typing import List
from typing import Dict, List
from dataclasses import dataclass

logger = init_logger(__name__)

labels = {}


def add_global_metrics_labels(**kwargs):
labels.update(kwargs)

disable_created_metrics()

# The begin-* and end* here are used by the documentation generator
# to extract the metrics definitions.


# begin-metrics-definitions
gauge_avg_prompt_throughput = Gauge("vllm:avg_prompt_throughput_toks_per_s",
"Average prefill throughput in tokens/s.")
gauge_avg_generation_throughput = Gauge(
"vllm:avg_generation_throughput_toks_per_s",
"Average generation throughput in tokens/s.")
counter_prompt_tokens = Counter("vllm:prompt_tokens_total",
"Number of prefill tokens processed.")
counter_generation_tokens = Counter("vllm:generation_tokens_total",
"Number of generation tokens processed.")

gauge_scheduler_running = Gauge(
"vllm:num_requests_running",
"Number of requests currently running on GPU.")
gauge_scheduler_swapped = Gauge("vllm:num_requests_swapped",
"Number of requests swapped to CPU.")
gauge_scheduler_waiting = Gauge("vllm:num_requests_waiting",
"Number of requests waiting to be processed.")

gauge_gpu_cache_usage = Gauge(
"vllm:gpu_cache_usage_perc",
"GPU KV-cache usage. 1 means 100 percent usage.")
gauge_cpu_cache_usage = Gauge(
"vllm:cpu_cache_usage_perc",
"CPU KV-cache usage. 1 means 100 percent usage.")

histogram_time_to_first_token = Histogram(
"vllm:time_to_first_token_seconds",
"Histogram of time to first token in seconds.",
buckets=[
0.001, 0.005, 0.01, 0.02, 0.04, 0.06, 0.08, 0.1, 0.25, 0.5, 0.75, 1.0,
2.5, 5.0, 7.5, 10.0
])
histogram_time_per_output_tokens = Histogram(
"vllm:time_per_output_token_seconds",
"Histogram of time per output token in seconds.",
buckets=[
0.01, 0.025, 0.05, 0.075, 0.1, 0.15, 0.2, 0.3, 0.4, 0.5, 0.75, 1.0, 2.5
])
histogram_e2e_request_latency = Histogram(
"vllm:e2e_request_latency_seconds",
"Histogram of end to end request latency in seconds.",
buckets=[1.0, 2.5, 5.0, 10.0, 15.0, 20.0, 30.0, 40.0, 50.0, 60.0])
class Metrics:

def __init__(self, labelnames: List[str]):
# Unregister any existing vLLM collectors
for collector in list(REGISTRY._collector_to_names):
if hasattr(collector, "_name") and "vllm" in collector._name:
REGISTRY.unregister(collector)

# System stats
self.gauge_scheduler_running = Gauge(
name="vllm:num_requests_running",
documentation="Number of requests currently running on GPU.",
labelnames=labelnames)
self.gauge_scheduler_swapped = Gauge(
name="vllm:num_requests_swapped",
documentation="Number of requests swapped to CPU.",
labelnames=labelnames)
self.gauge_scheduler_waiting = Gauge(
name="vllm:num_requests_waiting",
documentation="Number of requests waiting to be processed.",
labelnames=labelnames)
self.gauge_gpu_cache_usage = Gauge(
name="vllm:gpu_cache_usage_perc",
documentation="GPU KV-cache usage. 1 means 100 percent usage.",
labelnames=labelnames)
self.gauge_cpu_cache_usage = Gauge(
name="vllm:cpu_cache_usage_perc",
documentation="CPU KV-cache usage. 1 means 100 percent usage.",
labelnames=labelnames)

# Raw stats from last model iteration
self.counter_prompt_tokens = Counter(
name="vllm:prompt_tokens_total",
documentation="Number of prefill tokens processed.",
labelnames=labelnames)
self.counter_generation_tokens = Counter(
name="vllm:generation_tokens_total",
documentation="Number of generation tokens processed.",
labelnames=labelnames)
self.histogram_time_to_first_token = Histogram(
name="vllm:time_to_first_token_seconds",
documentation="Histogram of time to first token in seconds.",
labelnames=labelnames,
buckets=[
0.001, 0.005, 0.01, 0.02, 0.04, 0.06, 0.08, 0.1, 0.25, 0.5,
0.75, 1.0, 2.5, 5.0, 7.5, 10.0
])
self.histogram_time_per_output_token = Histogram(
name="vllm:time_per_output_token_seconds",
documentation="Histogram of time per output token in seconds.",
labelnames=labelnames,
buckets=[
0.01, 0.025, 0.05, 0.075, 0.1, 0.15, 0.2, 0.3, 0.4, 0.5, 0.75,
1.0, 2.5
])
self.histogram_e2e_request_latency = Histogram(
name="vllm:e2e_request_latency_seconds",
documentation="Histogram of end to end request latency in seconds.",
labelnames=labelnames,
buckets=[1.0, 2.5, 5.0, 10.0, 15.0, 20.0, 30.0, 40.0, 50.0, 60.0])

# Legacy metrics
self.gauge_avg_prompt_throughput = Gauge(
name="vllm:avg_prompt_throughput_toks_per_s",
documentation="Average prefill throughput in tokens/s.",
labelnames=labelnames,
)
self.gauge_avg_generation_throughput = Gauge(
name="vllm:avg_generation_throughput_toks_per_s",
documentation="Average generation throughput in tokens/s.",
labelnames=labelnames,
)


# end-metrics-definitions


Expand All @@ -87,7 +115,7 @@ class Stats:
class StatLogger:
"""StatLogger is used LLMEngine to log to Promethus and Stdout."""

def __init__(self, local_interval: float) -> None:
def __init__(self, local_interval: float, labels: Dict[str, str]) -> None:
# Metadata for logging locally.
self.last_local_log = time.monotonic()
self.local_interval = local_interval
Expand All @@ -96,6 +124,10 @@ def __init__(self, local_interval: float) -> None:
self.num_prompt_tokens: List[int] = []
self.num_generation_tokens: List[int] = []

# Prometheus metrics
self.labels = labels
self.metrics = Metrics(labelnames=list(labels.keys()))

def _get_throughput(self, tracked_stats: List[int], now: float) -> float:
return float(np.sum(tracked_stats) / (now - self.last_local_log))

Expand All @@ -105,23 +137,33 @@ def _local_interval_elapsed(self, now: float) -> bool:

def _log_prometheus(self, stats: Stats) -> None:
# Set system stat gauges.
gauge_scheduler_running.set(labels, stats.num_running)
gauge_scheduler_swapped.set(labels, stats.num_swapped)
gauge_scheduler_waiting.set(labels, stats.num_waiting)
gauge_gpu_cache_usage.set(labels, stats.gpu_cache_usage)
gauge_cpu_cache_usage.set(labels, stats.cpu_cache_usage)
self.metrics.gauge_scheduler_running.labels(**self.labels).set(
stats.num_running)
self.metrics.gauge_scheduler_swapped.labels(**self.labels).set(
stats.num_swapped)
self.metrics.gauge_scheduler_waiting.labels(**self.labels).set(
stats.num_waiting)
self.metrics.gauge_gpu_cache_usage.labels(**self.labels).set(
stats.gpu_cache_usage)
self.metrics.gauge_cpu_cache_usage.labels(**self.labels).set(
stats.cpu_cache_usage)

# Add to token counters.
counter_prompt_tokens.add(labels, stats.num_prompt_tokens)
counter_generation_tokens.add(labels, stats.num_generation_tokens)
self.metrics.counter_prompt_tokens.labels(**self.labels).inc(
stats.num_prompt_tokens)
self.metrics.counter_generation_tokens.labels(**self.labels).inc(
stats.num_generation_tokens)

# Observe request level latencies in histograms.
for ttft in stats.time_to_first_tokens:
histogram_time_to_first_token.observe(labels, ttft)
self.metrics.histogram_time_to_first_token.labels(
**self.labels).observe(ttft)
for tpot in stats.time_per_output_tokens:
histogram_time_per_output_tokens.observe(labels, tpot)
self.metrics.histogram_time_per_output_token.labels(
**self.labels).observe(tpot)
for e2e in stats.time_e2e_requests:
histogram_e2e_request_latency.observe(labels, e2e)
self.metrics.histogram_e2e_request_latency.labels(
**self.labels).observe(e2e)

def _log_prometheus_interval(self, prompt_throughput: float,
generation_throughput: float) -> None:
Expand All @@ -130,8 +172,10 @@ def _log_prometheus_interval(self, prompt_throughput: float,
# Moving forward, we should use counters like counter_prompt_tokens, counter_generation_tokens
# Which log raw data and calculate summaries using rate() on the grafana/prometheus side.
# See https://github.com/vllm-project/vllm/pull/2316#discussion_r1464204666
gauge_avg_prompt_throughput.set(labels, prompt_throughput)
gauge_avg_generation_throughput.set(labels, generation_throughput)
self.metrics.gauge_avg_prompt_throughput.labels(
**self.labels).set(prompt_throughput)
self.metrics.gauge_avg_generation_throughput.labels(
**self.labels).set(generation_throughput)

def log(self, stats: Stats) -> None:
"""Called by LLMEngine.
Expand Down
12 changes: 4 additions & 8 deletions vllm/entrypoints/openai/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@
import importlib
import inspect

from aioprometheus import MetricsMiddleware
from aioprometheus.asgi.starlette import metrics
from prometheus_client import make_asgi_app
import fastapi
import uvicorn
from http import HTTPStatus
Expand All @@ -18,7 +17,6 @@

from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.engine.metrics import add_global_metrics_labels
from vllm.entrypoints.openai.protocol import CompletionRequest, ChatCompletionRequest, ErrorResponse
from vllm.logger import init_logger
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
Expand Down Expand Up @@ -141,8 +139,9 @@ def parse_args():
return parser.parse_args()


app.add_middleware(MetricsMiddleware) # Trace HTTP server metrics
app.add_route("/metrics", metrics) # Exposes HTTP metrics
# Add prometheus asgi middleware to route /metrics requests
metrics_app = make_asgi_app()
app.mount("/metrics", metrics_app)


@app.exception_handler(RequestValidationError)
Expand Down Expand Up @@ -242,9 +241,6 @@ async def authentication(request: Request, call_next):
openai_serving_completion = OpenAIServingCompletion(
engine, served_model, args.lora_modules)

# Register labels for metrics
add_global_metrics_labels(model_name=engine_args.model)

app.root_path = args.root_path
uvicorn.run(app,
host=args.host,
Expand Down

0 comments on commit ef978fe

Please sign in to comment.