forked from vllm-project/vllm
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_error_handling.py
293 lines (221 loc) · 9.52 KB
/
test_error_handling.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
"""Test that various errors are handled properly."""
import asyncio
import tempfile
import time
import uuid
from unittest.mock import Mock
import pytest
from tests.mq_llm_engine.utils import RemoteMQLLMEngine
from vllm import SamplingParams
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.llm_engine import LLMEngine
from vllm.engine.multiprocessing import MQEngineDeadError
from vllm.engine.multiprocessing.engine import MQLLMEngine
from vllm.entrypoints.openai.api_server import build_async_engine_client
from vllm.entrypoints.openai.cli_args import make_arg_parser
from vllm.lora.request import LoRARequest
from vllm.usage.usage_lib import UsageContext
from vllm.utils import FlexibleArgumentParser
MODEL = "google/gemma-1.1-2b-it"
ENGINE_ARGS = AsyncEngineArgs(model=MODEL, enforce_eager=True)
RAISED_ERROR = KeyError
RAISED_VALUE = "foo"
@pytest.fixture(scope="function")
def tmp_socket():
with tempfile.TemporaryDirectory() as td:
yield f"ipc://{td}/{uuid.uuid4()}"
def run_with_evil_forward(engine_args: AsyncEngineArgs, ipc_path: str):
# Make engine.
engine = MQLLMEngine.from_engine_args(
engine_args=engine_args,
usage_context=UsageContext.UNKNOWN_CONTEXT,
ipc_path=ipc_path)
# Raise error during first forward pass.
engine.engine.model_executor.execute_model = Mock(
side_effect=RAISED_ERROR(RAISED_VALUE))
# Run engine.
engine.start()
@pytest.mark.asyncio
async def test_evil_forward(tmp_socket):
with RemoteMQLLMEngine(engine_args=ENGINE_ARGS,
ipc_path=tmp_socket,
run_fn=run_with_evil_forward) as engine:
client = await engine.make_client()
# Server should be healthy after initial probe.
await asyncio.sleep(2.0)
await client.check_health()
# Throws an error that should get ENGINE_DEAD_ERROR.
with pytest.raises(MQEngineDeadError):
async for _ in client.generate(prompt="Hello my name is",
sampling_params=SamplingParams(),
request_id=uuid.uuid4()):
pass
assert client.errored
await asyncio.sleep(1.0)
with pytest.raises(RAISED_ERROR):
await client.check_health()
assert client.errored
# Shutdown.
client.close()
def run_with_evil_model_executor_health(engine_args: AsyncEngineArgs,
ipc_path: str):
# Make engine.
engine = MQLLMEngine.from_engine_args(
engine_args=engine_args,
usage_context=UsageContext.UNKNOWN_CONTEXT,
ipc_path=ipc_path)
# Raise error during first forward pass.
engine.engine.model_executor.check_health = Mock(side_effect=RAISED_ERROR)
# Run engine.
engine.start()
@pytest.mark.asyncio
async def test_failed_health_check(tmp_socket):
with RemoteMQLLMEngine(
engine_args=ENGINE_ARGS,
ipc_path=tmp_socket,
run_fn=run_with_evil_model_executor_health) as engine:
client = await engine.make_client()
assert client.is_running
# Health probe should throw RAISED_ERROR.
await asyncio.sleep(15.)
with pytest.raises(RAISED_ERROR):
await client.check_health()
assert client.errored
# Generate call should throw ENGINE_DEAD_ERROR
with pytest.raises(MQEngineDeadError):
async for _ in client.generate(prompt="Hello my name is",
sampling_params=SamplingParams(),
request_id=uuid.uuid4()):
pass
client.close()
def run_with_evil_abort(engine_args: AsyncEngineArgs, ipc_path: str):
# Make engine.
engine = MQLLMEngine.from_engine_args(
engine_args=engine_args,
usage_context=UsageContext.UNKNOWN_CONTEXT,
ipc_path=ipc_path)
# Raise error during abort call.
engine.engine.abort_request = Mock(side_effect=RAISED_ERROR)
# Run engine.
engine.start()
@pytest.mark.asyncio
async def test_failed_abort(tmp_socket):
with RemoteMQLLMEngine(engine_args=ENGINE_ARGS,
ipc_path=tmp_socket,
run_fn=run_with_evil_abort) as engine:
client = await engine.make_client()
assert client.is_running
# First check health should work.
await client.check_health()
# Trigger an abort on the client side.
# This request ID does not exist, and will cause the engine to error
await client.abort(request_id="foo")
# Future generation requests will now fail
# with reference to the original KeyError("foo")
with pytest.raises(MQEngineDeadError) as execinfo:
async for _ in client.generate(
prompt="Hello my name is",
sampling_params=SamplingParams(max_tokens=10),
request_id=uuid.uuid4()):
pass
assert "KeyError" in repr(execinfo.value)
assert client.errored
# This should raise the original error.
with pytest.raises(RAISED_ERROR):
await client.check_health()
client.close()
@pytest.mark.asyncio
async def test_batch_error(tmp_socket):
with RemoteMQLLMEngine(engine_args=ENGINE_ARGS,
ipc_path=tmp_socket,
run_fn=run_with_evil_abort) as engine:
client = await engine.make_client()
assert client.is_running
# First check health should work.
await client.check_health()
# Batch of requests
async def do_generate(client):
# min_tokens=2048 to keep busy the engine busy
# to get enough time to get process a request
# that will crash the engine
params = SamplingParams(min_tokens=2048, max_tokens=2048)
async for _ in client.generate(prompt="Hello my name is",
sampling_params=params,
request_id=uuid.uuid4()):
pass
tasks = [asyncio.create_task(do_generate(client)) for _ in range(10)]
# This request will force a processing batch to raise
# an exception and next the engine get errored
await client.abort(request_id="foo")
# The batch of those request failed, then they
# should get the same exception as a MQEngineDeadError.
errors = await asyncio.gather(*tasks, return_exceptions=True)
for e in errors:
assert isinstance(e, MQEngineDeadError)
assert "KeyError" in repr(e)
client.close()
@pytest.mark.asyncio
async def test_bad_request(tmp_socket):
with RemoteMQLLMEngine(engine_args=ENGINE_ARGS,
ipc_path=tmp_socket) as engine:
client = await engine.make_client()
# Invalid request should fail, but not crash the server.
with pytest.raises(ValueError):
async for _ in client.generate(prompt="Hello my name is",
sampling_params=SamplingParams(),
request_id="abcd-1",
lora_request=LoRARequest(
"invalid-lora", 1,
"invalid-path")):
pass
# This request should be okay.
async for _ in client.generate(prompt="Hello my name is",
sampling_params=SamplingParams(),
request_id="abcd-2"):
pass
# Shutdown.
client.close()
@pytest.mark.asyncio
async def test_mp_crash_detection(monkeypatch):
parser = FlexibleArgumentParser(description="vLLM's remote OpenAI server.")
parser = make_arg_parser(parser)
args = parser.parse_args([])
# When LLMEngine is loaded, it will crash.
def mock_init():
raise ValueError
monkeypatch.setattr(LLMEngine, "__init__", mock_init)
start = time.perf_counter()
async with build_async_engine_client(args):
pass
end = time.perf_counter()
assert end - start < 60, ("Expected vLLM to gracefully shutdown in <60s "
"if there is an error in the startup.")
@pytest.mark.asyncio
async def test_mp_cuda_init():
# it should not crash, when cuda is initialized
# in the API server process
import torch
torch.cuda.init()
parser = FlexibleArgumentParser(description="vLLM's remote OpenAI server.")
parser = make_arg_parser(parser)
args = parser.parse_args([])
async with build_async_engine_client(args):
pass
@pytest.mark.asyncio
async def test_engine_process_death(tmp_socket):
with RemoteMQLLMEngine(engine_args=ENGINE_ARGS,
ipc_path=tmp_socket) as engine:
client = await engine.make_client()
assert client.is_running
# kill the engine process
engine.proc.kill()
# Generate call should fail
with pytest.raises(MQEngineDeadError):
async for _ in client.generate(prompt="Hello my name is",
sampling_params=SamplingParams(),
request_id=uuid.uuid4()):
pass
# And the health check should show the engine is dead
with pytest.raises(RuntimeError, match="Engine process .* died"):
await client.check_health()
client.close()