Skip to content

Upgrade lastest websockets and Exceptions overhaul #543

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions gql/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@

from .graphql_request import GraphQLRequest
from .transport.async_transport import AsyncTransport
from .transport.exceptions import TransportClosed, TransportQueryError
from .transport.exceptions import TransportConnectionFailed, TransportQueryError
from .transport.local_schema import LocalSchemaTransport
from .transport.transport import Transport
from .utilities import build_client_schema, get_introspection_query_ast
Expand Down Expand Up @@ -1730,6 +1730,7 @@ async def _connection_loop(self):
# Then wait for the reconnect event
self._reconnect_request_event.clear()
await self._reconnect_request_event.wait()
await self.transport.close()

async def start_connecting_task(self):
"""Start the task responsible to restart the connection
Expand Down Expand Up @@ -1758,7 +1759,7 @@ async def _execute_once(
**kwargs: Any,
) -> ExecutionResult:
"""Same Coroutine as parent method _execute but requesting a
reconnection if we receive a TransportClosed exception.
reconnection if we receive a TransportConnectionFailed exception.
"""

try:
Expand All @@ -1770,7 +1771,7 @@ async def _execute_once(
parse_result=parse_result,
**kwargs,
)
except TransportClosed:
except TransportConnectionFailed:
self._reconnect_request_event.set()
raise

Expand All @@ -1786,7 +1787,8 @@ async def _execute(
**kwargs: Any,
) -> ExecutionResult:
"""Same Coroutine as parent, but with optional retries
and requesting a reconnection if we receive a TransportClosed exception.
and requesting a reconnection if we receive a
TransportConnectionFailed exception.
"""

return await self._execute_with_retries(
Expand All @@ -1808,7 +1810,7 @@ async def _subscribe(
**kwargs: Any,
) -> AsyncGenerator[ExecutionResult, None]:
"""Same Async generator as parent method _subscribe but requesting a
reconnection if we receive a TransportClosed exception.
reconnection if we receive a TransportConnectionFailed exception.
"""

inner_generator: AsyncGenerator[ExecutionResult, None] = super()._subscribe(
Expand All @@ -1824,7 +1826,7 @@ async def _subscribe(
async for result in inner_generator:
yield result

except TransportClosed:
except TransportConnectionFailed:
self._reconnect_request_event.set()
raise

Expand Down
11 changes: 8 additions & 3 deletions gql/transport/common/adapters/aiohttp.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,12 +178,14 @@ async def send(self, message: str) -> None:
TransportConnectionFailed: If connection closed
"""
if self.websocket is None:
raise TransportConnectionFailed("Connection is already closed")
raise TransportConnectionFailed("WebSocket connection is already closed")

try:
await self.websocket.send_str(message)
except ConnectionResetError as e:
raise TransportConnectionFailed("Connection was closed") from e
except Exception as e:
raise TransportConnectionFailed(
f"Error trying to send data: {type(e).__name__}"
) from e

