Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[v1][WIP] Metrics & Stats prototype #10651

Draft
wants to merge 10 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1125,6 +1125,8 @@ class SchedulerConfig:

chunked_prefill_enabled: bool = field(init=False)

log_stats: bool = True

def __post_init__(self) -> None:
if self.max_num_batched_tokens is None:
if self.enable_chunked_prefill:
Expand Down Expand Up @@ -2039,6 +2041,9 @@ class ObservabilityConfig:
# If set, collects the model execute time for the request.
collect_model_execute_time: bool = False

# If set, collects stats for the engine.
log_stats: bool = True

def __post_init__(self):
if not is_otel_available() and self.otlp_traces_endpoint is not None:
raise ValueError(
Expand Down
5 changes: 4 additions & 1 deletion vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1153,7 +1153,9 @@ def create_engine_config(self,
multi_step_stream_outputs=self.multi_step_stream_outputs,
send_delta_data=(envs.VLLM_USE_RAY_SPMD_WORKER
and parallel_config.use_ray),
policy=self.scheduling_policy)
policy=self.scheduling_policy,
log_stats=not self.disable_log_stats,
)
lora_config = LoRAConfig(
bias_enabled=self.enable_lora_bias,
max_lora_rank=self.max_lora_rank,
Expand Down Expand Up @@ -1196,6 +1198,7 @@ def create_engine_config(self,
or "all" in detailed_trace_modules,
collect_model_execute_time="worker" in detailed_trace_modules
or "all" in detailed_trace_modules,
log_stats=not self.disable_log_stats,
)

config = VllmConfig(
Expand Down
3 changes: 2 additions & 1 deletion vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,7 +439,8 @@ def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer:
local_interval=_LOCAL_LOGGING_INTERVAL_SEC,
labels=dict(
model_name=self.model_config.served_model_name),
max_model_len=self.model_config.max_model_len),
max_model_len=self.model_config.max_model_len,
),
}
self.stat_loggers["prometheus"].info("cache_config",
self.cache_config)
Expand Down
11 changes: 8 additions & 3 deletions vllm/engine/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,9 +451,14 @@ def log(self, stats: Stats) -> None:
last_log=self.last_local_log)

log_fn = logger.info
if not any((prompt_throughput, generation_throughput,
self.last_prompt_throughput,
self.last_generation_throughput)):
if not any((
prompt_throughput,
generation_throughput,
self.last_prompt_throughput,
self.last_generation_throughput,
stats.num_running_sys,
stats.num_waiting_sys,
)):
# Avoid log noise on an idle production system
log_fn = logger.debug

Expand Down
62 changes: 31 additions & 31 deletions vllm/engine/metrics_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

import time
from abc import ABC, abstractmethod
from dataclasses import dataclass
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Protocol

from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics
Expand All @@ -26,44 +26,44 @@ class Stats:

# System stats (should have _sys suffix)
# Scheduler State
num_running_sys: int
num_waiting_sys: int
num_swapped_sys: int
num_running_sys: int = 0
num_waiting_sys: int = 0
num_swapped_sys: int = 0
# KV Cache Usage in %
gpu_cache_usage_sys: float
cpu_cache_usage_sys: float
gpu_cache_usage_sys: float = 0.0
cpu_cache_usage_sys: float = 0.0
# Prefix caching block hit rate
cpu_prefix_cache_hit_rate: float
gpu_prefix_cache_hit_rate: float
cpu_prefix_cache_hit_rate: float = 0.0
gpu_prefix_cache_hit_rate: float = 0.0

# Iteration stats (should have _iter suffix)
num_prompt_tokens_iter: int
num_generation_tokens_iter: int
num_tokens_iter: int
time_to_first_tokens_iter: List[float]
time_per_output_tokens_iter: List[float]
num_preemption_iter: int
num_prompt_tokens_iter: int = 0
num_generation_tokens_iter: int = 0
num_tokens_iter: int = 0
time_to_first_tokens_iter: List[float] = field(default_factory=list)
time_per_output_tokens_iter: List[float] = field(default_factory=list)
num_preemption_iter: int = 0

# Request stats (should have _requests suffix)
# Latency
time_e2e_requests: List[float]
time_queue_requests: List[float]
time_inference_requests: List[float]
time_prefill_requests: List[float]
time_decode_requests: List[float]
time_in_queue_requests: List[float]
model_forward_time_requests: List[float]
model_execute_time_requests: List[float]
time_e2e_requests: List[float] = field(default_factory=list)
time_queue_requests: List[float] = field(default_factory=list)
time_inference_requests: List[float] = field(default_factory=list)
time_prefill_requests: List[float] = field(default_factory=list)
time_decode_requests: List[float] = field(default_factory=list)
time_in_queue_requests: List[float] = field(default_factory=list)
model_forward_time_requests: List[float] = field(default_factory=list)
model_execute_time_requests: List[float] = field(default_factory=list)
# Metadata
num_prompt_tokens_requests: List[int]
num_generation_tokens_requests: List[int]
n_requests: List[int]
max_num_generation_tokens_requests: List[int]
max_tokens_requests: List[int]
finished_reason_requests: List[str]
waiting_lora_adapters: List[str]
running_lora_adapters: List[str]
max_lora: str
num_prompt_tokens_requests: List[int] = field(default_factory=list)
num_generation_tokens_requests: List[int] = field(default_factory=list)
n_requests: List[int] = field(default_factory=list)
max_num_generation_tokens_requests: List[int] = field(default_factory=list)
max_tokens_requests: List[int] = field(default_factory=list)
finished_reason_requests: List[str] = field(default_factory=list)
waiting_lora_adapters: List[str] = field(default_factory=list)
running_lora_adapters: List[str] = field(default_factory=list)
max_lora: str = "0"

