Skip to content

[Bugfix][Core] Update cancellation logic in generate() to handle Generator exits #19225

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
169 changes: 126 additions & 43 deletions tests/v1/engine/test_async_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,11 @@
pytest.skip(reason="V1 currently only supported on CUDA.",
allow_module_level=True)

TEXT_ENGINE_ARGS = AsyncEngineArgs(model="meta-llama/Llama-3.2-1B-Instruct",
enforce_eager=True,
disable_log_requests=True)
TEXT_ENGINE_ARGS = AsyncEngineArgs(
model="meta-llama/Llama-3.2-1B-Instruct",
enforce_eager=True,
disable_log_requests=True,
)

VISION_ENGINE_ARGS = AsyncEngineArgs(model="Qwen/Qwen2-VL-2B-Instruct",
enforce_eager=True,
Expand All @@ -41,28 +43,33 @@
"prompt": VISION_PROMPT_TEMPLATE,
"multi_modal_data": {
"image": ImageAsset("stop_sign").pil_image
}
},
}


async def generate(engine: AsyncLLM,
request_id: str,
prompt: PromptType,
output_kind: RequestOutputKind,
max_tokens: int,
n: int = 1,
prompt_logprobs: Optional[int] = None) -> tuple[int, str]:
async def generate(
engine: AsyncLLM,
request_id: str,
prompt: PromptType,
output_kind: RequestOutputKind,
max_tokens: int,
n: int = 1,
prompt_logprobs: Optional[int] = None,
cancel_after: Optional[int] = None,
) -> tuple[int, str]:
# Ensure generate doesn't complete too fast for cancellation test.
await asyncio.sleep(0.2)

