Skip to content
Closed
6 changes: 5 additions & 1 deletion benchmarks/backend_request_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,10 +272,14 @@ async def async_request_openai_completions(
try:
async with session.post(url=api_url, json=payload,
headers=headers) as response:
#print(f"RES = {response.status}")
if response.status == 200:
first_chunk_received = False

#print(response)
async for chunk_bytes in response.content:
chunk_bytes = chunk_bytes.strip()
#print(f"CB = {chunk_bytes}")
if not chunk_bytes:
continue

Expand Down Expand Up @@ -313,7 +317,7 @@ async def async_request_openai_completions(
else:
output.success = False
output.error = (
"Never received a valid chunk to calculate TTFT."
"Never received a valid chunk to calculate TTFT. "
"This response will be marked as failed!")
output.generated_text = generated_text
output.latency = most_recent_timestamp - st
Expand Down
2 changes: 1 addition & 1 deletion examples/offline_inference/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,4 @@
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
2 changes: 1 addition & 1 deletion tests/entrypoints/openai/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ async def test_metrics_counts(server: RemoteOpenAIServer,
assert sample.value == expected_value, (
f"{metric_name_w_suffix} expected value of "
f"{expected_value} did not match found value "
f"{sample.value}")
f"{sample.value}, use_v1={use_v1}")
break
assert found_suffix, (
f"Did not find {metric_name_w_suffix} in prom endpoint"
Expand Down
90 changes: 67 additions & 23 deletions vllm/v1/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,16 @@
from vllm.v1.core.scheduler_output import (CachedRequestData, NewRequestData,
SchedulerOutput)
from vllm.v1.engine import (EngineCoreEvent, EngineCoreEventType,
EngineCoreOutput, EngineCoreOutputs)
EngineCoreOutputs)
from vllm.v1.metrics.stats import SchedulerStats
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.request import Request, RequestStatus
from vllm.v1.sample.rejection_sampler import INVALID_TOKEN_ID

import cProfile
import pyinstrument
import torch
import numpy as np

logger = init_logger(__name__)

Expand Down Expand Up @@ -96,6 +102,8 @@ def __init__(
self.encoder_cache_manager = EncoderCacheManager(
cache_size=encoder_cache_size)

#self.profiler = cProfile.Profile()

def schedule(self) -> "SchedulerOutput":
# NOTE(woosuk) on the scheduling algorithm:
# There's no "decoding phase" nor "prefill phase" in the scheduler.
Expand Down Expand Up @@ -470,6 +478,16 @@ def _try_schedule_encoder_inputs(
encoder_inputs_to_schedule.append(i)
return encoder_inputs_to_schedule, num_new_tokens, encoder_budget

def Xupdate_from_output(
self,
scheduler_output: "SchedulerOutput",
model_runner_output: "ModelRunnerOutput",
) -> EngineCoreOutputs:
self.profiler.enable()
res = self._update_from_output(scheduler_output, model_runner_output)
self.profiler.disable()
return res

def update_from_output(
self,
scheduler_output: "SchedulerOutput",
Expand All @@ -482,21 +500,27 @@ def update_from_output(
num_scheduled_tokens = scheduler_output.num_scheduled_tokens

new_running: List[Request] = []
outputs: List[EngineCoreOutput] = []
output = EngineCoreOutputs()

# NOTE(woosuk): As len(self.running) can be up to 1K or more, the below
# loop can be a performance bottleneck. We should do our best to avoid
# expensive operations inside the loop.
for request in self.running:
for i, request in enumerate(self.running):
req_id = request.request_id
num_tokens_scheduled = num_scheduled_tokens.get(req_id, 0)
if num_tokens_scheduled == 0:
# The request was not scheduled in this step.
new_running.append(request)
continue

#print(f"r2i = {model_runner_output.req_id_to_index}")

req_index = model_runner_output.req_id_to_index[req_id]
generated_token_ids = sampled_token_ids[req_index]

if not isinstance(generated_token_ids, np.ndarray):
generated_token_ids = [generated_token_ids]

if req_id not in scheduler_output.scheduled_spec_decode_tokens:
# When the request's num_computed_tokens catches up
# its num_tokens, the request generates output tokens.
Expand Down Expand Up @@ -537,17 +561,19 @@ def update_from_output(
if spec_token_ids is not None:
request.spec_token_ids = spec_token_ids[req_index]

# Get prompt logprobs for this request.
prompt_logprobs_tensors = prompt_logprobs_dict.get(req_id)

stopped = False
new_logprobs = None
new_token_ids: List[int] = []
num_new_tokens = 0

if request.num_computed_tokens >= request.num_tokens:
# This loop seems inefficient.
#print(f"G = {generated_token_ids}")
for output_token_id in generated_token_ids:
output_token_id = int(output_token_id)
if output_token_id == INVALID_TOKEN_ID:
continue
request.append_output_token_ids(output_token_id)
new_token_ids.append(output_token_id)
num_new_tokens = num_new_tokens + 1

# Check for stop and update request state.
# This must be called before we make the EngineCoreOutput.
Expand All @@ -564,27 +590,41 @@ def update_from_output(
new_logprobs = logprobs.slice(req_index, req_index + 1)

# Transmit partial if chunked prefill & prompt logprobs is enabled
if new_token_ids or prompt_logprobs_tensors is not None:
# Add EngineCoreOutput for this Request.
outputs.append(
EngineCoreOutput(
request_id=req_id,
new_token_ids=new_token_ids,
finish_reason=request.get_finished_reason(),
new_logprobs=new_logprobs,
new_prompt_logprobs_tensors=prompt_logprobs_tensors,
stop_reason=request.stop_reason,
events=request.take_events()))
if num_new_tokens > 0 or req_id in prompt_logprobs_dict:
# Update EngineCoreOutputs for this Request.
output.request_ids.append(req_id)

# TODO: try to eliminate this if all the offsets are adjacent?
output.new_token_id_offsets.append(model_runner_output.req_id_to_index[req_id])

if (num_new_tokens != 1 or
output.new_token_id_counts is not None):
if output.new_token_id_counts is None:
output.new_token_id_counts = [1] * i
output.new_token_id_counts.append(num_new_tokens)

if new_logprobs is not None:
output.new_logprobs[req_id] = new_logprobs

finish_reason = request.get_finished_reason()
if finish_reason is not None:
output.finish_reason[req_id] = (finish_reason, None)

events = request.take_events()
if events is not None:
if output.events is None:
output.events = [None] * i
output.events.append(events)

self.scheduled_req_ids.remove(request.request_id)
if not stopped:
new_running.append(request)

self.running = new_running
return EngineCoreOutputs(
outputs=outputs,
scheduler_stats=self.make_stats(),
)
output.new_token_ids = sampled_token_ids
output.new_prompt_logprobs_tensors = prompt_logprobs_dict
output.scheduler_stats = self.make_stats()
return output

def _check_stop(self, request: Request) -> bool:
if (request.num_tokens >= self.max_model_len
Expand Down Expand Up @@ -640,6 +680,10 @@ def finish_requests(
request.status = finished_status
self._free_request(request)

def print_stats(self):
#self.profiler.print_stats('cumulative')
pass

def _free_request(self, request: Request) -> None:
assert request.is_finished()
self.kv_cache_manager.free(request)
Expand Down
59 changes: 35 additions & 24 deletions vllm/v1/engine/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@

import enum
import time
from typing import Any, List, Optional, Union
from typing import Any, Dict, List, Optional, Tuple, Union

import msgspec
import torch
import numpy as np

from vllm.lora.request import LoRARequest
from vllm.multimodal import MultiModalKwargs
Expand Down Expand Up @@ -85,25 +87,25 @@ def new_event(cls,
return cls(event_type, timestamp)


class EngineCoreOutput(
msgspec.Struct,
array_like=True, # type: ignore[call-arg]
omit_defaults=True, # type: ignore[call-arg]
gc=False): # type: ignore[call-arg]

request_id: str
new_token_ids: List[int]

new_logprobs: Optional[LogprobsLists] = None
new_prompt_logprobs_tensors: Optional[LogprobsTensors] = None

finish_reason: Optional[FinishReason] = None
stop_reason: Union[int, str, None] = None
events: Optional[List[EngineCoreEvent]] = None

@property
def finished(self) -> bool:
return self.finish_reason is not None
#class EngineCoreOutput(
# msgspec.Struct,
# array_like=True, # type: ignore[call-arg]
# omit_defaults=True, # type: ignore[call-arg]
# gc=False): # type: ignore[call-arg]
#
# request_id: str
# new_token_ids: List[int]
#
# new_logprobs: Optional[LogprobsLists] = None
# new_prompt_logprobs_tensors: Optional[LogprobsTensors] = None
#
# finish_reason: Optional[FinishReason] = None
# stop_reason: Union[int, str, None] = None
# events: Optional[List[EngineCoreEvent]] = None
#
# @property
# def finished(self) -> bool:
# return self.finish_reason is not None


class UtilityOutput(
Expand All @@ -124,11 +126,20 @@ class EngineCoreOutputs(
omit_defaults=True, # type: ignore[call-arg]
gc=False): # type: ignore[call-arg]

#NOTE(Nick): We could consider ways to make this more compact,
# e.g. columnwise layout

# [num_reqs]
outputs: List[EngineCoreOutput] = []
request_ids: List[str] = []
new_token_id_offsets : List[int] = []
new_token_id_counts: Optional[List[int]] = None # ndarray?
Comment on lines +131 to +132
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes keep as array ... and we don't need both offsets and counts right?

new_token_ids: np.ndarray = np.empty(0, dtype=int) # Optional?

# req_id -> LogprobsLists
new_logprobs: Dict[str, LogprobsLists] = {}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should change these to LogprobsTensors too


# req_id -> LogprobsTensors
new_prompt_logprobs_tensors: Dict[str, LogprobsTensors] = {}

finish_reason: Dict[str, Tuple[FinishReason, Union[int, str, None]]] = {}
events: Optional[List[Optional[List[EngineCoreEvent]]]] = None
scheduler_stats: Optional[SchedulerStats] = None
timestamp: float = 0.0

Expand Down
32 changes: 21 additions & 11 deletions vllm/v1/engine/async_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@
import os
from typing import AsyncGenerator, List, Mapping, Optional, Type, Union

import numpy as np

from vllm.config import ModelConfig, VllmConfig
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.protocol import EngineClient
Expand All @@ -30,6 +28,9 @@
StatLoggerBase)
from vllm.v1.metrics.stats import IterationStats, SchedulerStats

import cProfile as profile
import pyinstrument

logger = init_logger(__name__)


Expand Down Expand Up @@ -254,23 +255,32 @@ async def _run_output_handler(self):
# Split outputs into chunks of at most
# VLLM_V1_OUTPUT_PROC_CHUNK_SIZE, so that we don't block the
# event loop for too long.
num_outputs = len(outputs.outputs)
if num_outputs <= VLLM_V1_OUTPUT_PROC_CHUNK_SIZE:
slices = (outputs.outputs, )
num_requests = len(outputs.request_ids)
if num_requests <= VLLM_V1_OUTPUT_PROC_CHUNK_SIZE:
num_chunks = 1
chunk_size = num_requests
rem = 0
else:
Comment on lines +259 to 263
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could just keep the else logic here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I got div by zero when I tried just the else code path, so I left both branches.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm that should only be possible if outputs.request_ids is empty ... I don't remember if that should ever happen but if it does we would just skip the loop anyhow (unless we need to still update the iteration stats in this case)

slices = np.array_split(
outputs.outputs,
cdiv(num_outputs, VLLM_V1_OUTPUT_PROC_CHUNK_SIZE))
num_chunks = cdiv(num_requests,
VLLM_V1_OUTPUT_PROC_CHUNK_SIZE)
chunk_size = num_requests // num_chunks
rem = num_requests % num_chunks

slice_start = 0
for i in range(num_chunks):
adj = 1 if i < rem else 0
slice_end = slice_start + chunk_size + adj

for i, outputs_slice in enumerate(slices):
# 2) Process EngineCoreOutputs.
processed_outputs = self.output_processor.process_outputs(
outputs_slice, outputs.timestamp, iteration_stats)
outputs, slice_start, slice_end, outputs.timestamp,
iteration_stats)
slice_start = slice_end
# NOTE: RequestOutputs are pushed to their queues.
assert not processed_outputs.request_outputs

# Allow other asyncio tasks to run between chunks
if i + 1 < len(slices):
if i + 1 < num_chunks:
await asyncio.sleep(0)

# 3) Abort any reqs that finished due to stop strings.
Expand Down
5 changes: 3 additions & 2 deletions vllm/v1/engine/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def step(self) -> EngineCoreOutputs:

if not self.scheduler.has_unfinished_requests():
return EngineCoreOutputs(
outputs=[], scheduler_stats=self.scheduler.make_stats())
scheduler_stats=self.scheduler.make_stats())

scheduler_output = self.scheduler.schedule()
output = self.model_executor.execute_model(scheduler_output)
Expand Down Expand Up @@ -200,7 +200,7 @@ def step_with_batch_queue(self) -> Optional[EngineCoreOutputs]:
# If the queue is empty (timeout at .get), return
# an empty EngineCoreOutputs for logging.
engine_core_outputs = EngineCoreOutputs(
outputs=[], scheduler_stats=self.scheduler.make_stats())
scheduler_stats=self.scheduler.make_stats())

return engine_core_outputs

Expand Down Expand Up @@ -287,6 +287,7 @@ def signal_handler(signum, frame):

finally:
if engine_core is not None:
engine_core.scheduler.print_stats()
engine_core.shutdown()

def run_busy_loop(self):
Expand Down
4 changes: 2 additions & 2 deletions vllm/v1/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,12 +145,12 @@ def add_request(

def step(self) -> List[RequestOutput]:

# 1) Get EngineCoreOutput from the EngineCore.
# 1) Get EngineCoreOutputs from the EngineCore.
outputs = self.engine_core.get_output()

# 2) Process EngineCoreOutputs.
processed_outputs = self.output_processor.process_outputs(
outputs.outputs)
outputs, 0, len(outputs.request_ids))

# 3) Abort any reqs that finished due to stop strings.
self.engine_core.abort_requests(processed_outputs.reqs_to_abort)
Expand Down
Loading