Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issue/#707 fix disconnect handlers not always triggering #708

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions integration_tests/test_matchmaking.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,3 +158,17 @@ async def test_multiqueue(client_factory):
"queue_name": "tmm2v2",
"state": "stop"
}


async def test_party_cleanup_on_abort(client_factory):
for _ in range(2):
client, _ = await client_factory.login("test")
await client.read_until_command("game_info")

# This would time out on failure.
await client.join_queue("tmm2v2")

# Trigger an abort
await client.send_message({"some": "garbage"})

# Loop to reconnect
11 changes: 6 additions & 5 deletions server/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,13 +149,13 @@ def do_report_dirties():

if dirty_queues:
self.write_broadcast({
"command": "matchmaker_info",
"queues": [queue.to_dict() for queue in dirty_queues]
}
)
"command": "matchmaker_info",
"queues": [queue.to_dict() for queue in dirty_queues]
})

if dirty_players:
self.write_broadcast({
self.write_broadcast(
{
"command": "player_info",
"players": [player.to_dict() for player in dirty_players]
},
Expand Down Expand Up @@ -199,6 +199,7 @@ async def listen(
ctx = ServerContext(
f"{self.name}[{protocol_class.__name__}]",
self.connection_factory,
list(self.services.values()),
protocol_class
)
self.contexts.add(ctx)
Expand Down
6 changes: 6 additions & 0 deletions server/core/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,12 @@ async def shutdown(self) -> None:
"""
pass # pragma: no cover

def on_connection_lost(self, conn) -> None:
"""
Called every time a connection ends.
"""
pass # pragma: no cover


def create_services(injectables: Dict[str, object] = {}) -> Dict[str, Service]:
"""
Expand Down
8 changes: 6 additions & 2 deletions server/ladder_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,7 +455,7 @@ async def get_game_history(
self,
players: List[Player],
queue_id: int,
limit=3
limit: int = 3
) -> List[int]:
async with self._db.acquire() as conn:
result = []
Expand All @@ -482,7 +482,11 @@ async def get_game_history(
])
return result

async def on_connection_lost(self, player):
def on_connection_lost(self, conn: "LobbyConnection") -> None:
if not conn.player:
return

player = conn.player
self.cancel_search(player)
del self._searches[player]
if player in self._informed_players:
Expand Down
27 changes: 8 additions & 19 deletions server/lobbyconnection.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,16 +97,12 @@ async def on_connection_made(self, protocol: Protocol, peername: Address):

async def abort(self, logspam=""):
self._authenticated = False
if self.player:
self._logger.warning(
"Client %s dropped. %s", self.player.login, logspam
)
self.player_service.remove_player(self.player)
self.player = None
else:
self._logger.warning(
"Aborting %s. %s", self.peer_address.host, logspam
)

identity = self.player.login if self.player else self.peer_address.host
self._logger.warning(
"Aborting connection for %s. %s", identity, logspam
)

if self.game_connection:
await self.game_connection.abort()

Expand Down Expand Up @@ -836,7 +832,7 @@ async def command_game_matchmaking(self, message):

party = self.party_service.get_party(self.player)

if self.player is not party.owner:
if self.player != party.owner:
raise ClientError(
"Only the party owner may enter the party into a queue.",
recoverable=True
Expand Down Expand Up @@ -1129,21 +1125,14 @@ async def on_connection_lost(self):
async def nop(*args, **kwargs):
return
self.send = nop

if self.game_connection:
self._logger.debug(
"Lost lobby connection killing game connection for player %s",
self.game_connection.player.id
)
await self.game_connection.on_connection_lost()

if self.player:
self._logger.debug(
"Lost lobby connection removing player %s", self.player.id
)
await self.ladder_service.on_connection_lost(self.player)
self.player_service.remove_player(self.player)
await self.party_service.on_player_disconnected(self.player)

async def abort_connection_if_banned(self):
async with self._db.acquire() as conn:
now = datetime.utcnow()
Expand Down
15 changes: 11 additions & 4 deletions server/party_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,10 +143,12 @@ async def leave_party(self, player: Player):
raise ClientError("You are not in a party.", recoverable=True)

party = self.player_parties[player]
party.remove_player(player)
self._remove_player_from_party(player, party)
# TODO: Remove?
await party.send_party(player)

def _remove_player_from_party(self, player, party):
party.remove_player(player)
del self.player_parties[player]

if party.is_disbanded():
Expand Down Expand Up @@ -180,6 +182,11 @@ def remove_party(self, party):
# TODO: Send a special "disbanded" command?
self.write_broadcast_party(party, members=members)

async def on_player_disconnected(self, player):
if player in self.player_parties:
await self.leave_party(player)
def on_connection_lost(self, conn: "LobbyConnection") -> None:
if not conn.player or conn.player not in self.player_parties:
return

self._remove_player_from_party(
conn.player,
self.player_parties[conn.player]
)
13 changes: 13 additions & 0 deletions server/player_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,3 +275,16 @@ async def shutdown(self):
self._logger.debug(
"Could not send shutdown message to %s: %s", player, ex
)

def on_connection_lost(self, conn: "LobbyConnection") -> None:
if not conn.player:
return

self.remove_player(conn.player)

self._logger.debug(
"Removed player %d, %s, %d",
conn.player.id,
conn.player.login,
conn.session
)
21 changes: 19 additions & 2 deletions server/servercontext.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import asyncio
import socket
from typing import Callable, Dict, Type
from typing import Callable, Dict, Iterable, Type

import server.metrics as metrics

from .core import Service
from .decorators import with_logger
from .lobbyconnection import LobbyConnection
from .protocol import Protocol, QDataStreamProtocol
Expand All @@ -20,12 +21,14 @@ def __init__(
self,
name: str,
connection_factory: Callable[[], LobbyConnection],
services: Iterable[Service],
protocol_class: Type[Protocol] = QDataStreamProtocol,
):
super().__init__()
self.name = name
self._server = None
self._connection_factory = connection_factory
self._services = services
self.connections: Dict[LobbyConnection, Protocol] = {}
self.protocol_class = protocol_class

Expand Down Expand Up @@ -98,7 +101,21 @@ async def client_connected(self, stream_reader, stream_writer):
self._logger.exception(ex)
finally:
del self.connections[connection]
metrics.user_connections.labels(connection.user_agent, connection.version).dec()
await protocol.close()
await connection.on_connection_lost()

for service in self._services:
try:
service.on_connection_lost(connection)
except Exception:
self._logger.warning(
"Unexpected exception in %s.on_connection_lost",
service.__class__.__name__,
exc_info=True
)

self._logger.debug("%s: Client disconnected", self.name)
metrics.user_connections.labels(
connection.user_agent,
connection.version
).dec()
1 change: 1 addition & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,7 @@ def make(
# to live for the full lifetime of the player object
p.__owned_lobby_connection = asynctest.create_autospec(LobbyConnection)
p.lobby_connection = p.__owned_lobby_connection
p.lobby_connection.player = p
return p

return make
Expand Down
59 changes: 38 additions & 21 deletions tests/integration_tests/test_servercontext.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,44 +2,55 @@
import contextlib
from unittest import mock

import asynctest
import pytest
from asynctest import CoroutineMock, exhaust_callbacks

from server import ServerContext
from server.core import Service
from server.lobbyconnection import LobbyConnection
from server.protocol import DisconnectedError, QDataStreamProtocol

pytestmark = pytest.mark.asyncio


@pytest.fixture
def mock_server(event_loop):
class MockServer:
def __init__(self):
self.protocol, self.peername, self.user_agent, self.version = None, None, None, None
self.on_connection_lost = CoroutineMock()
class MockConnection:
def __init__(self):
self.protocol = None
self.peername = None
self.user_agent = None
self.version = None
self.on_connection_lost = CoroutineMock()

async def on_connection_made(self, protocol, peername):
self.protocol = protocol
self.peername = peername
self.protocol.writer.write_eof()
self.protocol.reader.feed_eof()

async def on_message_received(self, msg):
pass

async def on_connection_made(self, protocol, peername):
self.protocol = protocol
self.peername = peername
self.protocol.writer.write_eof()
self.protocol.reader.feed_eof()

async def on_message_received(self, msg):
pass
@pytest.fixture
def mock_connection():
return MockConnection()


return MockServer()
@pytest.fixture
def mock_service():
return asynctest.create_autospec(Service)


@pytest.fixture
async def mock_context(mock_server):
ctx = ServerContext("TestServer", lambda: mock_server)
async def mock_context(mock_connection, mock_service):
ctx = ServerContext("TestServer", lambda: mock_connection, [mock_service])
yield await ctx.listen("127.0.0.1", None), ctx
ctx.close()


@pytest.fixture
async def context():
async def context(mock_service):
def make_connection() -> LobbyConnection:
return LobbyConnection(
database=mock.Mock(),
Expand All @@ -51,22 +62,28 @@ def make_connection() -> LobbyConnection:
party_service=mock.Mock()
)

ctx = ServerContext("TestServer", make_connection)
ctx = ServerContext("TestServer", make_connection, [mock_service])
yield await ctx.listen("127.0.0.1", None), ctx
ctx.close()


async def test_serverside_abort(event_loop, mock_context, mock_server):
async def test_serverside_abort(
event_loop,
mock_context,
mock_connection,
mock_service
):
srv, ctx = mock_context
(reader, writer) = await asyncio.open_connection(*srv.sockets[0].getsockname())
proto = QDataStreamProtocol(reader, writer)
await proto.send_message({"some_junk": True})
await exhaust_callbacks(event_loop)

mock_server.on_connection_lost.assert_any_call()
mock_connection.on_connection_lost.assert_any_call()
mock_service.on_connection_lost.assert_called_once()


async def test_connection_broken_external(context, mock_server):
async def test_connection_broken_external(context):
"""
When the connection breaks while the server is calling protocol.send from
somewhere other than the main read - response loop. Make sure that this
Expand Down
22 changes: 22 additions & 0 deletions tests/integration_tests/test_teammatchmaker.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,3 +484,25 @@ async def test_game_ratings_initialized_based_on_global(lobby_server):
}
]
}


@fast_forward(30)
async def test_party_cleanup_on_abort(lobby_server):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the implicit "assert" here? If it's not properly cleaned up then the second time the start command for tmm fails and waiting for search_info times out?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep exactly. If the party isn't cleaned up, then you'd get a notice message instead of search_info.

I think I should also specify that state="start" in the search_info message.

for _ in range(3):
_, _, proto = await connect_and_sign_in(
("test", "test_password"), lobby_server
)
await read_until_command(proto, "game_info")

await proto.send_message({
"command": "game_matchmaking",
"state": "start",
"mod": "tmm2v2"
})
# The queue was successful. This would time out on failure.
await read_until_command(proto, "search_info", state="start")

# Trigger an abort
await proto.send_message({"some": "garbage"})

# Loop to reconnect
6 changes: 3 additions & 3 deletions tests/unit_tests/test_ladder_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,16 +234,16 @@ async def test_write_rating_progress(ladder_service: LadderService, player_facto
)

ladder_service.write_rating_progress(p1, RatingType.LADDER_1V1)

# Message is sent after the first call
p1.lobby_connection.write.assert_called_once()

ladder_service.write_rating_progress(p1, RatingType.LADDER_1V1)
p1.lobby_connection.write.reset_mock()
# But not after the second
p1.lobby_connection.write.assert_not_called()
await ladder_service.on_connection_lost(p1)
ladder_service.write_rating_progress(p1, RatingType.LADDER_1V1)

ladder_service.on_connection_lost(p1.lobby_connection)
ladder_service.write_rating_progress(p1, RatingType.LADDER_1V1)
# But it is called if the player relogs
p1.lobby_connection.write.assert_called_once()

Expand Down
Loading