Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[V1] AsyncLLMEngine #9826

Draft
wants to merge 65 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
65 commits
Select commit Hold shift + click to select a range
8f8662e
prototype
robertgshaw2-neuralmagic Oct 26, 2024
01c4ca8
revert spurious 2.5 changes
robertgshaw2-neuralmagic Oct 26, 2024
1ad8a48
stash
robertgshaw2-neuralmagic Oct 26, 2024
f9084f6
cleanup
robertgshaw2-neuralmagic Oct 26, 2024
72bccd9
add MQLLMEnginev1
robertgshaw2-neuralmagic Oct 26, 2024
a6cab52
work with MQLLMEngine
robertgshaw2-neuralmagic Oct 27, 2024
885ed16
format
robertgshaw2-neuralmagic Oct 27, 2024
3ed66cf
cleanup formatting
robertgshaw2-neuralmagic Oct 27, 2024
8ae8ce9
revert exmple change
robertgshaw2-neuralmagic Oct 27, 2024
5c72515
update comment
robertgshaw2-neuralmagic Oct 27, 2024
f9b33fa
formatting
robertgshaw2-neuralmagic Oct 27, 2024
82539b9
updated
robertgshaw2-neuralmagic Oct 27, 2024
d42a54e
stash
robertgshaw2-neuralmagic Oct 27, 2024
3a2d02a
format
robertgshaw2-neuralmagic Oct 27, 2024
6028ee1
Merge branch 'main' into rs-prototype-2
robertgshaw2-neuralmagic Oct 27, 2024
6bd37c1
update
robertgshaw2-neuralmagic Oct 27, 2024
196d822
revert bind/connect
robertgshaw2-neuralmagic Oct 27, 2024
a089cd1
revert comment
robertgshaw2-neuralmagic Oct 27, 2024
974aa06
formatting
robertgshaw2-neuralmagic Oct 27, 2024
fe1e1b4
formatting tweaks
robertgshaw2-neuralmagic Oct 27, 2024
9c27fbb
move detokenizer into engine
robertgshaw2-neuralmagic Oct 27, 2024
95b5af1
format
robertgshaw2-neuralmagic Oct 27, 2024
3999279
stash
robertgshaw2-neuralmagic Oct 27, 2024
b4dd571
revert bad import
robertgshaw2-neuralmagic Oct 27, 2024
f01f992
format
robertgshaw2-neuralmagic Oct 28, 2024
be333fa
format
robertgshaw2-neuralmagic Oct 28, 2024
aefb498
add files
robertgshaw2-neuralmagic Oct 28, 2024
6d7f473
stash
robertgshaw2-neuralmagic Oct 28, 2024
f431f8a
update
robertgshaw2-neuralmagic Oct 29, 2024
be431e4
update
robertgshaw2-neuralmagic Oct 29, 2024
36b7fa5
fix api client example to work with v1
robertgshaw2-neuralmagic Oct 29, 2024
3a5ce74
formatting
robertgshaw2-neuralmagic Oct 29, 2024
0d0251e
updated
robertgshaw2-neuralmagic Oct 29, 2024
046d78f
update
robertgshaw2-neuralmagic Oct 29, 2024
34c0665
update
robertgshaw2-neuralmagic Oct 29, 2024
52b790f
stash
robertgshaw2-neuralmagic Oct 30, 2024
4f9a86e
Stash
robertgshaw2-neuralmagic Oct 30, 2024
697b98f
stash
robertgshaw2-neuralmagic Oct 30, 2024
fa5c01d
LLMEngineWorking
robertgshaw2-neuralmagic Oct 30, 2024
0ca42d8
format
robertgshaw2-neuralmagic Oct 30, 2024
b6497d5
updated
robertgshaw2-neuralmagic Oct 30, 2024
ae88c73
updated
robertgshaw2-neuralmagic Oct 30, 2024
2161152
update
robertgshaw2-neuralmagic Oct 31, 2024
6a57297
aded processor
robertgshaw2-neuralmagic Oct 31, 2024
3665602
udpated
robertgshaw2-neuralmagic Oct 31, 2024
ed567ca
updated
robertgshaw2-neuralmagic Oct 31, 2024
f4005da
updated formats
robertgshaw2-neuralmagic Oct 31, 2024
67a53ed
revert
robertgshaw2-neuralmagic Oct 31, 2024
458b54f
finished
robertgshaw2-neuralmagic Oct 31, 2024
75ff707
updated
robertgshaw2-neuralmagic Oct 31, 2024
669648f
split core process into separate class
njhill Oct 31, 2024
127f09c
stash
robertgshaw2-neuralmagic Oct 31, 2024
99f683e
Merge pull request #22 from njhill/rework-splitcore
robertgshaw2-neuralmagic Oct 31, 2024
dc6163c
updated
robertgshaw2-neuralmagic Oct 31, 2024
d21cb8f
updated
robertgshaw2-neuralmagic Oct 31, 2024
565ffa6
working again
robertgshaw2-neuralmagic Oct 31, 2024
2960fbc
format
robertgshaw2-neuralmagic Oct 31, 2024
5d23709
updated
robertgshaw2-neuralmagic Oct 31, 2024
f2f2e40
updated
robertgshaw2-neuralmagic Oct 31, 2024
c10c9d8
better interface
robertgshaw2-neuralmagic Oct 31, 2024
b8767a9
formatting
robertgshaw2-neuralmagic Oct 31, 2024
ab783e1
format
robertgshaw2-neuralmagic Oct 31, 2024
423f47d
update
robertgshaw2-neuralmagic Oct 31, 2024
7c977d3
updated
robertgshaw2-neuralmagic Nov 1, 2024
3c14bdf
format
robertgshaw2-neuralmagic Nov 1, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions examples/api_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ def post_http_request(prompt: str,
pload = {
"prompt": prompt,
"n": n,
"use_beam_search": True,
"temperature": 0.0,
"max_tokens": 16,
"stream": stream,
Expand Down Expand Up @@ -58,7 +57,7 @@ def get_response(response: requests.Response) -> List[str]:
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="localhost")
parser.add_argument("--port", type=int, default=8000)
parser.add_argument("--n", type=int, default=4)
parser.add_argument("--n", type=int, default=1)
parser.add_argument("--prompt", type=str, default="San Francisco is a")
parser.add_argument("--stream", action="store_true")
args = parser.parse_args()
Expand All @@ -77,8 +76,8 @@ def get_response(response: requests.Response) -> List[str]:
num_printed_lines = 0
for i, line in enumerate(h):
num_printed_lines += 1
print(f"Beam candidate {i}: {line!r}", flush=True)
print(f"Output {i}: {line!r}", flush=True)
else:
output = get_response(response)
for i, line in enumerate(output):
print(f"Beam candidate {i}: {line!r}", flush=True)
print(f"Output {i}: {line!r}", flush=True)
7 changes: 6 additions & 1 deletion vllm/entrypoints/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,20 @@
from fastapi.responses import JSONResponse, Response, StreamingResponse

from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.entrypoints.launcher import serve_http
from vllm.envs import VLLM_USE_V1
from vllm.logger import init_logger
from vllm.sampling_params import SamplingParams
from vllm.usage.usage_lib import UsageContext
from vllm.utils import (FlexibleArgumentParser, iterate_with_cancellation,
random_uuid)
from vllm.version import __version__ as VLLM_VERSION

if VLLM_USE_V1:
from vllm.v1.engine.async_llm_engine import AsyncLLMEngine # type: ignore
else:
from vllm.engine.async_llm_engine import AsyncLLMEngine # type: ignore

logger = init_logger("vllm.entrypoints.api_server")

TIMEOUT_KEEP_ALIVE = 5 # seconds.
Expand Down
126 changes: 110 additions & 16 deletions vllm/entrypoints/llm.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import itertools
import time
import warnings
from contextlib import contextmanager
from typing import (Any, ClassVar, Dict, List, Optional, Sequence, Tuple,
Expand All @@ -14,7 +15,7 @@
apply_hf_chat_template,
apply_mistral_chat_template,
parse_chat_messages)
from vllm.inputs import PromptType, TextPrompt, TokensPrompt
from vllm.inputs import INPUT_REGISTRY, PromptType, TextPrompt, TokensPrompt
from vllm.inputs.parse import parse_and_batch_prompt
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
Expand All @@ -32,7 +33,10 @@
from vllm.utils import Counter, deprecate_args, deprecate_kwargs, is_list_of

if envs.VLLM_USE_V1:
from vllm.v1.engine.core import EngineCoreClient # type: ignore
from vllm.v1.engine.detokenizer import Detokenizer # type: ignore
from vllm.v1.engine.llm_engine import LLMEngine # type: ignore
from vllm.v1.engine.processor import Processor # type: ignore
else:
from vllm.engine.llm_engine import LLMEngine # type: ignore

Expand Down Expand Up @@ -195,8 +199,50 @@
mm_processor_kwargs=mm_processor_kwargs,
**kwargs,
)
self.llm_engine = LLMEngine.from_engine_args(
engine_args, usage_context=UsageContext.LLM_CLASS)

