Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ Frontend ] Multiprocessing for OpenAI Server with zeromq #6883

Merged
merged 84 commits into from
Aug 3, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
84 commits
Select commit Hold shift + click to select a range
bed649a
:alembic: add backend proto file
joerunde Jul 25, 2024
7de9d49
:recycle: move proto to grpc/pb
joerunde Jul 25, 2024
9394a62
:sparkles: add proto compilation
joerunde Jul 25, 2024
dd8bf96
updated
robertgshaw2-redhat Jul 25, 2024
5c7fbff
kinda working
robertgshaw2-redhat Jul 25, 2024
952e8ef
:construction: more wip
joerunde Jul 25, 2024
e8eac95
fixed
robertgshaw2-redhat Jul 25, 2024
938a843
:bug: fixup race condition
joerunde Jul 25, 2024
2b8d7cd
:bug: remove timeout
joerunde Jul 25, 2024
ea02d39
format
robertgshaw2-redhat Jul 26, 2024
4a2dc46
streaming
robertgshaw2-redhat Jul 26, 2024
30f2bc9
removed breaks
robertgshaw2-redhat Jul 26, 2024
c718b68
pushing current state
robertgshaw2-redhat Jul 26, 2024
b3d25c6
:alembic: try unix sockets
joerunde Jul 26, 2024
2765b17
:zap: no background loop
joerunde Jul 26, 2024
b219778
spurious change
robertgshaw2-redhat Jul 26, 2024
932ea23
remove spurious change
robertgshaw2-redhat Jul 26, 2024
f029114
spurious changes
robertgshaw2-redhat Jul 26, 2024
6854758
spurioous change
robertgshaw2-redhat Jul 26, 2024
3b5ff66
:bug: whoops
joerunde Jul 26, 2024
79247c3
:memo: log stuff
joerunde Jul 26, 2024
a39ebc0
stash
robertgshaw2-redhat Jul 26, 2024
ef257f1
pushing up
robertgshaw2-redhat Jul 26, 2024
a6c9bc5
stash
robertgshaw2-redhat Jul 28, 2024
d7490bc
actually working
robertgshaw2-redhat Jul 28, 2024
f68fd60
cleanup
robertgshaw2-redhat Jul 28, 2024
38b5b9c
more cleanup
robertgshaw2-redhat Jul 28, 2024
bc54311
cleanup
robertgshaw2-redhat Jul 28, 2024
3cccebb
stash
robertgshaw2-redhat Jul 28, 2024
4b78e29
more cleanup
robertgshaw2-redhat Jul 28, 2024
345bfdd
setup
robertgshaw2-redhat Jul 28, 2024
cfbb001
cleanup
robertgshaw2-redhat Jul 28, 2024
d811b42
format
robertgshaw2-redhat Jul 28, 2024
852534e
cleaning up
robertgshaw2-redhat Jul 28, 2024
e42be96
zlib
robertgshaw2-redhat Jul 28, 2024
5202a59
Revert "zlib"
robertgshaw2-redhat Jul 28, 2024
71b1bf9
turn on chunked prefill
robertgshaw2-redhat Jul 28, 2024
a499079
move RPC code into oai server
robertgshaw2-redhat Jul 29, 2024
88a1d08
format
robertgshaw2-redhat Jul 29, 2024
13ce2f1
format
robertgshaw2-redhat Jul 29, 2024
bb8ac06
trying to flow it through
robertgshaw2-redhat Jul 29, 2024
6ebdb3d
cleaning
robertgshaw2-redhat Jul 29, 2024
24c8100
cleaning
robertgshaw2-redhat Jul 29, 2024
e707049
cleaning
robertgshaw2-redhat Jul 29, 2024
baaf6bc
add stubs
robertgshaw2-redhat Jul 29, 2024
9d19d92
format
robertgshaw2-redhat Jul 29, 2024
f1be4b8
working with single launch...
robertgshaw2-redhat Jul 29, 2024
8e417ad
working end to end - with some hacks
robertgshaw2-redhat Jul 29, 2024
4c16c5e
:goal_net: handle shutdown and request errors
joerunde Jul 29, 2024
6ddd4a7
:art: fmt and clean up shutdown handler
joerunde Jul 29, 2024
6d7da74
:bug: fixup type hint for queue
joerunde Jul 29, 2024
97ea04d
:sparkles: update chat endpoint
joerunde Jul 29, 2024
6d753a4
:bug: fixup zmq constant types
joerunde Jul 29, 2024
38e308e
:sparkles: hook up de/tokenize
joerunde Jul 29, 2024
ec19a7b
:recycle: add VLLMBackend protocol
joerunde Jul 29, 2024
453939b
Frontend mp flag (#384)
joerunde Jul 30, 2024
1f33286
Features / Cleanup for MP Frontend (#387)
robertgshaw2-redhat Jul 31, 2024
5362952
Use random port for backend (#390)
joerunde Jul 31, 2024
7214fb8
Await socket operations + some other minor cleanup (#391)
njhill Jul 31, 2024
98a7dab
:sparkles: health check round 2 (#392)
joerunde Jul 31, 2024
f5f0b45
Add tokenizer (#394)
robertgshaw2-redhat Jul 31, 2024
0b351c0
Socket context (#393)
joerunde Jul 31, 2024
79fcc44
Logit bias (#395)
robertgshaw2-redhat Jul 31, 2024
9da8c4a
Merge remote-tracking branch 'upstream/main' into isolate-oai-server-…
joerunde Jul 31, 2024
4c65f74
:bug: messed up the revert in the merge commit :(
joerunde Jul 31, 2024
9bc97f1
fix (#396)
robertgshaw2-redhat Jul 31, 2024
68d8612
Merge remote-tracking branch 'upstream/main' into isolate-oai-server-…
joerunde Jul 31, 2024
4337fe7
format
robertgshaw2-redhat Aug 1, 2024
779d9bd
stash
robertgshaw2-redhat Aug 1, 2024
a6044a3
Fix failed tests (#398)
robertgshaw2-redhat Aug 1, 2024
100189f
Merge branch 'main' into isolate-oai-server-process
robertgshaw2-redhat Aug 1, 2024
0fc8545
fixed merge conflicts
robertgshaw2-redhat Aug 1, 2024
6383091
updated
robertgshaw2-redhat Aug 1, 2024
a09f57f
cleaning
robertgshaw2-redhat Aug 1, 2024
1bdbfcb
:white_check_mark: add test for multiprocessing flag (#399)
joerunde Aug 1, 2024
f3c0f1c
:sparkles: pipe tracing flag (#400)
joerunde Aug 1, 2024
9c415ad
integration tests for old backend
robertgshaw2-redhat Aug 1, 2024
62036ad
rename
robertgshaw2-redhat Aug 1, 2024
a177d87
cleaning
robertgshaw2-redhat Aug 1, 2024
9ca3b93
ordering
robertgshaw2-redhat Aug 1, 2024
f8b5fb1
fix embedding model feedback
robertgshaw2-redhat Aug 1, 2024
fca5a71
Update vllm/entrypoints/openai/rpc/server.py
robertgshaw2-redhat Aug 1, 2024
5f07f86
format
robertgshaw2-redhat Aug 1, 2024
bd0fd76
Merge branch 'main' into isolate-oai-server-process
robertgshaw2-redhat Aug 2, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
format
  • Loading branch information
robertgshaw2-redhat committed Jul 29, 2024
commit 88a1d089586280a42da3badb01c93bc8a055a397
46 changes: 22 additions & 24 deletions vllm/entrypoints/openai/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import re
import signal
from contextlib import asynccontextmanager
from multiprocessing import Process
from http import HTTPStatus
from typing import Optional, Set

Expand Down Expand Up @@ -37,8 +38,9 @@
from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
from vllm.entrypoints.openai.serving_tokenization import (
OpenAIServingTokenization)
from vllm.entrypoints.openai.rpc.client import RPCClient
from vllm.entrypoints.openai.rpc.server import run_rpc_server
from vllm.logger import init_logger
from vllm.usage.usage_lib import UsageContext
from vllm.utils import FlexibleArgumentParser
from vllm.version import __version__ as VLLM_VERSION

Expand Down Expand Up @@ -216,29 +218,19 @@ async def authentication(request: Request, call_next):

async def build_server(
args,
llm_engine: Optional[AsyncLLMEngine] = None,
**uvicorn_kwargs,
) -> uvicorn.Server:
app = build_app(args)

# if args.served_model_name is not None:
# served_model_names = args.served_model_name
# else:
# served_model_names = [args.model]

served_model_names = "meta-llama/Meta-Llama-3-8B-Instruct"

from vllm.grpc.client import RPCClient
engine = RPCClient()

# global engine, engine_args

# engine_args = AsyncEngineArgs.from_cli_args(args)
# engine = (llm_engine
# if llm_engine is not None else AsyncLLMEngine.from_engine_args(
# engine_args, usage_context=UsageContext.OPENAI_API_SERVER))

# model_config = await engine.get_model_config()
if args.served_model_name is not None:
served_model_names = args.served_model_name
else:
served_model_names = [args.model]

print("HERE")
rpc_client = RPCClient()
model_config = await rpc_client.get_model_config()
print("HERE2")

if args.disable_log_requests:
request_logger = None
Expand Down Expand Up @@ -309,13 +301,17 @@ async def build_server(
return uvicorn.Server(config)


async def run_server(args, llm_engine=None, **uvicorn_kwargs) -> None:
async def run_server(args, **uvicorn_kwargs) -> None:
logger.info("vLLM API server version %s", VLLM_VERSION)
logger.info("args: %s", args)


logger.info("Starting RPC Server")
rpc_server_process = Process(target=run_rpc_server,
args=(AsyncEngineArgs.from_cli_args(args),))
rpc_server_process.start()

server = await build_server(
args,
llm_engine,
**uvicorn_kwargs,
)

Expand All @@ -332,10 +328,12 @@ def signal_handler() -> None:

try:
await server_task
rpc_server_process.join()
except asyncio.CancelledError:
print("Gracefully stopping http server")
await server.shutdown()

rpc_server_process.join()


if __name__ == "__main__":
# NOTE(simon):
Expand Down
25 changes: 25 additions & 0 deletions vllm/entrypoints/openai/rpc/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from dataclasses import dataclass
from typing import Optional, Mapping
from enum import Enum

from vllm.inputs import PromptInputs
from vllm.lora.request import LoRARequest
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams

VLLM_GENERATE_RPC_PATH = "tcp://localhost:5570"
VLLM_GET_DATA_RPC_PATH = "tcp://localhost:5571"
VLLM_IS_READY_RPC_PATH = "tcp://localhost:5572"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At the very least we should be choosing randomly available ports above the "user space" of the 1000s port range. See https://pyzmq.readthedocs.io/en/latest/api/zmq.html#zmq.Socket.bind_to_random_port

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mgoin yeah that'd be better. Any idea how we'd notify the clients what port to connect to in that case?

The server that would be calling .bind_to_random_port() is in a different process than the openai server that needs to connect clients to it

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the the openai server can listen first, pass the ports to the server process, and then the server just connects to it?


@dataclass
class GenerateRequest:
inputs: PromptInputs
sampling_params: SamplingParams
request_id: str
lora_request: Optional[LoRARequest] = None
trace_headers: Optional[Mapping[str, str]] = None
prompt_adapter_request: Optional[PromptAdapterRequest] = None


class GetDataRequest(Enum):
MODEL_CONFIG = 1
68 changes: 23 additions & 45 deletions vllm/entrypoints/openai/rpc/client.py
Original file line number Diff line number Diff line change
@@ -1,62 +1,31 @@
from vllm import AsyncLLMEngine
from typing import AsyncIterator, Optional, Mapping

from vllm.config import ModelConfig
from vllm.inputs import PromptInputs
from vllm.lora.request import LoRARequest
from vllm.outputs import RequestOutput
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams
from transformers import AutoTokenizer
from dataclasses import dataclass
from vllm.entrypoints.openai.rpc import (
VLLM_GENERATE_RPC_PATH, VLLM_GET_DATA_RPC_PATH, GenerateRequest, GetDataRequest)

import zmq
import zmq.asyncio
import pickle

MODEL = "meta-llama/Meta-Llama-3-8B-Instruct"
ADDRESS = "ipc:///tmp/zmqtest"

@dataclass
class RCPRequest:
inputs: PromptInputs
sampling_params: SamplingParams
request_id: str


class RPCClient(AsyncLLMEngine):
class RPCClient:
def __init__(self):
self.engine_use_ray = False
self.worker_use_ray = False
self.log_requests = False
self.engine = None

self.tokenizer = AutoTokenizer.from_pretrained(MODEL)

self.context = zmq.asyncio.Context()


@property
def is_running(self) -> bool:
return True
self.is_ready_socket = self.context.socket(zmq.REP)
self.get_data_socket = self.context.socket(zmq.REQ)
self.get_data_socket.connect(VLLM_GET_DATA_RPC_PATH)

@property
def is_stopped(self) -> bool:
return False

@property
def errored(self) -> bool:
return False

async def get_tokenizer(
self,
lora_request: Optional[LoRARequest] = None,
) -> "PreTrainedTokenizer":
# TODO: what to return :/
return self.tokenizer

def start_background_loop(self):
# TODO something lol
pass
async def get_model_config(self) -> ModelConfig:
self.get_data_socket.send(pickle.dumps(GetDataRequest.MODEL_CONFIG))
return pickle.loads(await self.get_data_socket.recv())


async def generate(
self,
Expand All @@ -67,19 +36,28 @@ async def generate(
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None
) -> AsyncIterator[RequestOutput]:

# Connect to RPC socket for Request-Reply pattern,
# Note that we use DEALER to enable asynchronous communication
# to enable streaming.
socket = self.context.socket(zmq.DEALER)
socket.connect(ADDRESS)
socket.connect(VLLM_GENERATE_RPC_PATH)

# Send GenerateRequest to the RPC Server.
await socket.send_multipart([
pickle.dumps(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As @robertgshaw2-neuralmagic suggested, let's use msgspec?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's separate the messaging protocol optimizations to a separate PR

RCPRequest(
GenerateRequest(
inputs=inputs,
sampling_params=sampling_params,
request_id=request_id
request_id=request_id,
lora_request=lora_request,
trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request
), pickle.HIGHEST_PROTOCOL
)
])

# Stream back the results from the RPC Server.
while True:
message = await socket.recv()
request_output = pickle.loads(message)
Expand Down
84 changes: 66 additions & 18 deletions vllm/entrypoints/openai/rpc/server.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,52 @@
from vllm import AsyncEngineArgs, AsyncLLMEngine
import asyncio
import pickle
import zmq
import zmq.asyncio

from .client import MODEL, ADDRESS
from vllm import AsyncLLMEngine
from vllm.usage.usage_lib import UsageContext
from vllm.entrypoints.openai.rpc import (VLLM_GENERATE_RPC_PATH,
VLLM_GET_DATA_RPC_PATH,
VLLM_IS_READY_RPC_PATH,
GetDataRequest)

class RPCServer:
def __init__(self):
def __init__(self, async_engine_args):
# Initialize engine first.
self.engine = AsyncLLMEngine.from_engine_args(
async_engine_args, UsageContext.OPENAI_API_SERVER)

# Initialize context.
self.context = zmq.asyncio.Context()
self.socket = self.context.socket(zmq.ROUTER)
self.socket.bind(ADDRESS)

# Init socket for readiness state.
self.is_ready_socket = self.context.socket(zmq.REP)
self.is_ready_socket.bind(VLLM_IS_READY_RPC_PATH)

# Init socket for generation.
self.generate_socket = self.context.socket(zmq.ROUTER)
self.generate_socket.bind(VLLM_GENERATE_RPC_PATH)

# TODO (robertgshaw2-neuralmagic):
# add socket for generation without streaming

# Init socket for simple data requests.
self.get_data_socket = self.context.socket(zmq.REP)
self.get_data_socket.bind(VLLM_GET_DATA_RPC_PATH)

# Setup polling so we can listen on both sockets.
self.poller = zmq.asyncio.Poller()
self.poller.register(self.generate_socket, zmq.POLLIN)
self.poller.register(self.get_data_socket, zmq.POLLIN)


async def get_data(self, message):
request_type = pickle.loads(message)
if request_type == GetDataRequest.MODEL_CONFIG:
return await self.engine.get_model_config()
else:
raise ValueError(f"Unknown request type: {request_type}")

self.running_tasks = set()
self.engine = AsyncLLMEngine.from_engine_args(
AsyncEngineArgs(model=MODEL,
enable_chunked_prefill=True))

async def generate(self, identity, message):
request = pickle.loads(message)
Expand All @@ -29,23 +60,40 @@ async def generate(self, identity, message):
identity,
pickle.dumps(request_output, pickle.HIGHEST_PROTOCOL)
])

async def run_loop(self):
# Notify the RPC client that we are ready to recieve requests.
await self.is_ready_socket.send_string("Ready!")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This ready string should be a global shared constant

self.is_ready_socket.close()

# Avoid GC of running tasks.
running_tasks = set()
while True:
identity, message = await self.socket.recv_multipart()
try:
socks = dict(await self.poller.poll())
except KeyboardInterrupt:
# TODO: should there be some other exception here?
break

# Process the request in the background.
task = asyncio.create_task(self.generate(identity=identity,
message=message))
task = None
if self.generate_socket in socks:
identity, message = await self.generate_socket.recv_multipart()
task = asyncio.create_task(self.generate(identity, message))

elif self.get_data_socket in socks:
message = await self.get_data_socket.recv()
task = asyncio.create_task(self.get_data(message))

# We need to keep around a strong reference to the task,
# to avoid the task disappearing mid-execution as running tasks
# can be GC'ed. Below is a common "fire-and-forget" tasks
# https://docs.python.org/3/library/asyncio-task.html#asyncio.create_task
self.running_tasks.add(task)
task.add_done_callback(self.running_tasks.discard)
if task is not None:
running_tasks.add(task)
task.add_done_callback(running_tasks.discard)

# TODO: Do I need to close the generate / get_data sockets?
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's best to close sockets explicitly if possible per https://pyzmq.readthedocs.io/en/latest/api/zmq.html#zmq.Socket.close

If this is not called, the socket will automatically be closed when it is garbage collected, in which case you may see a ResourceWarning about the unclosed socket.


if __name__ == "__main__":
server = RPCServer()
def run_rpc_server(async_engine_args):
server = RPCServer(async_engine_args=async_engine_args)
asyncio.run(server.run_loop())
4 changes: 2 additions & 2 deletions vllm/entrypoints/openai/serving_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ class OpenAIServingCompletion(OpenAIServing):
def __init__(
self,
engine: AsyncLLMEngine,
# model_config: ModelConfig,
model_config: ModelConfig,
served_model_names: List[str],
*,
lora_modules: Optional[List[LoRAModulePath]],
Expand All @@ -54,7 +54,7 @@ def __init__(
return_tokens_as_token_ids: bool = False,
):
super().__init__(engine=engine,
# model_config=model_config,
model_config=model_config,
served_model_names=served_model_names,
lora_modules=lora_modules,
prompt_adapters=prompt_adapters,
Expand Down