Skip to content

Commit b2c0650

Browse files
frankie-ysDarkLight1337KuntaiDu
authored
[P/D]Provide bucket algorithm rate limiter for proxy_server (#22643)
Signed-off-by: frankie-ys <yongshengwang@cmbchina.com> Signed-off-by: frankie <wangyongsheng686@gmail.com> Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com> Co-authored-by: Kuntai Du <kuntai@uchicago.edu>
1 parent b2f6c24 commit b2c0650

File tree

3 files changed

+272
-52
lines changed

3 files changed

+272
-52
lines changed
Lines changed: 188 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -1,63 +1,199 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

4+
import argparse
5+
import asyncio
6+
import logging
47
import os
58

69
import aiohttp
7-
from quart import Quart, make_response, request
8-
9-
AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60)
10-
11-
app = Quart(__name__)
12-
13-
14-
async def forward_request(url, data):
15-
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
10+
from quart import Quart, Response, make_response, request
11+
from rate_limiter import RateLimiter
12+
from request_queue import RequestQueue
13+
14+
# Configure logging
15+
logging.basicConfig(level=logging.INFO)
16+
logger = logging.getLogger(__name__)
17+
18+
19+
def parse_args():
20+
"""parse command line arguments"""
21+
parser = argparse.ArgumentParser(description="vLLM P/D disaggregation proxy server")
22+
23+
# Add args
24+
parser.add_argument(
25+
"--timeout",
26+
type=float,
27+
default=300,
28+
help="Timeout for backend service requests in seconds (default: 300)",
29+
)
30+
parser.add_argument(
31+
"--max-concurrent",
32+
type=int,
33+
default=100,
34+
help="Maximum concurrent requests to backend services (default: 100)",
35+
)
36+
parser.add_argument(
37+
"--queue-size",
38+
type=int,
39+
default=500,
40+
help="Maximum number of requests in the queue (default: 500)",
41+
)
42+
parser.add_argument(
43+
"--rate-limit",
44+
type=int,
45+
default=40,
46+
help="Maximum requests per second (default: 40)",
47+
)
48+
parser.add_argument(
49+
"--port",
50+
type=int,
51+
default=8000,
52+
help="Port to run the server on (default: 8000)",
53+
)
54+
parser.add_argument(
55+
"--prefill-url",
56+
type=str,
57+
default="http://localhost:8100/v1/completions",
58+
help="Prefill service endpoint URL",
59+
)
60+
parser.add_argument(
61+
"--decode-url",
62+
type=str,
63+
default="http://localhost:8200/v1/completions",
64+
help="Decode service endpoint URL",
65+
)
66+
67+
return parser.parse_args()
68+
69+
70+
def main():
71+
"""parse command line arguments"""
72+
args = parse_args()
73+
74+
# Initialize configuration using command line parameters
75+
AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=args.timeout)
76+
MAX_CONCURRENT_REQUESTS = args.max_concurrent
77+
REQUEST_QUEUE_SIZE = args.queue_size
78+
RATE_LIMIT = args.rate_limit
79+
PREFILL_SERVICE_URL = args.prefill_url
80+
DECODE_SERVICE_URL = args.decode_url
81+
PORT = args.port
82+
83+
app = Quart(__name__)
84+
85+
# Initialize the rate limiter and request queue
86+
rate_limiter = RateLimiter(RATE_LIMIT)
87+
request_queue = RequestQueue(MAX_CONCURRENT_REQUESTS, REQUEST_QUEUE_SIZE)
88+
89+
# Attach the configuration object to the application instance
90+
app.config.update(
91+
{
92+
"AIOHTTP_TIMEOUT": AIOHTTP_TIMEOUT,
93+
"rate_limiter": rate_limiter,
94+
"request_queue": request_queue,
95+
"PREFILL_SERVICE_URL": PREFILL_SERVICE_URL,
96+
"DECODE_SERVICE_URL": DECODE_SERVICE_URL,
97+
}
98+
)
99+
100+
# Start queue processing on app startup
101+
@app.before_serving
102+
async def startup():
103+
"""Start request processing task when app starts serving"""
104+
asyncio.create_task(request_queue.process())
105+
106+
async def forward_request(url, data):
107+
"""Forward request to backend service with rate limiting and error handling"""
16108
headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"}
17-
async with session.post(url=url, json=data, headers=headers) as response:
18-
if response.status == 200:
19-
# if response.headers.get('Transfer-Encoding') == 'chunked':
20-
if True:
21-
async for chunk_bytes in response.content.iter_chunked(1024):
22-
yield chunk_bytes
23-
else:
24-
content = await response.read()
25-
yield content
26-
27-
28-
@app.route("/v1/completions", methods=["POST"])
29-
async def handle_request():
30-
try:
31-
original_request_data = await request.get_json()
32-
33-
prefill_request = original_request_data.copy()
34-
# change max_tokens = 1 to let it only do prefill
35-
prefill_request["max_tokens"] = 1
36-
37-
# finish prefill
38-
async for _ in forward_request(
39-
"http://localhost:8100/v1/completions", prefill_request
40-
):
41-
continue
42109