# TODO: should this be wrapped in a class?
if envs.VLLM_USE_V1:
engine_config = engine_args.create_engine_config()
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@njhill - im starting to think this stuff should be wrapped into llmengine

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@robertgshaw2-neuralmagic yeah I was thinking we might have a completely separate LLM class but that may be tricky if we want to be able to switch existing code with the env var.

executor_class = LLMEngine._get_executor_cls(engine_config)
self.task = engine_config.model_config.task
self.supported_tasks = engine_config.model_config.task

# Processor (converts Inputs --> EngineCoreRequests)
self.processor = Processor(engine_config.model_config,
engine_config.parallel_config,
engine_config.scheduler_config,
engine_config.lora_config,
INPUT_REGISTRY)

# Detokenizer (converts EngineCoreOutputs --> RequestOutput)
self.detokenizer = Detokenizer(
engine_config.model_config.tokenizer)

# EngineCoreClient.
self.engine_core_client = EngineCoreClient(
executor_class,
engine_config.model_config,
engine_config.cache_config,
engine_config.parallel_config,
engine_config.scheduler_config,
engine_config.device_config,
engine_config.load_config,
engine_config.lora_config,
engine_config.speculative_config,
engine_config.decoding_config,
engine_config.observability_config,
engine_config.prompt_adapter_config,
UsageContext.LLM_CLASS,
use_async_sockets=False,
)