count = 0
sampling_params = SamplingParams(max_tokens=max_tokens,
ignore_eos=True,
output_kind=output_kind,
temperature=0.5,
seed=33,
n=n,
prompt_logprobs=prompt_logprobs)
sampling_params = SamplingParams(
max_tokens=max_tokens,
ignore_eos=True,
output_kind=output_kind,
temperature=0.5,
seed=33,
n=n,
prompt_logprobs=prompt_logprobs,
)
async for out in engine.generate(request_id=request_id,
prompt=prompt,
sampling_params=sampling_params):
Expand All @@ -73,20 +80,27 @@ async def generate(engine: AsyncLLM,
else:
count = num_tokens

await asyncio.sleep(0.)
if cancel_after is not None and count >= cancel_after:
return count, request_id

await asyncio.sleep(0.0)

return count, request_id


@pytest.mark.parametrize(
"output_kind", [RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY])
@pytest.mark.parametrize("engine_args,prompt",
[(TEXT_ENGINE_ARGS, TEXT_PROMPT),
(VISION_ENGINE_ARGS, VISION_PROMPT)])
@pytest.mark.parametrize(
"engine_args,prompt",
[(TEXT_ENGINE_ARGS, TEXT_PROMPT), (VISION_ENGINE_ARGS, VISION_PROMPT)],
)
@pytest.mark.asyncio
async def test_load(monkeypatch: pytest.MonkeyPatch,
output_kind: RequestOutputKind,
engine_args: AsyncEngineArgs, prompt: PromptType):
async def test_load(
monkeypatch: pytest.MonkeyPatch,
output_kind: RequestOutputKind,
engine_args: AsyncEngineArgs,
prompt: PromptType,
):
# TODO(rickyx): Remove monkeypatch once we have a better way to test V1
# so that in the future when we switch, we don't have to change all the
# tests.
Expand Down Expand Up @@ -125,13 +139,17 @@ async def test_load(monkeypatch: pytest.MonkeyPatch,

@pytest.mark.parametrize(
"output_kind", [RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY])
@pytest.mark.parametrize("engine_args,prompt",
[(TEXT_ENGINE_ARGS, TEXT_PROMPT),
(VISION_ENGINE_ARGS, VISION_PROMPT)])
@pytest.mark.parametrize(
"engine_args,prompt",
[(TEXT_ENGINE_ARGS, TEXT_PROMPT), (VISION_ENGINE_ARGS, VISION_PROMPT)],
)
@pytest.mark.asyncio
async def test_abort(monkeypatch: pytest.MonkeyPatch,
output_kind: RequestOutputKind,
engine_args: AsyncEngineArgs, prompt: PromptType):
async def test_abort(
monkeypatch: pytest.MonkeyPatch,
output_kind: RequestOutputKind,
engine_args: AsyncEngineArgs,
prompt: PromptType,
):

with monkeypatch.context() as m, ExitStack() as after:
m.setenv("VLLM_USE_V1", "1")
Expand All @@ -150,8 +168,9 @@ async def test_abort(monkeypatch: pytest.MonkeyPatch,
# Create concurrent requests.
tasks: list[asyncio.Task] = []
for idx, request_id in enumerate(request_ids):
max_tokens = NUM_EXPECTED_TOKENS_LONG if (
idx in REQUEST_IDS_TO_ABORT) else NUM_EXPECTED_TOKENS
max_tokens = (NUM_EXPECTED_TOKENS_LONG if
(idx
in REQUEST_IDS_TO_ABORT) else NUM_EXPECTED_TOKENS)
n = 3 if idx in PARALLEL_SAMPLE_REQ_IDS else 1
tasks.append(
asyncio.create_task(
Expand Down Expand Up @@ -192,24 +211,31 @@ async def test_abort(monkeypatch: pytest.MonkeyPatch,


@pytest.mark.parametrize("n", [1, 3])
@pytest.mark.parametrize("engine_args,prompt",
[(TEXT_ENGINE_ARGS, TEXT_PROMPT),
(VISION_ENGINE_ARGS, VISION_PROMPT)])
@pytest.mark.parametrize(
"engine_args,prompt",
[(TEXT_ENGINE_ARGS, TEXT_PROMPT), (VISION_ENGINE_ARGS, VISION_PROMPT)],
)
@pytest.mark.asyncio
async def test_finished_flag(monkeypatch: pytest.MonkeyPatch, n: int,
engine_args: AsyncEngineArgs, prompt: PromptType):
async def test_finished_flag(
monkeypatch: pytest.MonkeyPatch,
n: int,
engine_args: AsyncEngineArgs,
prompt: PromptType,
):

with monkeypatch.context() as m, ExitStack() as after:
m.setenv("VLLM_USE_V1", "1")

engine = AsyncLLM.from_engine_args(engine_args)
after.callback(engine.shutdown)

sampling_params = SamplingParams(max_tokens=100,
output_kind=RequestOutputKind.DELTA,
temperature=1.0,
seed=33,
n=n)
sampling_params = SamplingParams(
max_tokens=100,
output_kind=RequestOutputKind.DELTA,
temperature=1.0,
seed=33,
n=n,
)
outputs = [
out
async for out in engine.generate(request_id="request-33",
Expand All @@ -222,6 +248,63 @@ async def test_finished_flag(monkeypatch: pytest.MonkeyPatch, n: int,
assert outputs[-1].finished


@pytest.mark.parametrize(
"engine_args,prompt",
[(TEXT_ENGINE_ARGS, TEXT_PROMPT), (VISION_ENGINE_ARGS, VISION_PROMPT)],
)
@pytest.mark.asyncio
async def test_mid_stream_cancellation(monkeypatch: pytest.MonkeyPatch,
engine_args: AsyncEngineArgs,
prompt: PromptType):
"""Test that requests can be cancelled mid-stream."""
with monkeypatch.context() as m, ExitStack() as after:
m.setenv("VLLM_USE_V1", "1")

engine = AsyncLLM.from_engine_args(engine_args)
after.callback(engine.shutdown)

NUM_REQUESTS = 100
NUM_TOKENS = 1000
NUM_EXPECTED_TOKENS = 20

request_ids = [f"request-{i}" for i in range(NUM_REQUESTS)]

# Create concurrent requests that will be cancelled mid-stream
tasks = []
for request_id in request_ids:
tasks.append(
asyncio.create_task(
generate(
engine,
request_id,
prompt,
RequestOutputKind.DELTA,
NUM_TOKENS,
cancel_after=NUM_EXPECTED_TOKENS,
)))

# Wait for all tasks to complete
results = await asyncio.gather(*tasks)

# Verify all tasks were cancelled at the expected point
for num_generated_tokens, request_id in results:
assert num_generated_tokens == NUM_EXPECTED_TOKENS, (
f"{request_id} generated {num_generated_tokens} tokens but "
f"expected to cancel after {NUM_EXPECTED_TOKENS}")

# Make sure no requests are left hanging
assert not engine.output_processor.has_unfinished_requests()

# Confirm we can reuse the request id after the cancellations.
request_id = request_ids[0]
task = asyncio.create_task(
generate(engine, request_id, prompt, RequestOutputKind.DELTA,
NUM_EXPECTED_TOKENS))
num_generated_tokens, request_id = await task
assert num_generated_tokens == NUM_EXPECTED_TOKENS
assert not engine.output_processor.has_unfinished_requests()


class MockLoggingStatLogger(LoggingStatLogger):

def __init__(self, vllm_config: VllmConfig, engine_index: int = 0):
Expand Down
5 changes: 3 additions & 2 deletions vllm/v1/engine/async_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,8 +332,9 @@ async def generate(
yield out

# If the request is disconnected by the client, generate()
# is cancelled. So, we abort the request if we end up here.
except asyncio.CancelledError:
# is cancelled or the generator is garbage collected. So,
# we abort the request if we end up here.
except (asyncio.CancelledError, GeneratorExit):
await self.abort(request_id)
if self.log_requests:
logger.info("Request %s aborted.", request_id)
Expand Down