spec_decode_metrics: Optional["SpecDecodeWorkerMetrics"] = None

Expand Down
5 changes: 5 additions & 0 deletions vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@
VLLM_DISABLED_KERNELS: List[str] = []
VLLM_USE_V1: bool = False
VLLM_ENABLE_V1_MULTIPROCESSING: bool = False
VLLM_STATS_ENGINE_POLLING_INTERVAL_S: int = 1


def get_default_cache_root():
Expand Down Expand Up @@ -457,6 +458,10 @@ def get_default_config_root():
# If set, enable multiprocessing in LLM for the V1 code path.
"VLLM_ENABLE_V1_MULTIPROCESSING":
lambda: bool(int(os.getenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0"))),

# Interval in seconds to poll the engine stats.
"VLLM_STATS_ENGINE_POLLING_INTERVAL_S":
lambda: int(os.getenv("VLLM_STATS_ENGINE_POLLING_INTERVAL_S", "1")),
}

# end-env-vars-definition
Expand Down
94 changes: 94 additions & 0 deletions vllm/v1/core/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple

from vllm.sampling_params import SamplingParams
from vllm.v1.request import Request

if TYPE_CHECKING:
from vllm.multimodal import MultiModalKwargs
from vllm.multimodal.base import PlaceholderRange


@dataclass
class NewRequestData:
req_id: str
prompt_token_ids: List[int]
prompt: Optional[str]
mm_inputs: List["MultiModalKwargs"]
mm_positions: List["PlaceholderRange"]
sampling_params: SamplingParams
block_ids: List[int]
num_computed_tokens: int

@classmethod
def from_request(
cls,
request: Request,
block_ids: List[int],
num_computed_tokens: int,
) -> "NewRequestData":
return cls(
req_id=request.request_id,
prompt_token_ids=request.prompt_token_ids,
prompt=request.prompt,
mm_inputs=request.mm_inputs,
mm_positions=request.mm_positions,
sampling_params=request.sampling_params,
block_ids=block_ids,
num_computed_tokens=num_computed_tokens,
)


@dataclass
class ResumedRequestData:
req_id: str
block_ids: List[int]
num_computed_tokens: int

@classmethod
def from_request(
cls,
request: Request,
block_ids: List[int],
num_computed_tokens: int,
) -> "ResumedRequestData":
return cls(
req_id=request.request_id,
block_ids=block_ids,
num_computed_tokens=num_computed_tokens,
)


@dataclass
class RunningRequestData:
req_id: str
new_block_ids: List[int]
num_computed_tokens: int

@classmethod
def from_request(
cls,
request: Request,
new_block_ids: List[int],
num_computed_tokens: int,
) -> "RunningRequestData":
return cls(
req_id=request.request_id,
new_block_ids=new_block_ids,
num_computed_tokens=num_computed_tokens,
)


@dataclass
class SchedulerOutput:
scheduled_new_reqs: List[NewRequestData]
scheduled_resumed_reqs: List[ResumedRequestData]
scheduled_running_reqs: List[RunningRequestData]

num_scheduled_tokens: Dict[str, int]
total_num_scheduled_tokens: int
scheduled_encoder_inputs: Dict[str, List[int]]

preempted_req_ids: Set[str]
finished_req_ids: Set[str]
free_encoder_input_ids: List[Tuple[str, int]]
12 changes: 12 additions & 0 deletions vllm/v1/core/kv_cache_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
KVCacheBlock, hash_block_tokens,
hash_request_tokens)
from vllm.v1.request import Request
from vllm.v1.stats.common import KVCacheStats

logger = init_logger(__name__)

Expand Down Expand Up @@ -337,6 +338,17 @@ def _get_cached_block(self,
return self.cached_block_hash_to_block[block_hash][first_block_id]
return None

def get_kv_cache_stats(self) -> KVCacheStats:
num_free_blocks = self.free_block_queue.num_free_blocks
num_used_blocks = self.num_gpu_blocks - num_free_blocks
assert num_used_blocks <= self.num_gpu_blocks
return KVCacheStats(
gpu_cache_usage_sys=num_used_blocks / self.num_gpu_blocks,
# TODO: we might just be able to compute this from the request's
# num cached tokens count
gpu_prefix_cache_hit_rate=0.0,
)

def _touch(self, blocks: List[KVCacheBlock]) -> None:
"""Touch a block increases its reference count by 1, and may remove
the block from the free queue. This is used when a block is hit by
Expand Down
Loading
Loading