43-
# return decode
44-
generator = forward_request(
45-
"http://localhost:8200/v1/completions", original_request_data
46-
)
47-
response = await make_response(generator)
48-
response.timeout = None
49-
50-
return response
51-
52-
except Exception as e:
53-
import sys
54-
import traceback
55-
56-
exc_info = sys.exc_info()
57-
print("Error occurred in disagg prefill proxy server")
58-
print(e)
59-
print("".join(traceback.format_exception(*exc_info)))
110+
# Use rate limiter as context manager
111+
async with (
112+
rate_limiter,
113+
aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session,
114+
):
115+
try:
116+
async with session.post(
117+
url=url, json=data, headers=headers
118+
) as response:
119+
if response.status == 200:
120+
# Stream response chunks
121+
async for chunk_bytes in response.content.iter_chunked(1024):
122+
yield chunk_bytes
123+
else:
124+
# Handle backend service errors
125+
error_text = await response.text()
126+
logger.error(
127+
"Backend service error: %s - %s",
128+
response.status,
129+
error_text,
130+
)
131+
yield b'{"error": "Backend service error"}'
132+
except aiohttp.ClientError as e:
133+
# Handle connection errors
134+
logger.error("Connection error to %s: %s", url, str(e))
135+
yield b'{"error": "Service unavailable"}'
136+
except asyncio.TimeoutError:
137+
# Handle timeout errors
138+
logger.error("Timeout connecting to %s", url)
139+
yield b'{"error": "Service timeout"}'
140+
141+
async def process_request():
142+
"""Process a single request through prefill and decode stages"""
143+
try:
144+
original_request_data = await request.get_json()
145+
146+
# Create prefill request (max_tokens=1)
147+
prefill_request = original_request_data.copy()
148+
prefill_request["max_tokens"] = 1
149+
150+
# Execute prefill stage
151+
async for _ in forward_request(PREFILL_SERVICE_URL, prefill_request):
152+
continue
153+
154+
# Execute decode stage and stream response
155+
generator = forward_request(DECODE_SERVICE_URL, original_request_data)
156+
response = await make_response(generator)
157+
response.timeout = None # Disable timeout for streaming response
158+
return response
159+
160+
except Exception:
161+
logger.exception("Error processing request")
162+
return Response(
163+
response=b'{"error": "Internal server error"}',
164+
status=500,
165+
content_type="application/json",
166+
)
167+
168+
@app.route("/v1/completions", methods=["POST"])
169+
async def handle_request():
170+
"""Handle incoming API requests with concurrency and rate limiting"""
171+
# Create task for request processing
172+
task = asyncio.create_task(process_request())
173+
174+
# Enqueue request or reject if queue is full
175+
if not await request_queue.enqueue(task):
176+
return Response(
177+
response=b'{"error": "Server busy, try again later"}',
178+
status=503,
179+
content_type="application/json",
180+
)
181+
182+
try:
183+
# Return the response from the processing task
184+
return await task
185+
except asyncio.CancelledError:
186+
# Handle task cancellation (timeout or queue full)
187+
logger.warning("Request cancelled due to timeout or queue full")
188+
return Response(
189+
response=b'{"error": "Request cancelled"}',
190+
status=503,
191+
content_type="application/json",
192+
)
193+
194+
# Start the Quart server with host can be set to 0.0.0.0
195+
app.run(port=PORT)
60196

61197

62198
if __name__ == "__main__":
63-
app.run(port=8000)
199+
main()
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
4+
import asyncio
5+
import time
6+
7+
8+
class RateLimiter:
9+
"""Token bucket rate limiter implementation"""
10+
11+
def __init__(self, rate_limit):
12+
self.rate_limit = rate_limit # Requests per second
13+
self.num_available_tokens = rate_limit # Available tokens
14+
self.last_refill = time.monotonic() # Last token refill time
15+
self.lock = asyncio.Lock() # Synchronization lock
16+
17+
async def acquire(self):
18+
"""Acquire a token from the rate limiter"""
19+
while True:
20+
async with self.lock:
21+
current_time = time.monotonic()
22+
elapsed = current_time - self.last_refill
23+
24+
# Refill num_available_tokens if more than 1 second has passed
25+
if elapsed > 1.0:
26+
self.num_available_tokens = self.rate_limit
27+
self.last_refill = current_time
28+
29+
# Check if num_available_tokens are available
30+
if self.num_available_tokens > 0:
31+
self.num_available_tokens -= 1
32+
return True
33+
34+
# Calculate wait time if no num_available_tokens available
35+
wait_time = 1.0 - elapsed
36+
await asyncio.sleep(wait_time)
37+
38+
async def __aenter__(self):
39+
"""Enter async context manager - acquire token"""
40+
await self.acquire()
41+
return self
42+
43+
async def __aexit__(self, exc_type, exc_value, traceback):
44+
"""Exit async context manager - no cleanup needed"""
45+
pass
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
4+
import asyncio
5+
from collections import deque
6+
7+
8+
class RequestQueue:
9+
"""Request queue manager with concurrency control"""
10+
11+
def __init__(self, max_concurrent, max_queue_size):
12+
# Maximum concurrent requests
13+
self.max_concurrent = max_concurrent
14+
self.max_queue_size = max_queue_size # Maximum queue size
15+
# Concurrency control
16+
self.semaphore = asyncio.Semaphore(max_concurrent)
17+
self.queue = deque() # Request queue
18+
self.queue_size = 0 # Current queue size
19+
self.lock = asyncio.Lock() # Sync queue Lock
20+
21+
async def enqueue(self, task):
22+
"""Add a request task to the queue"""
23+
async with self.lock:
24+
if self.queue_size >= self.max_queue_size:
25+
return False
26+
27+
self.queue.append(task)
28+
self.queue_size += 1
29+
return True
30+
31+
async def process(self):
32+
"""Process queued requests using semaphore for concurrency control"""
33+
while True:
34+
if self.queue:
35+
async with self.semaphore, self.lock:
36+
task = self.queue.popleft()
37+
self.queue_size -= 1
38+
await task
39+
await asyncio.sleep(0.01) # Yield control to event loop

0 commit comments

Comments
 (0)