async def receive(self) -> str:
"""Receive message from the WebSocket server.
Expand All @@ -200,6 +202,9 @@ async def receive(self) -> str:
raise TransportConnectionFailed("Connection is already closed")

while True:
# Should not raise any exception:
# https://docs.aiohttp.org/en/stable/_modules/aiohttp/client_ws.html
# #ClientWebSocketResponse.receive
ws_message = await self.websocket.receive()

# Ignore low-level ping and pong received
Expand Down
22 changes: 14 additions & 8 deletions gql/transport/common/adapters/websockets.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import Any, Dict, Optional, Union

import websockets
from websockets.client import WebSocketClientProtocol
from websockets import ClientConnection
from websockets.datastructures import Headers, HeadersLike

from ...exceptions import TransportConnectionFailed, TransportProtocolError
Expand Down Expand Up @@ -40,7 +40,7 @@ def __init__(
self._headers: Optional[HeadersLike] = headers
self.ssl = ssl

self.websocket: Optional[WebSocketClientProtocol] = None
self.websocket: Optional[ClientConnection] = None
self._response_headers: Optional[Headers] = None

async def connect(self) -> None:
Expand All @@ -57,7 +57,7 @@ async def connect(self) -> None:
# Set default arguments used in the websockets.connect call
connect_args: Dict[str, Any] = {
"ssl": ssl,
"extra_headers": self.headers,
"additional_headers": self.headers,
}

if self.subprotocols:
Expand All @@ -68,11 +68,13 @@ async def connect(self) -> None:

# Connection to the specified url
try:
self.websocket = await websockets.client.connect(self.url, **connect_args)
self.websocket = await websockets.connect(self.url, **connect_args)
except Exception as e:
raise TransportConnectionFailed("Connect failed") from e

self._response_headers = self.websocket.response_headers
assert self.websocket.response is not None

self._response_headers = self.websocket.response.headers

async def send(self, message: str) -> None:
"""Send message to the WebSocket server.
Expand All @@ -84,12 +86,14 @@ async def send(self, message: str) -> None:
TransportConnectionFailed: If connection closed
"""
if self.websocket is None:
raise TransportConnectionFailed("Connection is already closed")
raise TransportConnectionFailed("WebSocket connection is already closed")

try:
await self.websocket.send(message)
except Exception as e:
raise TransportConnectionFailed("Connection was closed") from e
raise TransportConnectionFailed(
f"Error trying to send data: {type(e).__name__}"
) from e

async def receive(self) -> str:
"""Receive message from the WebSocket server.
Expand All @@ -109,7 +113,9 @@ async def receive(self) -> str:
try:
data = await self.websocket.recv()
except Exception as e:
raise TransportConnectionFailed("Connection was closed") from e
raise TransportConnectionFailed(
f"Error trying to receive data: {type(e).__name__}"
) from e

# websocket.recv() can return either str or bytes
# In our case, we should receive only str here
Expand Down
35 changes: 23 additions & 12 deletions gql/transport/common/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,11 +127,13 @@ async def _send(self, message: str) -> None:
"""Send the provided message to the adapter connection and log the message"""

if not self._connected:
raise TransportClosed(
"Transport is not connected"
) from self.close_exception
if isinstance(self.close_exception, TransportConnectionFailed):
raise self.close_exception
else:
raise TransportConnectionFailed() from self.close_exception

try:
# Can raise TransportConnectionFailed
await self.adapter.send(message)
log.info(">>> %s", message)
except TransportConnectionFailed as e:
Expand All @@ -143,7 +145,7 @@ async def _receive(self) -> str:

# It is possible that the connection has been already closed in another task
if not self._connected:
raise TransportClosed("Transport is already closed")
raise TransportConnectionFailed() from self.close_exception

# Wait for the next frame.
# Can raise TransportConnectionFailed or TransportProtocolError
Expand Down Expand Up @@ -214,8 +216,6 @@ async def _receive_data_loop(self) -> None:
except (TransportConnectionFailed, TransportProtocolError) as e:
await self._fail(e, clean_close=False)
break
except TransportClosed:
break

# Parse the answer
try:
Expand Down Expand Up @@ -482,6 +482,10 @@ async def _close_coro(self, e: Exception, clean_close: bool = True) -> None:
# We should always have an active websocket connection here
assert self._connected

# Saving exception to raise it later if trying to use the transport
# after it has already closed.
self.close_exception = e

# Properly shut down liveness checker if enabled
if self.check_keep_alive_task is not None:
# More info: https://stackoverflow.com/a/43810272/1113207
Expand All @@ -492,18 +496,17 @@ async def _close_coro(self, e: Exception, clean_close: bool = True) -> None:
# Calling the subclass close hook
await self._close_hook()

# Saving exception to raise it later if trying to use the transport
# after it has already closed.
self.close_exception = e

if clean_close:
log.debug("_close_coro: starting clean_close")
try:
await self._clean_close(e)
except Exception as exc: # pragma: no cover
log.warning("Ignoring exception in _clean_close: " + repr(exc))

log.debug("_close_coro: sending exception to listeners")
if log.isEnabledFor(logging.DEBUG):
log.debug(
f"_close_coro: sending exception to {len(self.listeners)} listeners"
)

# Send an exception to all remaining listeners
for query_id, listener in self.listeners.items():
Expand All @@ -530,7 +533,15 @@ async def _close_coro(self, e: Exception, clean_close: bool = True) -> None:
log.debug("_close_coro: exiting")