else:
self.llm_engine = LLMEngine.from_engine_args(
engine_args, usage_context=UsageContext.LLM_CLASS)

self.task = self.llm_engine.model_config.task
self.supported_tasks = self.llm_engine.model_config.task

self.request_counter = Counter()

def get_tokenizer(self) -> AnyTokenizer:
Expand Down Expand Up @@ -337,14 +383,14 @@
considered legacy and may be deprecated in the future. You should
instead pass them via the ``inputs`` parameter.
"""
task = self.llm_engine.model_config.task
task = self.task
if task != "generate":
messages = [
"LLM.generate() is only supported for (conditional) generation "
"models (XForCausalLM, XForConditionalGeneration).",
]

supported_tasks = self.llm_engine.model_config.supported_tasks
supported_tasks = self.supported_tasks
if "generate" in supported_tasks:
messages.append(
"Your model supports the 'generate' task, but is "
Expand Down Expand Up @@ -724,11 +770,11 @@
considered legacy and may be deprecated in the future. You should
instead pass them via the ``inputs`` parameter.
"""
task = self.llm_engine.model_config.task
task = self.task
if task != "embedding":
messages = ["LLM.encode() is only supported for embedding models."]

supported_tasks = self.llm_engine.model_config.supported_tasks
supported_tasks = self.supported_tasks
if "embedding" in supported_tasks:
messages.append(
"Your model supports the 'embedding' task, but is "
Expand Down Expand Up @@ -868,14 +914,28 @@
priority: int = 0,
) -> None:
request_id = str(next(self.request_counter))
self.llm_engine.add_request(
request_id,
prompt,
params,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
priority=priority,
)

if envs.VLLM_USE_V1:
# 1) Convert input --> DetokenizerRequest / EngineCoreRequest.
detokenizer_req, engine_core_req = self.processor.process_inputs(
request_id, prompt, params, time.time(), lora_request, None,
prompt_adapter_request, priority)

# 2) Add the request to Detokenizer (this process).
self.detokenizer.add_request(detokenizer_req)

# 3) Add the EngineCoreRequest to EngineCore (separate process).
self.engine_core_client.add_request(engine_core_req)

else:
self.llm_engine.add_request(
request_id,
prompt,
params,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
priority=priority,
)

def _add_guided_params(
self,
Expand All @@ -898,9 +958,43 @@
whitespace_pattern=guided_options.guided_whitespace_pattern)
return params

def _run_engine_v1(
self, use_tqdm: bool
) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
# Initialize tqdm.
if use_tqdm:
num_requests = self.detokenizer.get_num_unfinished_requests()
pbar = tqdm(
total=num_requests,
desc="Processed prompts",
dynamic_ncols=True,
)

# Run the engine.
request_outputs: List[Union[RequestOutput,
EmbeddingRequestOutput]] = []
while self.detokenizer.has_unfinished_requests():
engine_core_outputs = self.engine_core_client.get_output()
outputs = self.detokenizer.step(engine_core_outputs)
for output in outputs:
if output.finished:
request_outputs.append(output)
if use_tqdm:
pbar.update(1)

if use_tqdm:
pbar.close()
# Sort the outputs by request ID.
# This is necessary because some requests may be finished earlier than
# its previous requests.
return sorted(outputs, key=lambda x: int(x.request_id))

def _run_engine(
self, *, use_tqdm: bool
) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
if envs.VLLM_USE_V1:
return self._run_engine_v1(use_tqdm)

Check failure on line 997 in vllm/entrypoints/llm.py

