Skip to content

Commit

Permalink
[Bugfix][CI/Build] Fix test and improve code for `merge_async_iterato…
Browse files Browse the repository at this point in the history
  • Loading branch information
DarkLight1337 authored and joerunde committed Jun 3, 2024
1 parent 9fa589e commit d603c5d
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 45 deletions.
41 changes: 0 additions & 41 deletions tests/async_engine/test_merge_async_iterators.py

This file was deleted.

57 changes: 56 additions & 1 deletion tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,64 @@
import asyncio
import sys
from typing import (TYPE_CHECKING, Any, AsyncIterator, Awaitable, Protocol,
Tuple, TypeVar)

import pytest

from vllm.utils import deprecate_kwargs
from vllm.utils import deprecate_kwargs, merge_async_iterators

from .utils import error_on_warning

if sys.version_info < (3, 10):
if TYPE_CHECKING:
_AwaitableT = TypeVar("_AwaitableT", bound=Awaitable[Any])
_AwaitableT_co = TypeVar("_AwaitableT_co",
bound=Awaitable[Any],
covariant=True)

class _SupportsSynchronousAnext(Protocol[_AwaitableT_co]):

def __anext__(self) -> _AwaitableT_co:
...

def anext(i: "_SupportsSynchronousAnext[_AwaitableT]", /) -> "_AwaitableT":
return i.__anext__()


@pytest.mark.asyncio
async def test_merge_async_iterators():

async def mock_async_iterator(idx: int) -> AsyncIterator[str]:
try:
while True:
yield f"item from iterator {idx}"
await asyncio.sleep(0.1)
except asyncio.CancelledError:
pass

iterators = [mock_async_iterator(i) for i in range(3)]
merged_iterator: AsyncIterator[Tuple[int, str]] = merge_async_iterators(
*iterators)

async def stream_output(generator: AsyncIterator[Tuple[int, str]]):
async for idx, output in generator:
print(f"idx: {idx}, output: {output}")

task = asyncio.create_task(stream_output(merged_iterator))
await asyncio.sleep(0.5)
task.cancel()
with pytest.raises(asyncio.CancelledError):
await task

for iterator in iterators:
try:
await asyncio.wait_for(anext(iterator), 1)
except StopAsyncIteration:
# All iterators should be cancelled and print this message.
print("Iterator was cancelled normally")
except (Exception, asyncio.CancelledError) as e:
raise AssertionError() from e


def test_deprecate_kwargs_always():

Expand Down
9 changes: 6 additions & 3 deletions vllm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import os
import socket
import subprocess
import sys
import tempfile
import threading
import uuid
Expand Down Expand Up @@ -234,9 +235,11 @@ async def consumer():
yield item
except (Exception, asyncio.CancelledError) as e:
for task in _tasks:
# NOTE: Pass the error msg in cancel()
# when only Python 3.9+ is supported.
task.cancel()
if sys.version_info >= (3, 9):
# msg parameter only supported in Python 3.9+
task.cancel(e)
else:
task.cancel()
raise e
await asyncio.gather(*_tasks)

Expand Down

0 comments on commit d603c5d

Please sign in to comment.