async def _fail(self, e: Exception, clean_close: bool = True) -> None:
log.debug("_fail: starting with exception: " + repr(e))
if log.isEnabledFor(logging.DEBUG):
import inspect

current_frame = inspect.currentframe()
assert current_frame is not None
caller_frame = current_frame.f_back
assert caller_frame is not None
caller_name = inspect.getframeinfo(caller_frame).function
log.debug(f"_fail from {caller_name}: " + repr(e))

if self.close_task is None:

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
]

install_websockets_requires = [
"websockets>=10.1,<14",
"websockets>=14.2,<16",
]

install_botocore_requires = [
Expand Down
39 changes: 13 additions & 26 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ def __init__(self, with_ssl: bool = False):

async def start(self, handler, extra_serve_args=None):

import websockets.server
import websockets

print("Starting server")

Expand All @@ -209,16 +209,21 @@ async def start(self, handler, extra_serve_args=None):
extra_serve_args["ssl"] = ssl_context

# Adding dummy response headers
extra_serve_args["extra_headers"] = {"dummy": "test1234"}
extra_headers = {"dummy": "test1234"}

def process_response(connection, request, response):
response.headers.update(extra_headers)
return response

# Start a server with a random open port
self.start_server = websockets.server.serve(
handler, "127.0.0.1", 0, **extra_serve_args
self.server = await websockets.serve(
handler,
"127.0.0.1",
0,
process_response=process_response,
**extra_serve_args,
)

# Wait that the server is started
self.server = await self.start_server

# Get hostname and port
hostname, port = self.server.sockets[0].getsockname()[:2] # type: ignore
assert hostname == "127.0.0.1"
Expand Down Expand Up @@ -603,32 +608,14 @@ async def graphqlws_server(request):

subprotocol = "graphql-transport-ws"

from websockets.server import WebSocketServerProtocol

class CustomSubprotocol(WebSocketServerProtocol):
def select_subprotocol(self, client_subprotocols, server_subprotocols):
print(f"Client subprotocols: {client_subprotocols!r}")
print(f"Server subprotocols: {server_subprotocols!r}")

return subprotocol

def process_subprotocol(self, headers, available_subprotocols):
# Overwriting available subprotocols
available_subprotocols = [subprotocol]

print(f"headers: {headers!r}")
# print (f"Available subprotocols: {available_subprotocols!r}")

return super().process_subprotocol(headers, available_subprotocols)

server_handler = get_server_handler(request)

try:
test_server = WebSocketServer()

# Starting the server with the fixture param as the handler function
await test_server.start(
server_handler, extra_serve_args={"create_protocol": CustomSubprotocol}
server_handler, extra_serve_args={"subprotocols": [subprotocol]}
)

yield test_server
Expand Down
14 changes: 6 additions & 8 deletions tests/test_aiohttp_online.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@ async def test_aiohttp_simple_query():
url = "https://countries.trevorblades.com/graphql"

# Get transport
sample_transport = AIOHTTPTransport(url=url)
transport = AIOHTTPTransport(url=url)

# Instanciate client
async with Client(transport=sample_transport) as session:
async with Client(transport=transport) as session:

query = gql(
"""
Expand Down Expand Up @@ -60,11 +60,9 @@ async def test_aiohttp_invalid_query():

from gql.transport.aiohttp import AIOHTTPTransport

sample_transport = AIOHTTPTransport(
url="https://countries.trevorblades.com/graphql"
)
transport = AIOHTTPTransport(url="https://countries.trevorblades.com/graphql")

async with Client(transport=sample_transport) as session:
async with Client(transport=transport) as session:

query = gql(
"""
Expand All @@ -89,12 +87,12 @@ async def test_aiohttp_two_queries_in_parallel_using_two_tasks():

from gql.transport.aiohttp import AIOHTTPTransport

sample_transport = AIOHTTPTransport(
transport = AIOHTTPTransport(
url="https://countries.trevorblades.com/graphql",
)

# Instanciate client
async with Client(transport=sample_transport) as session:
async with Client(transport=transport) as session:

query1 = gql(
"""
Expand Down
Loading
Loading