Skip to content

Commit

Permalink
init draft
Browse files Browse the repository at this point in the history
Signed-off-by: rickyx <rickyx@anyscale.com>
  • Loading branch information
rickyyx committed Nov 26, 2024
1 parent 1b583cf commit 0a666e7
Show file tree
Hide file tree
Showing 8 changed files with 222 additions and 33 deletions.
11 changes: 8 additions & 3 deletions vllm/engine/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,9 +451,14 @@ def log(self, stats: Stats) -> None:
last_log=self.last_local_log)

log_fn = logger.info
if not any((prompt_throughput, generation_throughput,
self.last_prompt_throughput,
self.last_generation_throughput)):
if not any((
prompt_throughput,
generation_throughput,
self.last_prompt_throughput,
self.last_generation_throughput,
stats.num_running_sys,
stats.num_waiting_sys,
)):
# Avoid log noise on an idle production system
log_fn = logger.debug

Expand Down
7 changes: 7 additions & 0 deletions vllm/v1/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from vllm.v1.core.encoder_cache_manager import EncoderCacheManager
from vllm.v1.core.kv_cache_manager import KVCacheManager
from vllm.v1.engine import EngineCoreOutput
from vllm.v1.engine.stats import SchedulerStats
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.request import Request, RequestStatus

Expand Down Expand Up @@ -501,6 +502,12 @@ def get_num_unfinished_requests(self) -> int:
def has_unfinished_requests(self) -> bool:
return self.get_num_unfinished_requests() > 0

def get_stats(self) -> SchedulerStats:
return SchedulerStats(
num_running_reqs=len(self.running),
num_waiting_reqs=len(self.waiting),
)


@dataclass
class NewRequestData:
Expand Down
3 changes: 3 additions & 0 deletions vllm/v1/engine/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from vllm.lora.request import LoRARequest
from vllm.multimodal import MultiModalDataDict, MultiModalPlaceholderDict
from vllm.sampling_params import RequestOutputKind, SamplingParams
from vllm.v1.engine.stats import EngineCoreStats


