Skip to content

Commit e853b3c

Browse files
Add Accuracy Test (#13)
* updated Signed-off-by: rshaw@neuralmagic.com <robertgshaw2@gmail.com> * updated Signed-off-by: rshaw@neuralmagic.com <robertgshaw2@gmail.com> * updated Signed-off-by: rshaw@neuralmagic.com <robertgshaw2@gmail.com> * updated Signed-off-by: rshaw@neuralmagic.com <robertgshaw2@gmail.com> --------- Signed-off-by: rshaw@neuralmagic.com <robertgshaw2@gmail.com>
1 parent 83f2872 commit e853b3c

File tree

3 files changed

+225
-0
lines changed

3 files changed

+225
-0
lines changed
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
#!/bin/bash
2+
3+
set -xe
4+
5+
# Model to run.
6+
MODEL_NAME=Qwen/Qwen3-0.6B
7+
8+
# Trap the SIGINT signal (triggered by Ctrl+C)
9+
trap 'kill $(jobs -pr)' SIGINT SIGTERM EXIT
10+
11+
# Cleanup function
12+
cleanup() {
13+
echo "Caught Ctrl+C, cleaning up..."
14+
# Cleanup commands
15+
pgrep python | xargs kill -9
16+
pkill -f python
17+
echo "Cleanup complete. Exiting."
18+
exit 0
19+
}
20+
21+
# Waits for vLLM to start.
22+
wait_for_server() {
23+
local port=$1
24+
timeout 1200 bash -c "
25+
until curl -s localhost:${port}/v1/completions > /dev/null; do
26+
sleep 1
27+
done" && return 0 || return 1
28+
}
29+
30+
# Prefill instance.
31+
CUDA_VISIBLE_DEVICES=0 NIXL_ROLE="SENDER" vllm serve $MODEL_NAME \
32+
--port 8100 \
33+
--enforce-eager \
34+
--disable-log-requests \
35+
--kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_both"}' &
36+
37+
# Decode instance.
38+
CUDA_VISIBLE_DEVICES=1 NIXL_ROLE="RECVER" vllm serve $MODEL_NAME \
39+
--port 8200 \
40+
--enforce-eager \
41+
--disable-log-requests \
42+
--kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_both"}' &
43+
44+
# wait until prefill and decode instances are ready
45+
wait_for_server 8100
46+
wait_for_server 8200
47+
48+
# Proxy server.
49+
python toy_proxy_server.py --port 8192 &
50+
51+
# Run lm eval.
52+
python3 -m pytest -s -x test_accuracy.py
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
import lm_eval
3+
4+
MODEL_NAME = "Qwen/Qwen3-0.6B"
5+
NUM_CONCURRENT = 100
6+
TASK = "gsm8k"
7+
FILTER = "exact_match,strict-match"
8+
RTOL = 0.03
9+
EXPECTED_VALUE = 0.41
10+
11+
12+
def test_accuracy():
13+
"""Run the end to end accuracy test."""
14+
15+
model_args = (f"model={MODEL_NAME},"
16+
f"base_url=http://localhost:8192/v1/completions,"
17+
f"num_concurrent={NUM_CONCURRENT},tokenized_requests=False")
18+
19+
results = lm_eval.simple_evaluate(
20+
model="local-completions",
21+
model_args=model_args,
22+
tasks=TASK,
23+
)
24+
25+
measured_value = results["results"][TASK][FILTER]
26+
assert (measured_value - RTOL < EXPECTED_VALUE
27+
and measured_value + RTOL > EXPECTED_VALUE
28+
), f"Expected: {EXPECTED_VALUE} | Measured: {measured_value}"
Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
3+
import argparse
4+
import os
5+
import uuid
6+
from contextlib import asynccontextmanager
7+
8+
import httpx
9+
from fastapi import FastAPI, Request
10+
from fastapi.responses import StreamingResponse
11+
12+
13+
@asynccontextmanager
14+
async def lifespan(app: FastAPI):
15+
"""
16+
Lifespan context manager to handle startup and shutdown events.
17+
"""
18+
# Startup: Initialize clients
19+
prefiller_base_url = f'http://{global_args.prefiller_host}:{global_args.prefiller_port}/v1'
20+
decoder_base_url = f'http://{global_args.decoder_host}:{global_args.decoder_port}/v1'
21+
22+
app.state.prefill_client = httpx.AsyncClient(timeout=None,
23+
base_url=prefiller_base_url)
24+
app.state.decode_client = httpx.AsyncClient(timeout=None,
25+
base_url=decoder_base_url)
26+
27+
yield
28+
29+
# Shutdown: Close clients
30+
await app.state.prefill_client.aclose()
31+
await app.state.decode_client.aclose()
32+
33+
34+
# Update FastAPI app initialization to use lifespan
35+
app = FastAPI(lifespan=lifespan)
36+
37+
38+
def parse_args():
39+
parser = argparse.ArgumentParser()
40+
41+
parser.add_argument("--port", type=int, default=8000)
42+
parser.add_argument("--host", type=str, default="localhost")
43+
parser.add_argument("--prefiller-host", type=str, default="localhost")
44+
parser.add_argument("--prefiller-port", type=int, default=8100)
45+
parser.add_argument("--decoder-host", type=str, default="localhost")
46+
parser.add_argument("--decoder-port", type=int, default=8200)
47+
args = parser.parse_args()
48+
return args
49+
50+
51+
# Initialize variables to hold the persistent clients
52+
app.state.prefill_client = None
53+
app.state.decode_client = None
54+
55+
56+
async def send_request_to_service(client: httpx.AsyncClient, endpoint: str,
57+
req_data: dict, request_id: str):
58+
"""
59+
Send a request to a service using a persistent client.
60+
"""
61+
req_data = req_data.copy()
62+
req_data['do_remote_decode'] = True
63+
req_data["stream"] = False
64+
headers = {
65+
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
66+
"X-Request-Id": request_id
67+
}
68+
response = await client.post(endpoint, json=req_data, headers=headers)
69+
response.raise_for_status()
70+
71+
return response
72+
73+
74+
async def stream_service_response(client: httpx.AsyncClient, endpoint: str,
75+
req_data: dict, remote_block_ids: list[int],
76+
remote_engine_id: str, request_id: str):
77+
"""
78+
Asynchronously stream the response from a service using a persistent client.
79+
"""
80+
headers = {
81+
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
82+
"X-Request-Id": request_id
83+
}
84+
req_data['do_remote_prefill'] = True
85+
req_data["remote_block_ids"] = remote_block_ids
86+
req_data['remote_engine_id'] = remote_engine_id
87+
async with client.stream("POST", endpoint, json=req_data,
88+
headers=headers) as response:
89+
response.raise_for_status()
90+
async for chunk in response.aiter_bytes():
91+
yield chunk
92+
93+
94+
@app.post("/v1/completions")
95+
async def handle_completions(request: Request):
96+
try:
97+
req_data = await request.json()
98+
99+
request_id = str(uuid.uuid4())
100+
101+
# Send request to prefill service
102+
response = await send_request_to_service(app.state.prefill_client,
103+
"/completions", req_data,
104+
request_id)
105+
106+
# Extract the needed fields
107+
response_json = response.json()
108+
remote_block_ids = response_json.get('remote_block_ids', [])
109+
remote_engine_id = response_json.get('remote_engine_id', '')
110+
111+
# Add these to the request data for the decoder
112+
req_data['remote_block_ids'] = remote_block_ids
113+
req_data['remote_engine_id'] = remote_engine_id
114+
115+
# Stream response from decode service
116+
async def generate_stream():
117+
async for chunk in stream_service_response(
118+
app.state.decode_client,
119+
"/completions",
120+
req_data,
121+
remote_block_ids=remote_block_ids,
122+
remote_engine_id=remote_engine_id,
123+
request_id=request_id):
124+
yield chunk
125+
126+
return StreamingResponse(generate_stream(),
127+
media_type="application/json")
128+
129+
except Exception as e:
130+
import sys
131+
import traceback
132+
exc_info = sys.exc_info()
133+
print("Error occurred in disagg prefill proxy server"
134+
" - completions endpoint")
135+
print(e)
136+
print("".join(traceback.format_exception(*exc_info)))
137+
raise
138+
139+
140+
if __name__ == '__main__':
141+
global global_args
142+
global_args = parse_args()
143+
144+
import uvicorn
145+
uvicorn.run(app, host=global_args.host, port=global_args.port)

0 commit comments

Comments
 (0)