Skip to content

Commit

Permalink
Fix sync subscribe graceful shutdown (#395)
Browse files Browse the repository at this point in the history
  • Loading branch information
leszekhanusz authored Feb 23, 2023
1 parent 5e37e6a commit 905b724
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 6 deletions.
11 changes: 5 additions & 6 deletions gql/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -593,15 +593,14 @@ def subscribe(
except StopAsyncIteration:
pass

except (KeyboardInterrupt, Exception):
except (KeyboardInterrupt, Exception, GeneratorExit):

# Graceful shutdown
asyncio.ensure_future(async_generator.aclose(), loop=loop)

# Graceful shutdown by cancelling the task and waiting clean shutdown
generator_task.cancel()

try:
loop.run_until_complete(generator_task)
except (StopAsyncIteration, asyncio.CancelledError):
pass
loop.run_until_complete(loop.shutdown_asyncgens())

# Then reraise the exception
raise
Expand Down
60 changes: 60 additions & 0 deletions tests/test_websocket_subscription.py
Original file line number Diff line number Diff line change
Expand Up @@ -494,6 +494,66 @@ def test_websocket_subscription_sync(server, subscription_str):
assert count == -1


@pytest.mark.parametrize("server", [server_countdown], indirect=True)
@pytest.mark.parametrize("subscription_str", [countdown_subscription_str])
def test_websocket_subscription_sync_user_exception(server, subscription_str):
from gql.transport.websockets import WebsocketsTransport

url = f"ws://{server.hostname}:{server.port}/graphql"
print(f"url = {url}")

sample_transport = WebsocketsTransport(url=url)

client = Client(transport=sample_transport)

count = 10
subscription = gql(subscription_str.format(count=count))

with pytest.raises(Exception) as exc_info:
for result in client.subscribe(subscription):

number = result["number"]
print(f"Number received: {number}")

assert number == count
count -= 1

if count == 5:
raise Exception("This is an user exception")

assert count == 5
assert "This is an user exception" in str(exc_info.value)


@pytest.mark.parametrize("server", [server_countdown], indirect=True)
@pytest.mark.parametrize("subscription_str", [countdown_subscription_str])
def test_websocket_subscription_sync_break(server, subscription_str):
from gql.transport.websockets import WebsocketsTransport

url = f"ws://{server.hostname}:{server.port}/graphql"
print(f"url = {url}")

sample_transport = WebsocketsTransport(url=url)

client = Client(transport=sample_transport)

count = 10
subscription = gql(subscription_str.format(count=count))

for result in client.subscribe(subscription):

number = result["number"]
print(f"Number received: {number}")

assert number == count
count -= 1

if count == 5:
break

assert count == 5


@pytest.mark.skipif(sys.platform.startswith("win"), reason="test failing on windows")
@pytest.mark.parametrize("server", [server_countdown], indirect=True)
@pytest.mark.parametrize("subscription_str", [countdown_subscription_str])
Expand Down

0 comments on commit 905b724

Please sign in to comment.