@dataclass
Expand Down Expand Up @@ -67,6 +68,8 @@ class EngineCoreOutputs(msgspec.Struct,
# [num_reqs]
outputs: List[EngineCoreOutput]

stats: Optional[EngineCoreStats] = None


@dataclass
class EngineCoreProfile:
Expand Down
47 changes: 41 additions & 6 deletions vllm/v1/engine/async_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from vllm.config import ModelConfig, VllmConfig
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.metrics_types import StatLoggerBase
from vllm.engine.metrics_types import StatLoggerBase, Stats
from vllm.engine.protocol import EngineClient
from vllm.inputs import INPUT_REGISTRY, InputRegistry, PromptType
from vllm.inputs.preprocess import InputPreprocessor
Expand All @@ -16,10 +16,13 @@
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
from vllm.usage.usage_lib import UsageContext
from vllm.v1.engine import EngineCoreOutputs
from vllm.v1.engine.async_stream import AsyncStream
from vllm.v1.engine.core_client import EngineCoreClient
from vllm.v1.engine.detokenizer import Detokenizer
from vllm.v1.engine.processor import Processor
from vllm.v1.engine.stats import (EngineCoreStats, initialize_stats_loggers,
make_stats)
from vllm.v1.executor.gpu_executor import GPUExecutor

logger = init_logger(__name__)
Expand Down Expand Up @@ -76,7 +79,18 @@ def __init__(
asyncio_mode=True,
)

self.output_handler = None
# Async tasks that run in the background.
self.output_handler: Optional[asyncio.Task] = None

# Stats loggers. If not provided, initialize from defaults.
self.stat_loggers: Dict[str, StatLoggerBase] = {}
if self.log_stats:
if stat_loggers is not None:
self.stat_loggers = stat_loggers
else:
self.stat_loggers = initialize_stats_loggers(vllm_config)
if self.stat_loggers:
logger.info("Logging stats to: %s", list(self.stat_loggers.keys()))

def __del__(self):
self.shutdown()
Expand Down Expand Up @@ -281,10 +295,13 @@ async def _run_output_handler(self):
try:
while True:
# 1) Pull EngineCoreOutput from the EngineCore.
outputs = await self.engine_core.get_output_async()
outputs: EngineCoreOutputs = (
await self.engine_core.get_output_async())
self._log_stats(outputs.stats)

# 2) Detokenize based on the output.
request_outputs, reqs_to_abort = self.detokenizer.step(outputs)
request_outputs, reqs_to_abort = self.detokenizer.step(
outputs.outputs)

# 3) Put the RequestOutputs into the per-request AsyncStreams.
self._process_request_outputs(request_outputs)
Expand All @@ -295,11 +312,27 @@ async def _run_output_handler(self):
# 5) Abort any requests due to client cancellations.
await self._process_cancellations()

except asyncio.CancelledError:
logger.info("Engine shutting down.")
self.shutdown()
except BaseException as e:
logger.error(e)
raise e

# TODO: can we eliminate these?
def _log_stats(self, engine_core_stats: EngineCoreStats):
if not self.stat_loggers:
# No stats to log.
return

stats: Stats = make_stats(engine_core_stats=engine_core_stats)
for stat_logger in self.stat_loggers.values():
# TODO(rickyx): we here assume loggers are lightweight and
# non-blocking. To make this more robust, we should really
# have an async logger interface, which implements actual
# logging that could be cpu-heavy in a separate process
# to minimize the latency impact on the frontend engine's
# event loop.
stat_logger.log(stats)

async def abort(self, request_id: str) -> None:
# Note: Who Calls this? I dont think this is actually used.
Expand Down Expand Up @@ -340,7 +373,9 @@ async def do_log_stats(
scheduler_outputs=None,
model_output=None,
) -> None:
logger.debug("Called do_log_stats.")
raise NotImplementedError(
"V1 stats logging should not be called by user. "
"The engine client handles logging internally.")

async def check_health(self) -> None:
logger.debug("Called check_health.")
Expand Down
52 changes: 35 additions & 17 deletions vllm/v1/engine/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
EngineCoreProfile, EngineCoreRequest,
EngineCoreRequestType)
from vllm.v1.engine.mm_input_mapper import MMInputMapper
from vllm.v1.engine.stats import EngineCoreStats
from vllm.v1.executor.gpu_executor import GPUExecutor
from vllm.v1.request import Request, RequestStatus
from vllm.v1.serial_utils import PickleEncoder
Expand Down Expand Up @@ -56,8 +57,11 @@ def __init__(

assert vllm_config.model_config.task != "embedding"

logger.info("Initializing an LLM engine (v%s) with config: %s",
VLLM_VERSION, vllm_config)
logger.info(
"Initializing an LLM engine (v%s) with config: %s",
VLLM_VERSION,
vllm_config,
)

# Setup Model.
self.model_executor = executor_class(vllm_config)
Expand All @@ -72,9 +76,11 @@ def __init__(
self.mm_input_mapper = MMInputMapper(vllm_config.model_config)

# Setup scheduler.
self.scheduler = Scheduler(vllm_config.scheduler_config,
vllm_config.cache_config,
vllm_config.lora_config)
self.scheduler = Scheduler(
vllm_config.scheduler_config,
vllm_config.cache_config,
vllm_config.lora_config,
)

self._last_logging_time = time.time()

Expand All @@ -87,8 +93,10 @@ def _initialize_kv_caches(self,
num_gpu_blocks_override = cache_config.num_gpu_blocks_override
logger.info(
"Overriding num_gpu_blocks=%d with "
"num_gpu_blocks_override=%d", num_gpu_blocks,
num_gpu_blocks_override)
"num_gpu_blocks_override=%d",
num_gpu_blocks,
num_gpu_blocks_override,
)
num_gpu_blocks = num_gpu_blocks_override

num_cpu_blocks = 0
Expand Down Expand Up @@ -172,7 +180,7 @@ def __init__(

@contextmanager
def make_socket(self, path: str, type: Any) -> Iterator[zmq.Socket]:
"""Context manager for use """
"""Context manager for use"""

ctx = zmq.Context()
try:
Expand Down Expand Up @@ -245,7 +253,7 @@ def make_engine_core_process(
"vllm_config": vllm_config,
"executor_class": executor_class,
"usage_context": usage_context,
"should_shutdown": should_shutdown
"should_shutdown": should_shutdown,
}
# Run EngineCore busy loop in background process.
proc = context.Process(target=EngineCoreProc.run_engine_core,
Expand Down Expand Up @@ -284,7 +292,7 @@ def run_busy_loop(self):
self._handle_client_request(req)
break
except queue.Empty:
self._log_stats()
self._log_stats(self._make_stats(engine_outputs=[]))
logger.debug("EngineCore busy loop waiting.")
if self.should_shutdown:
return
Expand All @@ -297,21 +305,27 @@ def run_busy_loop(self):
# 3) Step the engine core.
outputs = self.step()

stats = self._make_stats(engine_outputs=outputs, )

# 4) Put EngineCoreOutputs into the output queue.
self.output_queue.put_nowait(outputs)
self.output_queue.put_nowait((outputs, stats))

self._log_stats(stats)

self._log_stats()
def _make_stats(self,
engine_outputs: List[EngineCoreOutput]) -> EngineCoreStats:
return EngineCoreStats(scheduler_stats=self.scheduler.get_stats(), )

def _log_stats(self):
def _log_stats(self, stats: EngineCoreStats):
"""Log basic stats every LOGGING_TIME_S"""

now = time.time()

if now - self._last_logging_time > LOGGING_TIME_S:
logger.info(
"RUNNING: %s | WAITING: %s",
len(self.scheduler.running),
len(self.scheduler.waiting),
stats.scheduler_stats.num_running_reqs,
stats.scheduler_stats.num_waiting_reqs,
)

self._last_logging_time = now
Expand Down Expand Up @@ -367,7 +381,11 @@ def process_output_socket(self, output_path: str):

with self.make_socket(output_path, zmq.constants.PUSH) as socket:
while True:
engine_core_outputs = self.output_queue.get()
outputs = EngineCoreOutputs(outputs=engine_core_outputs)
engine_core_outputs, engine_core_stats = self.output_queue.get(
)
outputs = EngineCoreOutputs(
outputs=engine_core_outputs,
stats=engine_core_stats,
)
encoder.encode_into(outputs, buffer)
socket.send_multipart((buffer, ), copy=False)
10 changes: 5 additions & 5 deletions vllm/v1/engine/core_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ async def profile(self, is_start=True) -> None:
def abort_requests(self, request_ids: List[str]) -> None:
raise NotImplementedError

async def get_output_async(self) -> List[EngineCoreOutput]:
async def get_output_async(self) -> EngineCoreOutputs:
raise NotImplementedError

async def add_request_async(self, request: EngineCoreRequest) -> None:
Expand Down Expand Up @@ -178,10 +178,10 @@ class SyncMPClient(MPClient):
def __init__(self, *args, **kwargs):
super().__init__(*args, asyncio_mode=False, **kwargs)

def get_output(self) -> List[EngineCoreOutput]:
def get_output(self) -> EngineCoreOutputs:

(frame, ) = self.output_socket.recv_multipart(copy=False)
engine_core_outputs = self.decoder.decode(frame.buffer).outputs
engine_core_outputs = self.decoder.decode(frame.buffer)
return engine_core_outputs

def _send_input(
Expand Down Expand Up @@ -210,10 +210,10 @@ class AsyncMPClient(MPClient):
def __init__(self, *args, **kwargs):
super().__init__(*args, asyncio_mode=True, **kwargs)

async def get_output_async(self) -> List[EngineCoreOutput]:
async def get_output_async(self) -> EngineCoreOutputs:

frames = await self.output_socket.recv_multipart(copy=False)
engine_core_outputs = self.decoder.decode(frames[0].buffer).outputs
engine_core_outputs = self.decoder.decode(frames[0].buffer)

return engine_core_outputs

Expand Down
21 changes: 19 additions & 2 deletions vllm/v1/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,12 @@
from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
from vllm.usage.usage_lib import UsageContext
from vllm.v1.engine import EngineCoreOutputs
from vllm.v1.engine.core_client import EngineCoreClient
from vllm.v1.engine.detokenizer import Detokenizer
from vllm.v1.engine.processor import Processor
from vllm.v1.engine.stats import (EngineCoreStats, initialize_stats_loggers,
make_stats)
from vllm.v1.executor.gpu_executor import GPUExecutor

logger = init_logger(__name__)
Expand Down Expand Up @@ -71,6 +74,13 @@ def __init__(
asyncio_mode=False,
)

self.stat_loggers: Dict[str, StatLoggerBase] = {}
if log_stats:
self.stat_loggers = stat_loggers or initialize_stats_loggers(
vllm_config)
if self.stat_loggers:
logger.info("Logging stats to: %s", list(self.stat_loggers.keys()))

@classmethod
def from_engine_args(
cls,
Expand Down Expand Up @@ -146,18 +156,25 @@ def add_request(
def step(self) -> List[RequestOutput]:

# 1) Get EngineCoreOutput from the EngineCore.
engine_core_outputs = self.engine_core.get_output()
engine_core_outputs: EngineCoreOutputs = self.engine_core.get_output()

# 2) Detokenizer the EngineCoreOutput.
request_outputs, requests_to_abort = self.detokenizer.step(
engine_core_outputs)
engine_core_outputs.outputs)

# 3) Abort requests that finished due to stopping criteria.
if requests_to_abort:
self.abort_request(requests_to_abort)

return request_outputs

def _log_stats(self, engine_core_stats: EngineCoreStats) -> None:
if not self.stat_loggers:
return
stats = make_stats(engine_core_stats)
for logger in self.stat_loggers.values():
logger.log(stats)

# TODO(rob): Can we get rid of these?

def get_model_config(self):
Expand Down
Loading

0 comments on commit 0a666e7

Please sign in to comment.