-
-
Notifications
You must be signed in to change notification settings - Fork 5.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[ Frontend ] Multiprocessing for OpenAI Server with zeromq
#6883
Changes from 1 commit
bed649a
7de9d49
9394a62
dd8bf96
5c7fbff
952e8ef
e8eac95
938a843
2b8d7cd
ea02d39
4a2dc46
30f2bc9
c718b68
b3d25c6
2765b17
b219778
932ea23
f029114
6854758
3b5ff66
79247c3
a39ebc0
ef257f1
a6c9bc5
d7490bc
f68fd60
38b5b9c
bc54311
3cccebb
4b78e29
345bfdd
cfbb001
d811b42
852534e
e42be96
5202a59
71b1bf9
a499079
88a1d08
13ce2f1
bb8ac06
6ebdb3d
24c8100
e707049
baaf6bc
9d19d92
f1be4b8
8e417ad
4c16c5e
6ddd4a7
6d7da74
97ea04d
6d753a4
38e308e
ec19a7b
453939b
1f33286
5362952
7214fb8
98a7dab
f5f0b45
0b351c0
79fcc44
9da8c4a
4c65f74
9bc97f1
68d8612
4337fe7
779d9bd
a6044a3
100189f
0fc8545
6383091
a09f57f
1bdbfcb
f3c0f1c
9c415ad
62036ad
a177d87
9ca3b93
f8b5fb1
fca5a71
5f07f86
bd0fd76
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
- Loading branch information
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
from dataclasses import dataclass | ||
from typing import Optional, Mapping | ||
from enum import Enum | ||
|
||
from vllm.inputs import PromptInputs | ||
from vllm.lora.request import LoRARequest | ||
from vllm.prompt_adapter.request import PromptAdapterRequest | ||
from vllm.sampling_params import SamplingParams | ||
|
||
VLLM_GENERATE_RPC_PATH = "tcp://localhost:5570" | ||
VLLM_GET_DATA_RPC_PATH = "tcp://localhost:5571" | ||
VLLM_IS_READY_RPC_PATH = "tcp://localhost:5572" | ||
|
||
@dataclass | ||
class GenerateRequest: | ||
inputs: PromptInputs | ||
sampling_params: SamplingParams | ||
request_id: str | ||
lora_request: Optional[LoRARequest] = None | ||
trace_headers: Optional[Mapping[str, str]] = None | ||
prompt_adapter_request: Optional[PromptAdapterRequest] = None | ||
|
||
|
||
class GetDataRequest(Enum): | ||
MODEL_CONFIG = 1 |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,62 +1,31 @@ | ||
from vllm import AsyncLLMEngine | ||
from typing import AsyncIterator, Optional, Mapping | ||
|
||
from vllm.config import ModelConfig | ||
from vllm.inputs import PromptInputs | ||
from vllm.lora.request import LoRARequest | ||
from vllm.outputs import RequestOutput | ||
from vllm.prompt_adapter.request import PromptAdapterRequest | ||
from vllm.sampling_params import SamplingParams | ||
from transformers import AutoTokenizer | ||
from dataclasses import dataclass | ||
from vllm.entrypoints.openai.rpc import ( | ||
VLLM_GENERATE_RPC_PATH, VLLM_GET_DATA_RPC_PATH, GenerateRequest, GetDataRequest) | ||
|
||
import zmq | ||
import zmq.asyncio | ||
import pickle | ||
|
||
MODEL = "meta-llama/Meta-Llama-3-8B-Instruct" | ||
ADDRESS = "ipc:///tmp/zmqtest" | ||
|
||
@dataclass | ||
class RCPRequest: | ||
inputs: PromptInputs | ||
sampling_params: SamplingParams | ||
request_id: str | ||
|
||
|
||
class RPCClient(AsyncLLMEngine): | ||
class RPCClient: | ||
def __init__(self): | ||
self.engine_use_ray = False | ||
self.worker_use_ray = False | ||
self.log_requests = False | ||
self.engine = None | ||
|
||
self.tokenizer = AutoTokenizer.from_pretrained(MODEL) | ||
|
||
self.context = zmq.asyncio.Context() | ||
|
||
|
||
@property | ||
def is_running(self) -> bool: | ||
return True | ||
self.is_ready_socket = self.context.socket(zmq.REP) | ||
self.get_data_socket = self.context.socket(zmq.REQ) | ||
self.get_data_socket.connect(VLLM_GET_DATA_RPC_PATH) | ||
|
||
@property | ||
def is_stopped(self) -> bool: | ||
return False | ||
|
||
@property | ||
def errored(self) -> bool: | ||
return False | ||
|
||
async def get_tokenizer( | ||
self, | ||
lora_request: Optional[LoRARequest] = None, | ||
) -> "PreTrainedTokenizer": | ||
# TODO: what to return :/ | ||
return self.tokenizer | ||
|
||
def start_background_loop(self): | ||
# TODO something lol | ||
pass | ||
async def get_model_config(self) -> ModelConfig: | ||
self.get_data_socket.send(pickle.dumps(GetDataRequest.MODEL_CONFIG)) | ||
return pickle.loads(await self.get_data_socket.recv()) | ||
|
||
|
||
async def generate( | ||
self, | ||
|
@@ -67,19 +36,28 @@ async def generate( | |
trace_headers: Optional[Mapping[str, str]] = None, | ||
prompt_adapter_request: Optional[PromptAdapterRequest] = None | ||
) -> AsyncIterator[RequestOutput]: | ||
|
||
# Connect to RPC socket for Request-Reply pattern, | ||
# Note that we use DEALER to enable asynchronous communication | ||
# to enable streaming. | ||
socket = self.context.socket(zmq.DEALER) | ||
socket.connect(ADDRESS) | ||
socket.connect(VLLM_GENERATE_RPC_PATH) | ||
|
||
# Send GenerateRequest to the RPC Server. | ||
await socket.send_multipart([ | ||
pickle.dumps( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As @robertgshaw2-neuralmagic suggested, let's use msgspec? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's separate the messaging protocol optimizations to a separate PR |
||
RCPRequest( | ||
GenerateRequest( | ||
inputs=inputs, | ||
sampling_params=sampling_params, | ||
request_id=request_id | ||
request_id=request_id, | ||
lora_request=lora_request, | ||
trace_headers=trace_headers, | ||
prompt_adapter_request=prompt_adapter_request | ||
), pickle.HIGHEST_PROTOCOL | ||
) | ||
]) | ||
|
||
# Stream back the results from the RPC Server. | ||
while True: | ||
message = await socket.recv() | ||
request_output = pickle.loads(message) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,21 +1,52 @@ | ||
from vllm import AsyncEngineArgs, AsyncLLMEngine | ||
import asyncio | ||
import pickle | ||
import zmq | ||
import zmq.asyncio | ||
|
||
from .client import MODEL, ADDRESS | ||
from vllm import AsyncLLMEngine | ||
from vllm.usage.usage_lib import UsageContext | ||
from vllm.entrypoints.openai.rpc import (VLLM_GENERATE_RPC_PATH, | ||
VLLM_GET_DATA_RPC_PATH, | ||
VLLM_IS_READY_RPC_PATH, | ||
GetDataRequest) | ||
|
||
class RPCServer: | ||
def __init__(self): | ||
def __init__(self, async_engine_args): | ||
# Initialize engine first. | ||
self.engine = AsyncLLMEngine.from_engine_args( | ||
async_engine_args, UsageContext.OPENAI_API_SERVER) | ||
|
||
# Initialize context. | ||
self.context = zmq.asyncio.Context() | ||
self.socket = self.context.socket(zmq.ROUTER) | ||
self.socket.bind(ADDRESS) | ||
|
||
# Init socket for readiness state. | ||
self.is_ready_socket = self.context.socket(zmq.REP) | ||
self.is_ready_socket.bind(VLLM_IS_READY_RPC_PATH) | ||
|
||
# Init socket for generation. | ||
self.generate_socket = self.context.socket(zmq.ROUTER) | ||
self.generate_socket.bind(VLLM_GENERATE_RPC_PATH) | ||
|
||
# TODO (robertgshaw2-neuralmagic): | ||
# add socket for generation without streaming | ||
|
||
# Init socket for simple data requests. | ||
self.get_data_socket = self.context.socket(zmq.REP) | ||
self.get_data_socket.bind(VLLM_GET_DATA_RPC_PATH) | ||
|
||
# Setup polling so we can listen on both sockets. | ||
self.poller = zmq.asyncio.Poller() | ||
self.poller.register(self.generate_socket, zmq.POLLIN) | ||
self.poller.register(self.get_data_socket, zmq.POLLIN) | ||
|
||
|
||
async def get_data(self, message): | ||
request_type = pickle.loads(message) | ||
if request_type == GetDataRequest.MODEL_CONFIG: | ||
return await self.engine.get_model_config() | ||
else: | ||
raise ValueError(f"Unknown request type: {request_type}") | ||
|
||
self.running_tasks = set() | ||
self.engine = AsyncLLMEngine.from_engine_args( | ||
AsyncEngineArgs(model=MODEL, | ||
enable_chunked_prefill=True)) | ||
|
||
async def generate(self, identity, message): | ||
request = pickle.loads(message) | ||
|
@@ -29,23 +60,40 @@ async def generate(self, identity, message): | |
identity, | ||
pickle.dumps(request_output, pickle.HIGHEST_PROTOCOL) | ||
]) | ||
|
||
async def run_loop(self): | ||
# Notify the RPC client that we are ready to recieve requests. | ||
await self.is_ready_socket.send_string("Ready!") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This ready string should be a global shared constant |
||
self.is_ready_socket.close() | ||
|
||
# Avoid GC of running tasks. | ||
running_tasks = set() | ||
while True: | ||
identity, message = await self.socket.recv_multipart() | ||
try: | ||
socks = dict(await self.poller.poll()) | ||
except KeyboardInterrupt: | ||
# TODO: should there be some other exception here? | ||
break | ||
|
||
# Process the request in the background. | ||
task = asyncio.create_task(self.generate(identity=identity, | ||
message=message)) | ||
task = None | ||
if self.generate_socket in socks: | ||
identity, message = await self.generate_socket.recv_multipart() | ||
task = asyncio.create_task(self.generate(identity, message)) | ||
|
||
elif self.get_data_socket in socks: | ||
message = await self.get_data_socket.recv() | ||
task = asyncio.create_task(self.get_data(message)) | ||
|
||
# We need to keep around a strong reference to the task, | ||
# to avoid the task disappearing mid-execution as running tasks | ||
# can be GC'ed. Below is a common "fire-and-forget" tasks | ||
# https://docs.python.org/3/library/asyncio-task.html#asyncio.create_task | ||
self.running_tasks.add(task) | ||
task.add_done_callback(self.running_tasks.discard) | ||
if task is not None: | ||
running_tasks.add(task) | ||
task.add_done_callback(running_tasks.discard) | ||
|
||
# TODO: Do I need to close the generate / get_data sockets? | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it's best to close sockets explicitly if possible per https://pyzmq.readthedocs.io/en/latest/api/zmq.html#zmq.Socket.close
|
||
|
||
if __name__ == "__main__": | ||
server = RPCServer() | ||
def run_rpc_server(async_engine_args): | ||
server = RPCServer(async_engine_args=async_engine_args) | ||
asyncio.run(server.run_loop()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
At the very least we should be choosing randomly available ports above the "user space" of the 1000s port range. See https://pyzmq.readthedocs.io/en/latest/api/zmq.html#zmq.Socket.bind_to_random_port
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@mgoin yeah that'd be better. Any idea how we'd notify the clients what port to connect to in that case?
The server that would be calling
.bind_to_random_port()
is in a different process than the openai server that needs to connect clients to itThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the the openai server can listen first, pass the ports to the server process, and then the server just connects to it?