View workflow job for this annotation

GitHub Actions / mypy (3.8)

Missing return statement [return]

Check failure on line 997 in vllm/entrypoints/llm.py

View workflow job for this annotation

GitHub Actions / mypy (3.10)

Missing return statement [return]

Check failure on line 997 in vllm/entrypoints/llm.py

View workflow job for this annotation

GitHub Actions / mypy (3.11)

Missing return statement [return]

Check failure on line 997 in vllm/entrypoints/llm.py

View workflow job for this annotation

GitHub Actions / mypy (3.12)

Missing return statement [return]
# Initialize tqdm.
if use_tqdm:
num_requests = self.llm_engine.get_num_unfinished_requests()
Expand Down Expand Up @@ -941,7 +1035,7 @@
# Sort the outputs by request ID.
# This is necessary because some requests may be finished earlier than
# its previous requests.
return sorted(outputs, key=lambda x: int(x.request_id))
outputs = sorted(outputs, key=lambda x: int(x.request_id))

def _is_encoder_decoder_model(self):
return self.llm_engine.is_encoder_decoder_model()
30 changes: 30 additions & 0 deletions vllm/outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,36 @@ def __init__(
self.encoder_prompt = encoder_prompt
self.encoder_prompt_token_ids = encoder_prompt_token_ids

@classmethod
def create_empty(cls, request_id: str, prompt: Optional[str],
prompt_token_ids: Optional[List[int]]) -> "RequestOutput":
"""Initialize a new "empty" RequestOutput object."""

# TODO: Support `n` > 1.
completion_output = CompletionOutput(
index=0,
text="",
token_ids=[],
cumulative_logprob=None,
logprobs=None, # TODO
finish_reason=None,
stop_reason=None,
lora_request=None,
)

return RequestOutput(
request_id=request_id,
prompt=prompt,
prompt_token_ids=prompt_token_ids,
prompt_logprobs=None, # TODO
outputs=[completion_output],
finished=False,
metrics=None,
lora_request=None,
encoder_prompt=None,
encoder_prompt_token_ids=None,
)

@classmethod
def from_seq_group(
cls, seq_group: SequenceGroup, use_cache: bool,
Expand Down
27 changes: 20 additions & 7 deletions vllm/v1/core/scheduler.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
from collections import deque
from dataclasses import dataclass
from typing import Deque, Dict, Iterable, List, Optional, Set, Tuple, Union
from typing import Deque, Dict, Iterable, List, Optional, Set, Union

from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig
from vllm.logger import init_logger
from vllm.multimodal import MultiModalDataDict
from vllm.sampling_params import SamplingParams
from vllm.v1.core.kv_cache_manager import KVCacheManager
from vllm.v1.engine import EngineCoreOutput
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.request import Request, RequestStatus

Expand Down Expand Up @@ -227,13 +228,12 @@ def update_from_output(
self,
scheduler_output: "SchedulerOutput",
model_runner_output: "ModelRunnerOutput",
) -> List[Tuple[Request, int]]:
) -> List[EngineCoreOutput]:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Im not sure it makes sense for this method to be in scheduler.py

The only item related to setting the scheduler here is updating which self.running

# NOTE(woosuk): This method doesn't consider speculative decoding.
sampled_token_ids = model_runner_output.sampled_token_ids_cpu.tolist()
num_scheduled_tokens = scheduler_output.num_scheduled_tokens
new_running: List[Request] = []
# (request, num_sampled_tokens)
sampled: List[Tuple[Request, int]] = []
engine_core_outputs: List[EngineCoreOutput] = []
for request in self.running:
req_id = request.request_id
request.num_computed_tokens += num_scheduled_tokens[req_id]
Expand All @@ -247,17 +247,30 @@ def update_from_output(
# generates at most one token at each step.
token_id = sampled_token_ids[req_index]
request.output_token_ids.append(token_id)
sampled.append((request, 1))
num_new_tokens = 1

# TODO: Update the KV cache manager for prefix caching.

# Check if the request is finished.
# Check for stop and update request state.
# This must be called before me make the EngineCoreOutput.
stopped = self._check_stop(request)

# Add EngineCoreOutput for this Request.
output = EngineCoreOutput(
request_id=req_id,
new_token_ids=request.output_token_ids[-num_new_tokens:],
finished=request.is_finished(),
finish_reason=request.get_finished_reason(),
stop_reason=request.stop_reason)
engine_core_outputs.append(output)

# Breakout of the loop.
if stopped:
continue

new_running.append(request)
self.running = new_running
return sampled
return engine_core_outputs

def _check_stop(self, request: Request) -> bool:
if (request.num_tokens >= self.max_model_len
Expand Down
Loading
Loading