Skip to content

Commit d8e319b

Browse files
DarkLight1337jimpang
authored andcommitted
[CI/Build] Simplify OpenAI server setup in tests (vllm-project#5100)
1 parent 809496b commit d8e319b

File tree

6 files changed

+285
-238
lines changed

6 files changed

+285
-238
lines changed

tests/async_engine/test_openapi_server_ray.py

Lines changed: 15 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,22 @@
44
# and debugging.
55
import ray
66

7-
from ..utils import ServerRunner
7+
from ..utils import VLLM_PATH, RemoteOpenAIServer
88

99
# any model with a chat template should work here
1010
MODEL_NAME = "facebook/opt-125m"
1111

1212

1313
@pytest.fixture(scope="module")
14-
def server():
15-
ray.init()
16-
server_runner = ServerRunner.remote([
14+
def ray_ctx():
15+
ray.init(runtime_env={"working_dir": VLLM_PATH})
16+
yield
17+
ray.shutdown()
18+
19+
20+
@pytest.fixture(scope="module")
21+
def server(ray_ctx):
22+
return RemoteOpenAIServer([
1723
"--model",
1824
MODEL_NAME,
1925
# use half precision for speed and memory savings in CI environment
@@ -24,22 +30,15 @@ def server():
2430
"--enforce-eager",
2531
"--engine-use-ray"
2632
])
27-
ray.get(server_runner.ready.remote())
28-
yield server_runner
29-
ray.shutdown()
3033

3134

3235
@pytest.fixture(scope="module")
33-
def client():
34-
client = openai.AsyncOpenAI(
35-
base_url="http://localhost:8000/v1",
36-
api_key="token-abc123",
37-
)
38-
yield client
36+
def client(server):
37+
return server.get_async_client()
3938

4039

4140
@pytest.mark.asyncio
42-
async def test_check_models(server, client: openai.AsyncOpenAI):
41+
async def test_check_models(client: openai.AsyncOpenAI):
4342
models = await client.models.list()
4443
models = models.data
4544
served_model = models[0]
@@ -48,7 +47,7 @@ async def test_check_models(server, client: openai.AsyncOpenAI):
4847

4948

5049
@pytest.mark.asyncio
51-
async def test_single_completion(server, client: openai.AsyncOpenAI):
50+
async def test_single_completion(client: openai.AsyncOpenAI):
5251
completion = await client.completions.create(model=MODEL_NAME,
5352
prompt="Hello, my name is",
5453
max_tokens=5,
@@ -72,7 +71,7 @@ async def test_single_completion(server, client: openai.AsyncOpenAI):
7271

7372

7473
@pytest.mark.asyncio
75-
async def test_single_chat_session(server, client: openai.AsyncOpenAI):
74+
async def test_single_chat_session(client: openai.AsyncOpenAI):
7675
messages = [{
7776
"role": "system",
7877
"content": "you are a helpful assistant"
Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
import openai
2+
import pytest
3+
import ray
4+
5+
from ..utils import VLLM_PATH, RemoteOpenAIServer
6+
7+
EMBEDDING_MODEL_NAME = "intfloat/e5-mistral-7b-instruct"
8+
9+
pytestmark = pytest.mark.openai
10+
11+
12+
@pytest.fixture(scope="module")
13+
def ray_ctx():
14+
ray.init(runtime_env={"working_dir": VLLM_PATH})
15+
yield
16+
ray.shutdown()
17+
18+
19+
@pytest.fixture(scope="module")
20+
def embedding_server(ray_ctx):
21+
return RemoteOpenAIServer([
22+
"--model",
23+
EMBEDDING_MODEL_NAME,
24+
# use half precision for speed and memory savings in CI environment
25+
"--dtype",
26+
"bfloat16",
27+
"--enforce-eager",
28+
"--max-model-len",
29+
"8192",
30+
"--enforce-eager",
31+
])
32+
33+
34+
@pytest.mark.asyncio
35+
@pytest.fixture(scope="module")
36+
def embedding_client(embedding_server):
37+
return embedding_server.get_async_client()
38+
39+
40+
@pytest.mark.asyncio
41+
@pytest.mark.parametrize(
42+
"model_name",
43+
[EMBEDDING_MODEL_NAME],
44+
)
45+
async def test_single_embedding(embedding_client: openai.AsyncOpenAI,
46+
model_name: str):
47+
input_texts = [
48+
"The chef prepared a delicious meal.",
49+
]
50+
51+
# test single embedding
52+
embeddings = await embedding_client.embeddings.create(
53+
model=model_name,
54+
input=input_texts,
55+
encoding_format="float",
56+
)
57+
assert embeddings.id is not None
58+
assert len(embeddings.data) == 1
59+
assert len(embeddings.data[0].embedding) == 4096
60+
assert embeddings.usage.completion_tokens == 0
61+
assert embeddings.usage.prompt_tokens == 9
62+
assert embeddings.usage.total_tokens == 9
63+
64+
# test using token IDs
65+
input_tokens = [1, 1, 1, 1, 1]
66+
embeddings = await embedding_client.embeddings.create(
67+
model=model_name,
68+
input=input_tokens,
69+
encoding_format="float",
70+
)
71+
assert embeddings.id is not None
72+
assert len(embeddings.data) == 1
73+
assert len(embeddings.data[0].embedding) == 4096
74+
assert embeddings.usage.completion_tokens == 0
75+
assert embeddings.usage.prompt_tokens == 5
76+
assert embeddings.usage.total_tokens == 5
77+
78+
79+
@pytest.mark.asyncio
80+
@pytest.mark.parametrize(
81+
"model_name",
82+
[EMBEDDING_MODEL_NAME],
83+
)
84+
async def test_batch_embedding(embedding_client: openai.AsyncOpenAI,
85+
model_name: str):
86+
# test List[str]
87+
input_texts = [
88+
"The cat sat on the mat.", "A feline was resting on a rug.",
89+
"Stars twinkle brightly in the night sky."
90+
]
91+
embeddings = await embedding_client.embeddings.create(
92+
model=model_name,
93+
input=input_texts,
94+
encoding_format="float",
95+
)
96+
assert embeddings.id is not None
97+
assert len(embeddings.data) == 3
98+
assert len(embeddings.data[0].embedding) == 4096
99+
100+
# test List[List[int]]
101+
input_tokens = [[4, 5, 7, 9, 20], [15, 29, 499], [24, 24, 24, 24, 24],
102+
[25, 32, 64, 77]]
103+
embeddings = await embedding_client.embeddings.create(
104+
model=model_name,
105+
input=input_tokens,
106+
encoding_format="float",
107+
)
108+
assert embeddings.id is not None
109+
assert len(embeddings.data) == 4
110+
assert len(embeddings.data[0].embedding) == 4096
111+
assert embeddings.usage.completion_tokens == 0
112+
assert embeddings.usage.prompt_tokens == 17
113+
assert embeddings.usage.total_tokens == 17

0 commit comments

Comments
 (0)