diff --git a/server/__init__.py b/server/__init__.py index 7b2eca1f2..5b6df70ba 100644 --- a/server/__init__.py +++ b/server/__init__.py @@ -11,6 +11,7 @@ from typing import Any, Dict, Optional import aiomeasures +import asyncio from server.db import FAFDatabase from . import config as config @@ -26,9 +27,11 @@ from .game_service import GameService from .ladder_service import LadderService from .control import init as run_control_server +from .timing import at_interval + __version__ = '0.9.17' -__author__ = 'Chris Kitching, Dragonfire, Gael Honorez, Jeroen De Dauw, Crotalus, Michael Søndergaard, Michel Jung' +__author__ = 'Askaholic, Chris Kitching, Dragonfire, Gael Honorez, Jeroen De Dauw, Crotalus, Michael Søndergaard, Michel Jung' __contact__ = 'admin@faforever.com' __license__ = 'GPLv3' __copyright__ = 'Copyright (c) 2011-2015 ' + __author__ @@ -94,7 +97,8 @@ def run_lobby_server( Run the lobby server """ - def report_dirties(): + @at_interval(DIRTY_REPORT_INTERVAL) + async def do_report_dirties(): try: dirty_games = games.dirty_games dirty_queues = games.dirty_queues @@ -103,13 +107,20 @@ def report_dirties(): player_service.clear_dirty() if len(dirty_queues) > 0: - ctx.broadcast_raw(encode_queues(dirty_queues), lambda lobby_conn: lobby_conn.authenticated) + await ctx.broadcast_raw( + encode_queues(dirty_queues), + lambda lobby_conn: lobby_conn.authenticated + ) if len(dirty_players) > 0: - ctx.broadcast_raw(encode_players(dirty_players), lambda lobby_conn: lobby_conn.authenticated) - - # TODO: This spams squillions of messages: we should implement per-connection message - # aggregation at the next abstraction layer down :P + await ctx.broadcast_raw( + encode_players(dirty_players), + lambda lobby_conn: lobby_conn.authenticated + ) + + # TODO: This spams squillions of messages: we should implement per- + # connection message aggregation at the next abstraction layer down :P + tasks = [] for game in dirty_games: if game.state == GameState.ENDED: games.remove_game(game) @@ -117,25 +128,33 @@ def report_dirties(): # So we're going to be broadcasting this to _somebody_... message = encode_dict(game.to_dict()) - # These games shouldn't be broadcast, but instead privately sent to those who are - # allowed to see them. + # These games shouldn't be broadcast, but instead privately sent + # to those who are allowed to see them. if game.visibility == VisibilityState.FRIENDS: - # To see this game, you must have an authenticated connection and be a friend of the host, or the host. - validation_func = lambda lobby_conn: lobby_conn.player.id in game.host.friends or lobby_conn.player == game.host + # To see this game, you must have an authenticated + # connection and be a friend of the host, or the host. + def validation_func(lobby_conn): + return lobby_conn.player.id in game.host.friends or \ + lobby_conn.player == game.host else: - validation_func = lambda lobby_conn: lobby_conn.player.id not in game.host.foes + def validation_func(lobby_conn): + return lobby_conn.player.id not in game.host.foes + + tasks.append(ctx.broadcast_raw( + message, + lambda lobby_conn: lobby_conn.authenticated and validation_func(lobby_conn) + )) + + await asyncio.gather(*tasks) - ctx.broadcast_raw(message, lambda lobby_conn: lobby_conn.authenticated and validation_func(lobby_conn)) except Exception as e: logging.getLogger().exception(e) - finally: - loop.call_later(DIRTY_REPORT_INTERVAL, report_dirties) ping_msg = encode_message('PING') - def ping_broadcast(): - ctx.broadcast_raw(ping_msg) - loop.call_later(45, ping_broadcast) + @at_interval(45) + async def ping_broadcast(): + await ctx.broadcast_raw(ping_msg) def make_connection() -> LobbyConnection: return LobbyConnection( @@ -147,7 +166,5 @@ def make_connection() -> LobbyConnection: ladder_service=ladder_service ) ctx = ServerContext(make_connection, name="LobbyServer") - loop.call_later(DIRTY_REPORT_INTERVAL, report_dirties) - loop.call_soon(ping_broadcast) loop.run_until_complete(ctx.listen(*address)) return ctx diff --git a/server/api/oauth_session.py b/server/api/oauth_session.py index 656bfb425..5d78afd3b 100644 --- a/server/api/oauth_session.py +++ b/server/api/oauth_session.py @@ -3,8 +3,9 @@ from typing import Dict import aiohttp -from oauthlib.oauth2.rfc6749.errors import (InsecureTransportError, - MissingTokenError) +from oauthlib.oauth2.rfc6749.errors import ( + InsecureTransportError, MissingTokenError +) class OAuth2Session(object): diff --git a/server/control.py b/server/control.py index 0bd013301..9d5a35206 100644 --- a/server/control.py +++ b/server/control.py @@ -2,12 +2,15 @@ Tiny local-only http server for getting stats and performing various tasks """ +import logging import socket +from json import dumps from aiohttp import web -import logging -from server import PlayerService, GameService, config -from json import dumps + +from . import config +from .game_service import GameService +from .player_service import PlayerService logger = logging.getLogger(__name__) diff --git a/server/game_service.py b/server/game_service.py index ba7169e6d..d112c5e55 100644 --- a/server/game_service.py +++ b/server/game_service.py @@ -2,11 +2,10 @@ from typing import Dict, List, Optional, Union, ValuesView import aiocron -from server import GameState, VisibilityState from server.db import FAFDatabase from server.decorators import with_logger from server.games import CoopGame, CustomGame, FeaturedMod, LadderGame -from server.games.game import Game +from server.games.game import Game, GameState, VisibilityState from server.matchmaker import MatchmakerQueue from server.players import Player diff --git a/server/gameconnection.py b/server/gameconnection.py index 0bea4b86b..63cf97439 100644 --- a/server/gameconnection.py +++ b/server/gameconnection.py @@ -1,10 +1,11 @@ import asyncio from server.db import FAFDatabase -from sqlalchemy import text, select +from sqlalchemy import select, text from .abc.base_game import GameConnectionState from .config import TRACE +from .db.models import login, moderation_report, reported_user from .decorators import with_logger from .game_service import GameService from .games.game import Game, GameState, ValidityState, Victory @@ -12,8 +13,6 @@ from .players import Player, PlayerState from .protocol import GpgNetServerProtocol, QDataStreamProtocol -from .db.models import (reported_user, moderation_report, login) - @with_logger class GameConnection(GpgNetServerProtocol): @@ -69,11 +68,11 @@ def player(self) -> Player: def player(self, val: Player): self._player = val - def send_message(self, message): + async def send_message(self, message): message['target'] = "game" self._logger.log(TRACE, ">>: %s", message) - self.protocol.send_message(message) + await self.protocol.send_message(message) async def _handle_idle_state(self): """ @@ -92,7 +91,7 @@ async def _handle_idle_state(self): pass else: self._logger.exception("Unknown PlayerState: %s", state) - self.abort() + await self.abort() async def _handle_lobby_state(self): """ @@ -104,7 +103,7 @@ async def _handle_lobby_state(self): try: player_state = self.player.state if player_state == PlayerState.HOSTING: - self.send_HostGame(self.game.map_folder_name) + await self.send_HostGame(self.game.map_folder_name) self.game.set_hosted() # If the player is joining, we connect him to host # followed by the rest of the players. @@ -116,10 +115,12 @@ async def _handle_lobby_state(self): self._state = GameConnectionState.CONNECTED_TO_HOST self.game.add_game_connection(self) + tasks = [] for peer in self.game.connections: if peer != self and peer.player != self.game.host: self._logger.debug("%s connecting to %s", self.player, peer) - asyncio.ensure_future(self.connect_to_peer(peer)) + tasks.append(self.connect_to_peer(peer)) + await asyncio.gather(*tasks) except Exception as e: # pragma: no cover self._logger.exception(e) @@ -129,24 +130,29 @@ async def connect_to_host(self, peer: "GameConnection"): :return: """ assert peer.player.state == PlayerState.HOSTING - self.send_JoinGame(peer.player.login, - peer.player.id) + await self.send_JoinGame(peer.player.login, peer.player.id) - peer.send_ConnectToPeer(player_name=self.player.login, - player_uid=self.player.id, - offer=True) + await peer.send_ConnectToPeer( + player_name=self.player.login, + player_uid=self.player.id, + offer=True + ) async def connect_to_peer(self, peer: "GameConnection"): """ Connect two peers :return: None """ - self.send_ConnectToPeer(player_name=peer.player.login, - player_uid=peer.player.id, - offer=True) - peer.send_ConnectToPeer(player_name=self.player.login, - player_uid=self.player.id, - offer=False) + await self.send_ConnectToPeer( + player_name=peer.player.login, + player_uid=peer.player.id, + offer=True + ) + await peer.send_ConnectToPeer( + player_name=self.player.login, + player_uid=self.player.id, + offer=False + ) async def handle_action(self, command, args): """ @@ -167,7 +173,7 @@ async def handle_action(self, command, args): except Exception as e: # pragma: no cover self._logger.exception(e) self._logger.exception("Something awful happened in a game thread!") - self.abort() + await self.abort() async def handle_desync(self, *_args): # pragma: no cover self.game.desyncs += 1 @@ -303,7 +309,7 @@ async def handle_teamkill_report(self, gametime, reporter_id, reporter_name, tea :param teamkiller_id: teamkiller id :param teamkiller_name: teamkiller nickname - Used as a failsafe in case ID is wrong """ - + async with self._db.acquire() as conn: """ Sometime the game sends a wrong ID - but a correct player name @@ -403,7 +409,7 @@ async def handle_ice_message(self, receiver_id, ice_msg): ) return - game_connection.send_message({ + await game_connection.send_message({ "command": "IceMsg", "args": [int(self.player.id), ice_msg] }) @@ -499,7 +505,7 @@ def _mark_dirty(self): if self.game: self.game_service.mark_dirty(self.game) - def abort(self, log_message: str=''): + async def abort(self, log_message: str=''): """ Abort the connection @@ -513,10 +519,10 @@ def abort(self, log_message: str=''): self._logger.debug("%s.abort(%s)", self, log_message) if self.game.state == GameState.LOBBY: - self.disconnect_all_peers() + await self.disconnect_all_peers() self._state = GameConnectionState.ENDED - asyncio.ensure_future(self.game.remove_game_connection(self)) + await self.game.remove_game_connection(self) self._mark_dirty() self.player.state = PlayerState.IDLE del self.player.game @@ -524,14 +530,16 @@ def abort(self, log_message: str=''): except Exception as ex: # pragma: no cover self._logger.debug("Exception in abort(): %s", ex) - def disconnect_all_peers(self): + async def disconnect_all_peers(self): + tasks = [] for peer in self.game.connections: if peer == self: continue - try: - peer.send_DisconnectFromPeer(self.player.id) - except Exception: # pragma no cover + tasks.append(peer.send_DisconnectFromPeer(self.player.id)) + + for result in await asyncio.gather(*tasks, return_exceptions=True): + if isinstance(result, Exception): self._logger.exception( "peer_sendDisconnectFromPeer failed for player %i", self.player.id) @@ -542,7 +550,7 @@ async def on_connection_lost(self): except Exception as e: # pragma: no cover self._logger.exception(e) finally: - self.abort() + await self.abort() def __str__(self): return "GameConnection({}, {})".format(self.player, self.game) diff --git a/server/games/game.py b/server/games/game.py index ef88f0fdd..1e5cfef7a 100644 --- a/server/games/game.py +++ b/server/games/game.py @@ -3,17 +3,17 @@ import logging import re import time -from collections import Counter, defaultdict +from collections import defaultdict from enum import Enum, unique -from typing import Any, Dict, Optional, Tuple, Union +from typing import Any, Dict, Optional, Tuple import trueskill +from server.games.game_results import GameOutcome, GameResult, GameResults +from server.rating import RatingType from trueskill import Rating from ..abc.base_game import GameConnectionState, InitMode from ..players import Player, PlayerState -from server.rating import RatingType -from server.games.game_results import GameOutcome, GameResult, GameResults FFA_TEAM = 1 diff --git a/server/geoip_service.py b/server/geoip_service.py index 03caf1a62..f9e81a40e 100644 --- a/server/geoip_service.py +++ b/server/geoip_service.py @@ -28,7 +28,8 @@ def __init__(self): self.db = None # crontab: min hour day month day_of_week - # Run every Wednesday because GeoLite2 is updated every Tuesday + # Run every Wednesday because GeoLite2 is updated every first Tuesday + # of the month. self._update_cron = aiocron.crontab( '0 0 0 * * 3', func=self.check_update_geoip_db ) diff --git a/server/ladder_service.py b/server/ladder_service.py index 2e0e73386..c48d727ac 100644 --- a/server/ladder_service.py +++ b/server/ladder_service.py @@ -3,16 +3,16 @@ from collections import defaultdict from typing import Dict, List, NamedTuple, Set +from server.db import FAFDatabase +from server.rating import RatingType from sqlalchemy import and_, func, select, text -from server.db import FAFDatabase from .config import LADDER_ANTI_REPETITION_LIMIT from .db.models import game_featuredMods, game_player_stats, game_stats from .decorators import with_logger from .game_service import GameService from .matchmaker import MatchmakerQueue, Search from .players import Player, PlayerState -from server.rating import RatingType MapDescription = NamedTuple('Map', [("id", int), ("name", str), ("path", str)]) @@ -37,7 +37,7 @@ def __init__(self, database: FAFDatabase, games_service: GameService): asyncio.ensure_future(self.handle_queue_matches()) - def start_search(self, initiator: Player, search: Search, queue_name: str): + async def start_search(self, initiator: Player, search: Search, queue_name: str): self._cancel_existing_searches(initiator) for player in search.players: @@ -45,7 +45,7 @@ def start_search(self, initiator: Player, search: Search, queue_name: str): # For now, inform_player is only designed for ladder1v1 if queue_name == "ladder1v1": - self.inform_player(player) + await self.inform_player(player) self.searches[queue_name][initiator] = search @@ -76,13 +76,13 @@ def _cancel_existing_searches(self, initiator: Player) -> List[Search]: del self.searches[queue_name][initiator] return searches - def inform_player(self, player: Player): + async def inform_player(self, player: Player): if player not in self._informed_players: self._informed_players.add(player) mean, deviation = player.ratings[RatingType.LADDER_1V1] if deviation > 490: - player.lobby_connection.send({ + await player.lobby_connection.send({ "command": "notice", "style": "info", "text": ( @@ -97,7 +97,7 @@ def inform_player(self, player: Player): }) elif deviation > 250: progress = (500.0 - deviation) / 2.5 - player.lobby_connection.send({ + await player.lobby_connection.send({ "command": "notice", "style": "info", "text": ( @@ -113,8 +113,10 @@ async def handle_queue_matches(self): assert len(s2.players) == 1 p1, p2 = s1.players[0], s2.players[0] msg = {"command": "match_found", "queue": "ladder1v1"} - p1.lobby_connection.send(msg) - p2.lobby_connection.send(msg) + await asyncio.gather( + p1.lobby_connection.send(msg), + p2.lobby_connection.send(msg) + ) asyncio.ensure_future(self.start_game(p1, p2)) except Exception as e: self._logger.exception( @@ -159,15 +161,17 @@ async def start_game(self, host: Player, guest: Player): # FIXME: Database filenames contain the maps/ prefix and .zip suffix. # Really in the future, just send a better description self._logger.debug("Starting ladder game: %s", game) - host.lobby_connection.launch_game(game, is_host=True, use_map=mapname) + await host.lobby_connection.launch_game(game, is_host=True, use_map=mapname) try: hosted = await game.await_hosted() if not hosted: raise TimeoutError("Host left lobby") except TimeoutError: msg = {"command": "game_launch_timeout"} - host.lobby_connection.send(msg) - guest.lobby_connection.send(msg) + await asyncio.gather( + host.lobby_connection.send(msg), + guest.lobby_connection.send(msg) + ) # TODO: Uncomment this line once the client supports `game_launch_timeout`. # Until then, returning here will cause the client to think it is # searching for ladder, even though the server has already removed it @@ -175,7 +179,7 @@ async def start_game(self, host: Player, guest: Player): # return self._logger.debug("Ladder game failed to launch due to a timeout") - guest.lobby_connection.launch_game( + await guest.lobby_connection.launch_game( game, is_host=False, use_map=mapname ) self._logger.debug("Ladder game launched successfully") diff --git a/server/lobbyconnection.py b/server/lobbyconnection.py index af7189d28..438bd79db 100644 --- a/server/lobbyconnection.py +++ b/server/lobbyconnection.py @@ -97,13 +97,13 @@ def on_connection_made(self, protocol: QDataStreamProtocol, peername: Address): self.peer_address = peername server.stats.incr('server.connections') - def abort(self, logspam=""): + async def abort(self, logspam=""): if self.player: self._logger.warning("Client %s dropped. %s" % (self.player.login, logspam)) else: self._logger.warning("Aborting %s. %s" % (self.peer_address.host, logspam)) if self.game_connection: - self.game_connection.abort() + await self.game_connection.abort() self.game_connection = None self._authenticated = False self.protocol.writer.close() @@ -113,11 +113,11 @@ def abort(self, logspam=""): self.player = None server.stats.incr('server.connections.aborted') - def ensure_authenticated(self, cmd): + async def ensure_authenticated(self, cmd): if not self._authenticated: if cmd not in ['hello', 'ask_session', 'create_account', 'ping', 'pong', 'Bottleneck']: # Bottleneck is sent by the game during reconnect server.stats.incr('server.received_messages.unauthenticated', tags={"command": cmd}) - self.abort("Message invalid for unauthenticated connection: %s" % cmd) + await self.abort("Message invalid for unauthenticated connection: %s" % cmd) return False return True @@ -129,7 +129,7 @@ async def on_message_received(self, message): try: cmd = message['command'] - if not self.ensure_authenticated(cmd): + if not await self.ensure_authenticated(cmd): return target = message.get('target') if target == 'game': @@ -144,43 +144,41 @@ async def on_message_received(self, message): raise ClientError("Your client version is no longer supported. Please update to the newest version: https://faforever.com") handler = getattr(self, 'command_{}'.format(cmd)) - if asyncio.iscoroutinefunction(handler): - await handler(message) - else: - handler(message) + await handler(message) + except AuthenticationError as ex: - self.protocol.send_message( - {'command': 'authentication_failed', - 'text': ex.message} - ) + await self.send({ + 'command': 'authentication_failed', + 'text': ex.message + }) except ClientError as ex: self._logger.warning("Client error: %s", ex.message) - self.protocol.send_message( - {'command': 'notice', - 'style': 'error', - 'text': ex.message} - ) + await self.send({ + 'command': 'notice', + 'style': 'error', + 'text': ex.message + }) if not ex.recoverable: - self.abort(ex.message) + await self.abort(ex.message) except (KeyError, ValueError) as ex: self._logger.exception(ex) - self.abort("Garbage command: {}".format(message)) + await self.abort("Garbage command: {}".format(message)) except Exception as ex: # pragma: no cover - self.protocol.send_message({'command': 'invalid'}) + await self.send({'command': 'invalid'}) self._logger.exception(ex) - self.abort("Error processing command") + await self.abort("Error processing command") - def command_ping(self, msg): - self.protocol.send_raw(self.protocol.pack_message('PONG')) + async def command_ping(self, msg): + await self.protocol.send_raw(self.protocol.pack_message('PONG')) - def command_pong(self, msg): + async def command_pong(self, msg): pass - @asyncio.coroutine - def command_create_account(self, message): + async def command_create_account(self, message): raise ClientError("FAF no longer supports direct registration. Please use the website to register.", recoverable=True) - async def send_coop_maps(self): + async def command_coop_list(self, message): + """ Request for coop map list""" async with self._db.acquire() as conn: result = await conn.execute("SELECT name, description, filename, type, id FROM `coop_map`") @@ -209,19 +207,16 @@ async def send_coop_maps(self): json_to_send["uid"] = row["id"] maps.append(json_to_send) - self.protocol.send_messages(maps) + await self.protocol.send_messages(maps) async def command_matchmaker_info(self, message): - self.send({ + await self.send({ 'command': 'matchmaker_info', 'queues': [queue.to_dict() for queue in self.ladder_service.queues.values()] }) - await self.protocol.drain() - - @timed() - def send_game_list(self): - self.send({ + async def send_game_list(self): + await self.send({ 'command': 'game_info', 'games': [game.to_dict() for game in self.game_service.open_games] }) @@ -232,7 +227,7 @@ async def command_social_remove(self, message): elif "foe" in message: subject_id = message["foe"] else: - self.abort("No-op social_remove.") + await self.abort("No-op social_remove.") return async with self._db.acquire() as conn: @@ -258,15 +253,15 @@ async def command_social_add(self, message): subject_id=subject_id, )) - def kick(self): - self.send({ + async def kick(self): + await self.send({ "command": "notice", "style": "kick", }) - self.abort() + await self.abort() - def send_updated_achievements(self, updated_achievements): - self.send({ + async def send_updated_achievements(self, updated_achievements): + await self.send({ "command": "updated_achievements", "updated_achievements": updated_achievements }) @@ -333,7 +328,7 @@ async def command_admin(self, message): for player in self.player_service: try: if player.lobby_connection: - player.lobby_connection.send_warning(message.get('message')) + await player.lobby_connection.send_warning(message.get('message')) except Exception as ex: self._logger.debug("Could not send broadcast message to %s: %s".format(player, ex)) @@ -403,7 +398,7 @@ async def check_user_login(self, conn, username, password): return player_id, real_username, steamid - def check_version(self, message): + async def check_version(self, message): versionDB, updateFile = self.player_service.client_version_info update_msg = { 'command': 'update', @@ -417,7 +412,7 @@ def check_version(self, message): server.stats.gauge('user.agents.{}'.format(self.user_agent), 1, delta=True) if not self.user_agent or 'downlords-faf-client' not in self.user_agent: - self.send_warning( + await self.send_warning( "You are using an unofficial client version! " "Some features might not work as expected. " "If you experience any problems please download the latest " @@ -428,7 +423,7 @@ def check_version(self, message): if not version or not self.user_agent: update_msg['command'] = 'welcome' # For compatibility with 0.10.x updating mechanism - self.send(update_msg) + await self.send(update_msg) return False # Check their client is reporting the right version number. @@ -439,10 +434,10 @@ def check_version(self, message): if "+" in version: version = version.split('+')[0] if semver.compare(versionDB, version) > 0: - self.send(update_msg) + await self.send(update_msg) return False except ValueError: - self.send(update_msg) + await self.send(update_msg) return False return True @@ -467,7 +462,7 @@ async def check_policy_conformity(self, player_id, uid_hash, session, ignore_res if response.get('result', '') == 'vm': self._logger.debug("Using VM: %d: %s", player_id, uid_hash) - self.send({ + await self.send({ "command": "notice", "style": "error", "text": ( @@ -477,7 +472,7 @@ async def check_policy_conformity(self, player_id, uid_hash, session, ignore_res "positive." ) }) - self.send_warning("Your computer seems to be a virtual machine.

In order to " + await self.send_warning("Your computer seems to be a virtual machine.

In order to " "log in from a VM, you have to link your account to Steam: " + config.WWW_URL + "/account/link.
If you need an exception, please contact an " @@ -485,7 +480,7 @@ async def check_policy_conformity(self, player_id, uid_hash, session, ignore_res if response.get('result', '') == 'already_associated': self._logger.warning("UID hit: %d: %s", player_id, uid_hash) - self.send_warning("Your computer is already associated with another FAF account.

In order to " + await self.send_warning("Your computer is already associated with another FAF account.

In order to " "log in with an additional account, you have to link it to Steam: " + config.WWW_URL + "/account/link.
If you need an exception, please contact an " @@ -494,7 +489,7 @@ async def check_policy_conformity(self, player_id, uid_hash, session, ignore_res if response.get('result', '') == 'fraudulent': self._logger.info("Banning player %s for fraudulent looking login.", player_id) - self.send_warning("Fraudulent login attempt detected. As a precautionary measure, your account has been " + await self.send_warning("Fraudulent login attempt detected. As a precautionary measure, your account has been " "banned permanently. Please contact an admin or moderator on the forums if you feel this is " "a false positive.", fatal=True) @@ -566,7 +561,7 @@ async def command_hello(self, message): if old_player: self._logger.debug("player {} already signed in: {}".format(self.player.id, old_player)) if old_player.lobby_connection: - old_player.lobby_connection.send_warning("You have been signed out because you signed in elsewhere.", fatal=True) + await old_player.lobby_connection.send_warning("You have been signed out because you signed in elsewhere.", fatal=True) old_player.lobby_connection.game_connection = None old_player.lobby_connection.player = None self._logger.debug("Removing previous game_connection and player reference of player {} in hope on_connection_lost() wouldn't drop her out of the game".format(self.player.id)) @@ -581,7 +576,7 @@ async def command_hello(self, message): self.player.country = self.geoip_service.country(self.peer_address.host) # Send the player their own player info. - self.send({ + await self.send({ "command": "welcome", "me": self.player.to_dict(), @@ -591,12 +586,10 @@ async def command_hello(self, message): }) # Tell player about everybody online. This must happen after "welcome". - self.send( - { - "command": "player_info", - "players": [player.to_dict() for player in self.player_service] - } - ) + await self.send({ + "command": "player_info", + "players": [player.to_dict() for player in self.player_service] + }) # Tell everyone else online about us. This must happen after all the player_info messages. # This ensures that no other client will perform an operation that interacts with the @@ -630,21 +623,21 @@ async def command_hello(self, message): channels.append("#%s_clan" % self.player.clan) json_to_send = {"command": "social", "autojoin": channels, "channels": channels, "friends": friends, "foes": foes, "power": permission_group} - self.send(json_to_send) + await self.send(json_to_send) - self.send_game_list() + await self.send_game_list() - def command_restore_game_session(self, message): + async def command_restore_game_session(self, message): game_id = int(message.get('game_id')) # Restore the player's game connection, if the game still exists and is live if not game_id or game_id not in self.game_service: - self.send_warning("The game you were connected to does no longer exist") + await self.send_warning("The game you were connected to does no longer exist") return game = self.game_service[game_id] # type: Game if game.state != GameState.LOBBY and game.state != GameState.LIVE: - self.send_warning("The game you were connected to is no longer available") + await self.send_warning("The game you were connected to is no longer available") return self._logger.debug("Restoring game session of player %s to game %s", self.player, game) @@ -663,10 +656,9 @@ def command_restore_game_session(self, message): if not hasattr(self.player, "game"): self.player.game = game - @timed - def command_ask_session(self, message): - if self.check_version(message): - self.send({"command": "session", "session": self.session}) + async def command_ask_session(self, message): + if await self.check_version(message): + await self.send({"command": "session", "session": self.session}) async def command_avatar(self, message): action = message['action'] @@ -684,7 +676,7 @@ async def command_avatar(self, message): avatarList.append(avatar) if avatarList: - self.send({"command": "avatar", "avatarlist": avatarList}) + await self.send({"command": "avatar", "avatarlist": avatarList}) elif action == "select": avatar = message['avatar'] @@ -719,7 +711,7 @@ async def command_game_join(self, message): game = self.game_service[uuid] if not game or game.state != GameState.LOBBY: self._logger.debug("Game not in lobby state: %s", game) - self.send({ + await self.send({ "command": "notice", "style": "info", "text": "The game you are trying to join is not ready." @@ -727,17 +719,17 @@ async def command_game_join(self, message): return if game.password != password: - self.send({ + await self.send({ "command": "notice", "style": "info", "text": "Bad password (it's case sensitive)" }) return - self.launch_game(game, is_host=False) + await self.launch_game(game, is_host=False) except KeyError: - self.send({ + await self.send({ "command": "notice", "style": "info", "text": "The host has left the game" @@ -765,11 +757,7 @@ async def command_game_matchmaking(self, message): # TODO: Put player parties here search = Search([self.player]) - self.ladder_service.start_search(self.player, search, queue_name=mod) - - def command_coop_list(self, message): - """ Request for coop map list""" - asyncio.ensure_future(self.send_coop_maps()) + await self.ladder_service.start_search(self.player, search, queue_name=mod) async def command_game_host(self, message): assert isinstance(self.player, Player) @@ -782,7 +770,7 @@ async def command_game_host(self, message): visibility = VisibilityState.from_string(message.get('visibility')) if not isinstance(visibility, VisibilityState): # Protocol violation. - self.abort("{} sent a nonsense visibility code: {}".format(self.player.login, message.get('visibility'))) + await self.abort("{} sent a nonsense visibility code: {}".format(self.player.login, message.get('visibility'))) return title = html.escape(message.get('title') or f"{self.player.login}'s game") @@ -790,7 +778,7 @@ async def command_game_host(self, message): try: title.encode('ascii') except UnicodeEncodeError: - self.send({ + await self.send({ "command": "notice", "style": "error", "text": "Non-ascii characters in game name detected." @@ -810,13 +798,13 @@ async def command_game_host(self, message): mapname=mapname, password=password ) - self.launch_game(game, is_host=True) + await self.launch_game(game, is_host=True) server.stats.incr('game.hosted', tags={'game_mode': game_mode}) - def launch_game(self, game, is_host=False, use_map=None): + async def launch_game(self, game, is_host=False, use_map=None): # TODO: Fix setting up a ridiculous amount of cyclic pointers here if self.game_connection: - self.game_connection.abort("Player launched a new game") + await self.game_connection.abort("Player launched a new game") if is_host: game.host = self.player @@ -840,7 +828,7 @@ def launch_game(self, game, is_host=False, use_map=None): } if use_map: cmd['mapname'] = use_map - self.send(cmd) + await self.send(cmd) async def command_modvault(self, message): type = message["type"] @@ -861,7 +849,7 @@ async def command_modvault(self, message): comments=[], description=description, played=played, likes=likes, downloads=downloads, date=int(date.timestamp()), uid=uid, name=name, version=version, author=author, ui=ui) - self.send(out) + await self.send(out) except: self._logger.error("Error handling table_mod row (uid: {})".format(uid), exc_info=True) pass @@ -899,7 +887,7 @@ async def command_modvault(self, message): "JOIN mod_version v ON v.mod_id = s.mod_id " "SET s.likes = s.likes + 1, likers=%s WHERE v.uid = %s", json.dumps(likers), uid) - self.send(out) + await self.send(out) elif type == "download": uid = message["uid"] @@ -910,7 +898,6 @@ async def command_modvault(self, message): else: raise ValueError('invalid type argument') - @asyncio.coroutine async def command_ice_servers(self, message): if not self.player: return @@ -924,13 +911,13 @@ async def command_ice_servers(self, message): if self.nts_client: ice_servers = ice_servers + await self.nts_client.server_tokens(ttl=ttl) - self.send({ + await self.send({ 'command': 'ice_servers', 'ice_servers': ice_servers, 'ttl': ttl }) - def send_warning(self, message: str, fatal: bool=False): + async def send_warning(self, message: str, fatal: bool=False): """ Display a warning message to the client :param message: Warning message to display @@ -939,20 +926,22 @@ def send_warning(self, message: str, fatal: bool=False): and not attempt to reconnect. :return: None """ - self.send({'command': 'notice', - 'style': 'info' if not fatal else 'error', - 'text': message}) + await self.send({ + 'command': 'notice', + 'style': 'info' if not fatal else 'error', + 'text': message + }) if fatal: - self.abort(message) + await self.abort(message) - def send(self, message): + async def send(self, message): """ :param message: :return: """ self._logger.log(TRACE, ">>: %s", message) - self.protocol.send_message(message) + await self.protocol.send_message(message) async def drain(self): await self.protocol.drain() diff --git a/server/matchmaker/search.py b/server/matchmaker/search.py index 9831307cf..8bef98abe 100644 --- a/server/matchmaker/search.py +++ b/server/matchmaker/search.py @@ -3,12 +3,12 @@ import time from typing import List, Optional, Tuple +from server.rating import RatingType from trueskill import Rating, quality from .. import config from ..decorators import with_logger from ..players import Player -from server.rating import RatingType @with_logger @@ -71,7 +71,6 @@ def has_no_top_player(self) -> bool: max_rating = max(map(lambda rating_tuple: rating_tuple[0], self.ratings)) return max_rating < config.TOP_PLAYER_MIN_RATING - @property def ratings(self): ratings = [] diff --git a/server/protocol/gpgnet.py b/server/protocol/gpgnet.py index 72dd1ca8a..41af9ccc5 100644 --- a/server/protocol/gpgnet.py +++ b/server/protocol/gpgnet.py @@ -6,44 +6,44 @@ class GpgNetServerProtocol(metaclass=ABCMeta): """ Defines an interface for the server side GPGNet protocol """ - def send_ConnectToPeer(self, player_name: str, player_uid: int, offer: bool): + async def send_ConnectToPeer(self, player_name: str, player_uid: int, offer: bool): """ Tells a client that has a listening LobbyComm instance to connect to the given peer :param player_name: Remote player name :param player_uid: Remote player identifier """ - self.send_gpgnet_message('ConnectToPeer', [player_name, player_uid, offer]) + await self.send_gpgnet_message('ConnectToPeer', [player_name, player_uid, offer]) - def send_JoinGame(self, remote_player_name: str, remote_player_uid: int): + async def send_JoinGame(self, remote_player_name: str, remote_player_uid: int): """ Tells the game to join the given peer by ID :param remote_player_name: :param remote_player_uid: """ - self.send_gpgnet_message('JoinGame', [remote_player_name, remote_player_uid]) + await self.send_gpgnet_message('JoinGame', [remote_player_name, remote_player_uid]) - def send_HostGame(self, map): + async def send_HostGame(self, map): """ Tells the game to start listening for incoming connections as a host :param map: Which scenario to use """ - self.send_gpgnet_message('HostGame', [str(map)]) + await self.send_gpgnet_message('HostGame', [str(map)]) - def send_DisconnectFromPeer(self, id: int): + async def send_DisconnectFromPeer(self, id: int): """ Instructs the game to disconnect from the peer given by id :param id: :return: """ - self.send_gpgnet_message('DisconnectFromPeer', [id]) + await self.send_gpgnet_message('DisconnectFromPeer', [id]) - def send_gpgnet_message(self, command_id, arguments): + async def send_gpgnet_message(self, command_id, arguments): message = {"command": command_id, "args": arguments} - self.send_message(message) + await self.send_message(message) @abstractmethod - def send_message(self, message): + async def send_message(self, message): pass # pragma: no cover diff --git a/server/protocol/protocol.py b/server/protocol/protocol.py index 99018cd76..403839257 100644 --- a/server/protocol/protocol.py +++ b/server/protocol/protocol.py @@ -14,7 +14,7 @@ async def read_message(self: 'Protocol') -> dict: pass # pragma: no cover @abstractmethod - def send_message(self, message: dict) -> None: + async def send_message(self, message: dict) -> None: """ Send a single message in the form of a dictionary @@ -23,7 +23,7 @@ def send_message(self, message: dict) -> None: pass # pragma: no cover @abstractmethod - def send_messages(self, messages: List[dict]) -> None: + async def send_messages(self, messages: List[dict]) -> None: """ Send multiple messages in the form of a list of dictionaries. @@ -34,7 +34,7 @@ def send_messages(self, messages: List[dict]) -> None: pass # pragma: no cover @abstractmethod - def send_raw(self, data: bytes) -> None: + async def send_raw(self, data: bytes) -> None: """ Send raw bytes. Should generally not be used. diff --git a/server/protocol/qdatastreamprotocol.py b/server/protocol/qdatastreamprotocol.py index 78baa5404..d0ef572a4 100644 --- a/server/protocol/qdatastreamprotocol.py +++ b/server/protocol/qdatastreamprotocol.py @@ -121,15 +121,6 @@ async def read_message(self): pass return message - async def drain(self): - """ - Await the write buffer to empty. - - See StreamWriter.drain() - """ - await asyncio.sleep(0) - await self.writer.drain() - def close(self): """ Close writer stream @@ -137,20 +128,23 @@ def close(self): """ self.writer.close() - def send_message(self, message: dict): + async def send_message(self, message: dict): + server.stats.incr('server.sent_messages') self.writer.write( self.pack_message(json.dumps(message, separators=(',', ':'))) ) - server.stats.incr('server.sent_messages') + await self.writer.drain() - def send_messages(self, messages): + async def send_messages(self, messages): server.stats.incr('server.sent_messages') payload = [ self.pack_message(json.dumps(msg, separators=(',', ':'))) for msg in messages ] self.writer.writelines(payload) + await self.writer.drain() - def send_raw(self, data): + async def send_raw(self, data): server.stats.incr('server.sent_messages') self.writer.write(data) + await self.writer.drain() diff --git a/server/servercontext.py b/server/servercontext.py index ff20a6ead..d75f3ea13 100644 --- a/server/servercontext.py +++ b/server/servercontext.py @@ -46,11 +46,14 @@ def close(self): def __contains__(self, connection): return connection in self.connections.keys() - def broadcast_raw(self, message, validate_fn=lambda a: True): + async def broadcast_raw(self, message, validate_fn=lambda a: True): server.stats.incr('server.broadcasts') + tasks = [] for conn, proto in self.connections.items(): if validate_fn(conn): - proto.send_raw(message) + tasks.append(proto.send_raw(message)) + + await asyncio.gather(*tasks) async def client_connected(self, stream_reader, stream_writer): self._logger.debug("%s: Client connected", self) @@ -64,9 +67,6 @@ async def client_connected(self, stream_reader, stream_writer): message = await protocol.read_message() with server.stats.timer('connection.on_message_received'): await connection.on_message_received(message) - with server.stats.timer('servercontext.drain'): - await asyncio.sleep(0) - await connection.drain() except ConnectionResetError: pass except ConnectionAbortedError: diff --git a/server/timing/__init__.py b/server/timing/__init__.py new file mode 100644 index 000000000..c23da5060 --- /dev/null +++ b/server/timing/__init__.py @@ -0,0 +1,3 @@ +from .timer import Timer, at_interval + +__all__ = ("Timer", "at_interval") diff --git a/server/timing/timer.py b/server/timing/timer.py new file mode 100644 index 000000000..1adea10a4 --- /dev/null +++ b/server/timing/timer.py @@ -0,0 +1,98 @@ +""" +This code is a modified version of Gael Pasgrimaud's library `aiocron`. + +See the original code here: +https://github.com/gawel/aiocron/blob/e82a53c3f9a7950209cee7b3e493204c1dfc8b12/aiocron/__init__.py +""" + +import asyncio +import functools + + +async def null_callback(*args): + return args + + +def wrap_func(func): + """wrap in a coroutine""" + if not asyncio.iscoroutinefunction(func): + return asyncio.coroutine(func) + return func + + +class Timer(object): + """Schedules a function to be called asynchronously on a fixed interval""" + + def __init__(self, interval, func=None, args=(), start=False, loop=None): + self.interval = interval + if func is not None: + self.func = func if not args else functools.partial(func, *args) + else: + self.func = null_callback + self.cron = wrap_func(self.func) + self.auto_start = start + self.handle = self.future = None + self.loop = loop if loop is not None else asyncio.get_event_loop() + if start and self.func is not null_callback: + self.handle = self.loop.call_soon_threadsafe(self.start) + + def start(self): + """Start scheduling""" + self.stop() + self.handle = self.loop.call_later(self.get_delay(), self.call_next) + + def stop(self): + """Stop scheduling""" + if self.handle is not None: + self.handle.cancel() + self.handle = self.future = None + + def get_delay(self): + """Return next interval to wait between calls""" + return self.interval + + def call_next(self): + """Set next hop in the loop. Call task""" + if self.handle is not None: + self.handle.cancel() + delay = self.get_delay() + self.handle = self.loop.call_later(delay, self.call_next) + self.call_func() + + def call_func(self, *args, **kwargs): + """Called. Take care of exceptions using gather""" + asyncio.gather( + self.cron(*args, **kwargs), + loop=self.loop, return_exceptions=True + ).add_done_callback(self.set_result) + + def set_result(self, result): + """Set future's result if needed (can be an exception). + Else raise if needed.""" + result = result.result()[0] + if self.future is not None: + if isinstance(result, Exception): + self.future.set_exception(result) + else: + self.future.set_result(result) + self.future = None + elif isinstance(result, Exception): + raise result + + def __call__(self, func): + """Used as a decorator""" + self.func = func + self.cron = wrap_func(func) + if self.auto_start: + self.loop.call_soon_threadsafe(self.start) + return self + + def __str__(self): + return f"{self.interval} {self.func}" + + def __repr__(self): + return f"" + + +def at_interval(interval, func=None, args=(), start=True, loop=None): + return Timer(interval, func=func, args=args, start=start, loop=loop) diff --git a/tests/conftest.py b/tests/conftest.py index 967420786..deb371192 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -11,20 +11,20 @@ from typing import Iterable from unittest import mock +import asynctest import pytest +from asynctest import CoroutineMock from server.api.api_accessor import ApiAccessor from server.config import DB_LOGIN, DB_PASSWORD, DB_PORT, DB_SERVER +from server.db import FAFDatabase from server.game_service import GameService from server.geoip_service import GeoIpService +from server.lobbyconnection import LobbyConnection from server.matchmaker import MatchmakerQueue from server.player_service import PlayerService from server.rating import RatingType -from server.db import FAFDatabase from tests.utils import MockDatabase -from asynctest import CoroutineMock -import asynctest - logging.getLogger().setLevel(logging.DEBUG) @@ -159,6 +159,11 @@ def make(state=PlayerState.IDLE, global_rating=None, ladder_rating=None, p = Player(ratings=ratings, game_count=games, **kwargs) p.state = state + + # lobby_connection is a weak reference, but we want the mock + # to live for the full lifetime of the player object + p.__owned_lobby_connection = asynctest.create_autospec(LobbyConnection) + p.lobby_connection = p.__owned_lobby_connection return p return make diff --git a/tests/data/test-data.sql b/tests/data/test-data.sql index 4d9c2f5c5..f9add180f 100644 --- a/tests/data/test-data.sql +++ b/tests/data/test-data.sql @@ -1,4 +1,5 @@ insert into login (id, login, email, password, steamid, create_time) values + (10, 'friends', 'friends@example.com', SHA2('friends', 256), null, '2000-01-01 00:00:00'), (50, 'player_service1', 'ps1@example.com', SHA2('player_service1', 256), null, '2000-01-01 00:00:00'), (51, 'player_service2', 'ps2@example.com', SHA2('player_service2', 256), null, '2000-01-01 00:00:00'), (52, 'player_service3', 'ps3@example.com', SHA2('player_service3', 256), null, '2000-01-01 00:00:00'), @@ -69,7 +70,8 @@ insert into game_player_stats (gameId, playerId, AI, faction, color, team, place delete from friends_and_foes where user_id = 1 and subject_id = 2; insert into friends_and_foes (user_id, subject_id, status) values - (2, 1, 'FRIEND'); + (2, 1, 'FRIEND'), + (10, 1, 'FRIEND'); insert into `mod` (id, display_name, author) values diff --git a/tests/integration_tests/conftest.py b/tests/integration_tests/conftest.py index 441d92a28..b379874f1 100644 --- a/tests/integration_tests/conftest.py +++ b/tests/integration_tests/conftest.py @@ -110,7 +110,7 @@ async def perform_login( ) -> None: login, pw = credentials pw_hash = hashlib.sha256(pw.encode('utf-8')) - proto.send_message({ + await proto.send_message({ 'command': 'hello', 'version': '1.0.0-dev', 'user_agent': 'faf-client', @@ -118,7 +118,6 @@ async def perform_login( 'password': pw_hash.hexdigest(), 'unique_id': 'some_id' }) - await proto.drain() async def read_until( @@ -139,8 +138,7 @@ async def read_until_command(proto: QDataStreamProtocol, command: str) -> Dict[s async def get_session(proto): - proto.send_message({'command': 'ask_session', 'user_agent': 'faf-client', 'version': '0.11.16'}) - await proto.drain() + await proto.send_message({'command': 'ask_session', 'user_agent': 'faf-client', 'version': '0.11.16'}) msg = await read_until_command(proto, 'session') return msg['session'] diff --git a/tests/integration_tests/test_matchmaker.py b/tests/integration_tests/test_matchmaker.py index d9ba3e03b..fe99ce27c 100644 --- a/tests/integration_tests/test_matchmaker.py +++ b/tests/integration_tests/test_matchmaker.py @@ -19,19 +19,17 @@ async def queue_players_for_matchmaking(lobby_server): await read_until_command(proto1, 'game_info') await read_until_command(proto2, 'game_info') - proto1.send_message({ + await proto1.send_message({ 'command': 'game_matchmaking', 'state': 'start', 'faction': 'uef' }) - await proto1.drain() - proto2.send_message({ + await proto2.send_message({ 'command': 'game_matchmaking', 'state': 'start', 'faction': 1 # Python client sends factions as numbers }) - await proto2.drain() # If the players did not match, this will fail due to a timeout error await read_until_command(proto1, 'match_found') @@ -46,7 +44,7 @@ async def test_game_matchmaking(lobby_server): # The player that queued last will be the host msg2 = await read_until_command(proto2, 'game_launch') - proto2.send_message({ + await proto2.send_message({ 'command': 'GameState', 'target': 'game', 'args': ['Lobby'] @@ -58,7 +56,7 @@ async def test_game_matchmaking(lobby_server): assert msg2['mod'] == 'ladder1v1' -@fast_forward(1000) +@fast_forward(100) async def test_matchmaker_info_message(lobby_server, mocker): mocker.patch('server.matchmaker.pop_timer.time', return_value=1_562_000_000) @@ -95,8 +93,7 @@ async def test_command_matchmaker_info(lobby_server, mocker): await read_until_command(proto, "game_info") - proto.send_message({"command": "matchmaker_info"}) - await proto.drain() + await proto.send_message({"command": "matchmaker_info"}) msg = await read_until_command(proto, "matchmaker_info") assert msg == { @@ -120,7 +117,7 @@ async def test_matchmaker_info_message_on_cancel(lobby_server): await read_until_command(proto, 'game_info') - proto.send_message({ + await proto.send_message({ 'command': 'game_matchmaking', 'state': 'start', 'faction': 'uef' @@ -132,7 +129,7 @@ async def test_matchmaker_info_message_on_cancel(lobby_server): assert msg["queues"][0]["queue_name"] == "ladder1v1" assert len(msg["queues"][0]["boundary_80s"]) == 1 - proto.send_message({ + await proto.send_message({ 'command': 'game_matchmaking', 'state': 'stop', }) diff --git a/tests/integration_tests/test_modvault.py b/tests/integration_tests/test_modvault.py index 0d6e7118f..855ad4da9 100644 --- a/tests/integration_tests/test_modvault.py +++ b/tests/integration_tests/test_modvault.py @@ -1,6 +1,7 @@ -from .conftest import connect_and_sign_in, read_until_command import pytest +from .conftest import connect_and_sign_in, read_until_command + pytestmark = pytest.mark.asyncio @@ -12,11 +13,10 @@ async def test_modvault_start(lobby_server): await read_until_command(proto, 'game_info') - proto.send_message({ + await proto.send_message({ 'command': 'modvault', 'type': 'start' }) - await proto.drain() # Make sure all 5 mod version messages are sent for _ in range(5): @@ -31,12 +31,11 @@ async def test_modvault_like(lobby_server): await read_until_command(proto, 'game_info') - proto.send_message({ + await proto.send_message({ 'command': 'modvault', 'type': 'like', 'uid': 'FFF' }) - await proto.drain() msg = await read_until_command(proto, 'modvault_info') # Not going to verify the date diff --git a/tests/integration_tests/test_server.py b/tests/integration_tests/test_server.py index 35f9ecf41..e25cdec29 100644 --- a/tests/integration_tests/test_server.py +++ b/tests/integration_tests/test_server.py @@ -1,9 +1,9 @@ import asyncio -from unittest import mock import pytest from server import VisibilityState from server.db.models import ban +from tests.utils import fast_forward from .conftest import ( connect_and_sign_in, connect_client, perform_login, read_until, @@ -17,15 +17,13 @@ async def test_server_deprecated_client(lobby_server): proto = await connect_client(lobby_server) - proto.send_message({'command': 'ask_session', 'user_agent': 'faf-client', 'version': '0.0.0'}) - await proto.drain() + await proto.send_message({'command': 'ask_session', 'user_agent': 'faf-client', 'version': '0.0.0'}) msg = await proto.read_message() assert msg['command'] == 'notice' proto = await connect_client(lobby_server) - proto.send_message({'command': 'ask_session', 'version': '0.0.0'}) - await proto.drain() + await proto.send_message({'command': 'ask_session', 'version': '0.0.0'}) msg = await proto.read_message() assert msg['command'] == 'notice' @@ -132,6 +130,14 @@ async def test_server_double_login(lobby_server): await lobby_server.wait_closed() +@fast_forward(50) +async def test_ping_message(lobby_server): + _, _, proto = await connect_and_sign_in(('test', 'test_password'), lobby_server) + + # We should receive the message every 45 seconds + await asyncio.wait_for(read_until_command(proto, 'ping'), 46) + + async def test_player_info_broadcast(lobby_server): p1 = await connect_client(lobby_server) p2 = await connect_client(lobby_server) @@ -155,13 +161,12 @@ async def test_info_broadcast_authenticated(lobby_server): await perform_login(proto1, ('test', 'test_password')) await perform_login(proto2, ('Rhiza', 'puff_the_magic_dragon')) - proto1.send_message({ + await proto1.send_message({ "command": "game_matchmaking", "state": "start", "mod": "ladder1v1", "faction": "uef" }) - await proto1.drain() # Will timeout if the message is never received await read_until_command(proto2, "matchmaker_info") with pytest.raises(asyncio.TimeoutError): @@ -170,6 +175,70 @@ async def test_info_broadcast_authenticated(lobby_server): assert False +async def test_game_info_not_broadcast_to_foes(lobby_server): + # Rhiza is foed by test + _, _, proto1 = await connect_and_sign_in( + ("test", "test_password"), lobby_server + ) + _, _, proto2 = await connect_and_sign_in( + ("Rhiza", "puff_the_magic_dragon"), lobby_server + ) + await read_until_command(proto1, "game_info") + await read_until_command(proto2, "game_info") + + await proto1.send_message({ + "command": "game_host", + "title": "No Foes Allowed", + "mod": "faf", + "visibility": "public" + }) + + msg = await read_until_command(proto1, "game_info") + + assert msg["featured_mod"] == "faf" + assert msg["title"] == "No Foes Allowed" + assert msg["visibility"] == "public" + + with pytest.raises(asyncio.TimeoutError): + await asyncio.wait_for(read_until_command(proto2, "game_info"), 0.2) + + +async def test_game_info_broadcast_to_friends(lobby_server): + # test is the friend of friends + _, _, proto1 = await connect_and_sign_in( + ("friends", "friends"), lobby_server + ) + _, _, proto2 = await connect_and_sign_in( + ("test", "test_password"), lobby_server + ) + _, _, proto3 = await connect_and_sign_in( + ("Rhiza", "puff_the_magic_dragon"), lobby_server + ) + await read_until_command(proto1, "game_info") + await read_until_command(proto2, "game_info") + await read_until_command(proto3, "game_info") + + await proto1.send_message({ + "command": "game_host", + "title": "Friends Only", + "mod": "faf", + "visibility": "friends" + }) + + # The host and his friend should see the game + msg = await read_until_command(proto1, "game_info") + msg2 = await read_until_command(proto2, "game_info") + + assert msg == msg2 + assert msg["featured_mod"] == "faf" + assert msg["title"] == "Friends Only" + assert msg["visibility"] == "friends" + + # However, the other person should not see the game + with pytest.raises(asyncio.TimeoutError): + await asyncio.wait_for(read_until_command(proto3, "game_info"), 0.2) + + @pytest.mark.parametrize("user", [ ("test", "test_password"), ("ban_revoked", "ban_revoked"), @@ -181,13 +250,12 @@ async def test_game_host_authenticated(lobby_server, user): _, _, proto = await connect_and_sign_in(user, lobby_server) await read_until_command(proto, 'game_info') - proto.send_message({ + await proto.send_message({ 'command': 'game_host', 'title': 'My Game', 'mod': 'faf', 'visibility': 'public', }) - await proto.drain() msg = await read_until_command(proto, 'game_launch') @@ -205,13 +273,12 @@ async def test_host_missing_fields(event_loop, lobby_server, player_service): await read_until_command(proto, 'game_info') - proto.send_message({ + await proto.send_message({ 'command': 'game_host', 'mod': '', - 'visibility': VisibilityState.to_string(VisibilityState.PUBLIC), + 'visibility': 'public', 'title': '' }) - await proto.drain() msg = await read_until_command(proto, 'game_info') @@ -229,8 +296,7 @@ async def test_coop_list(lobby_server): await read_until_command(proto, 'game_info') - proto.send_message({"command": "coop_list"}) - await proto.drain() + await proto.send_message({"command": "coop_list"}) msg = await read_until_command(proto, "coop_info") assert "name" in msg @@ -261,8 +327,7 @@ async def test_server_ban_prevents_hosting(lobby_server, database, command): ) ) - proto.send_message({"command": command}) - await proto.drain() + await proto.send_message({"command": command}) msg = await proto.read_message() assert msg == { diff --git a/tests/integration_tests/test_servercontext.py b/tests/integration_tests/test_servercontext.py index a02013d99..d54919f84 100644 --- a/tests/integration_tests/test_servercontext.py +++ b/tests/integration_tests/test_servercontext.py @@ -1,12 +1,10 @@ import asyncio -from asynctest import exhaust_callbacks -import pytest from unittest import mock -from server import ServerContext +import pytest +from asynctest import exhaust_callbacks +from server import ServerContext, fake_statsd from server.protocol import QDataStreamProtocol -from server import fake_statsd - pytestmark = pytest.mark.asyncio @@ -45,8 +43,7 @@ def fin(): async def test_serverside_abort(event_loop, mock_context, mock_server): (reader, writer) = await asyncio.open_connection(*mock_context.sockets[0].getsockname()) proto = QDataStreamProtocol(reader, writer) - proto.send_message({"some_junk": True}) - await writer.drain() + await proto.send_message({"some_junk": True}) await exhaust_callbacks(event_loop) mock_server.on_connection_lost.assert_any_call() diff --git a/tests/unit_tests/conftest.py b/tests/unit_tests/conftest.py index 5e3b4ed2f..2e321e70b 100644 --- a/tests/unit_tests/conftest.py +++ b/tests/unit_tests/conftest.py @@ -1,14 +1,14 @@ from unittest import mock +import asynctest import pytest +from asynctest import CoroutineMock from server import GameStatsService from server.game_service import GameService from server.gameconnection import GameConnection, GameConnectionState from server.games import Game from server.geoip_service import GeoIpService from server.ladder_service import LadderService -import asynctest -from asynctest import CoroutineMock @pytest.fixture @@ -19,7 +19,15 @@ def lobbythread(): @pytest.fixture -def game_connection(request, database, game, players, game_service, player_service): +def game_connection( + request, + database, + game, + players, + game_service, + player_service, + event_loop +): from server import GameConnection conn = GameConnection( database=database, @@ -33,7 +41,7 @@ def game_connection(request, database, game, players, game_service, player_servi conn.finished_sim = False def fin(): - conn.abort() + event_loop.run_until_complete(conn.abort()) request.addfinalizer(fin) return conn diff --git a/tests/unit_tests/test_gameconnection.py b/tests/unit_tests/test_gameconnection.py index dbfa60d44..a5ef526d0 100644 --- a/tests/unit_tests/test_gameconnection.py +++ b/tests/unit_tests/test_gameconnection.py @@ -1,17 +1,24 @@ import asyncio from unittest import mock -import pytest -import asynctest +import asynctest +import pytest +from asynctest import CoroutineMock, exhaust_callbacks from server import GameConnection +from server.abc.base_game import GameConnectionState from server.games import Game -from server.games.game import ValidityState, Victory +from server.games.game import GameState, ValidityState, Victory from server.players import PlayerState -from asynctest import CoroutineMock, exhaust_callbacks pytestmark = pytest.mark.asyncio +@pytest.fixture +def real_game(event_loop, database, game_service, game_stats_service): + game = Game(42, database, game_service, game_stats_service) + yield game + + def assert_message_sent(game_connection: GameConnection, command, args): game_connection.protocol.send_message.assert_called_with({ 'command': command, @@ -24,11 +31,46 @@ async def test_abort(game_connection: GameConnection, game: Game, players): game_connection.player = players.hosting game_connection.game = game - game_connection.abort() + await game_connection.abort() game.remove_game_connection.assert_called_with(game_connection) +async def test_disconnect_all_peers( + game_connection: GameConnection, + real_game: Game, + players +): + real_game.state = GameState.LOBBY + game_connection.player = players.hosting + game_connection.game = real_game + + disconnect_done = mock.Mock() + + async def fake_send_dc(player_id): + await asyncio.sleep(1) # Take some time + disconnect_done.success() + return "OK" + + # Set up a peer that will disconnect without error + ok_disconnect = asynctest.create_autospec(GameConnection) + ok_disconnect.state = GameConnectionState.CONNECTED_TO_HOST + ok_disconnect.send_DisconnectFromPeer = fake_send_dc + + # Set up a peer that will throw an exception + fail_disconnect = asynctest.create_autospec(GameConnection) + fail_disconnect.send_DisconnectFromPeer.return_value = Exception("Test exception") + fail_disconnect.state = GameConnectionState.CONNECTED_TO_HOST + + # Add the peers to the game + real_game.add_game_connection(fail_disconnect) + real_game.add_game_connection(ok_disconnect) + + await game_connection.disconnect_all_peers() + + disconnect_done.success.assert_called_once() + + async def test_handle_action_GameState_idle_adds_connection( game: Game, game_connection: GameConnection, @@ -49,7 +91,7 @@ async def test_handle_action_GameState_idle_non_searching_player_aborts( ): game_connection.player = players.hosting game_connection.lobby = mock.Mock() - game_connection.abort = mock.Mock() + game_connection.abort = CoroutineMock() players.hosting.state = PlayerState.IDLE await game_connection.handle_action('GameState', ['Idle']) @@ -233,12 +275,12 @@ async def test_handle_action_TeamkillReport(game: Game, game_connection: GameCon result = await conn.execute("select game_id,id from moderation_report where reporter_id=2 and game_id=%s and game_incident_timecode=200", (game.id)) report = await result.fetchone() assert game.id == report["game_id"] - + reported_user_query = await conn.execute("select player_id from reported_user where report_id=%s", (report["id"])) data = await reported_user_query.fetchone() assert data["player_id"] == 3 - - + + async def test_handle_action_TeamkillReport_invalid_ids(game: Game, game_connection: GameConnection, database): game.launch = CoroutineMock() await game_connection.handle_action('TeamkillReport', ['230', 0, 'Dostya', 0, 'Rhiza']) @@ -247,7 +289,7 @@ async def test_handle_action_TeamkillReport_invalid_ids(game: Game, game_connect result = await conn.execute("select game_id,id from moderation_report where reporter_id=2 and game_id=%s and game_incident_timecode=230", (game.id)) report = await result.fetchone() assert game.id == report["game_id"] - + reported_user_query = await conn.execute("select player_id from reported_user where report_id=%s", (report["id"])) data = await reported_user_query.fetchone() assert data["player_id"] == 3 @@ -288,7 +330,7 @@ async def test_handle_action_TeamkillHappened(game: Game, game_connection: GameC async def test_handle_action_TeamkillHappened_AI(game: Game, game_connection: GameConnection, database): # Should fail with a sql constraint error if this isn't handled correctly - game_connection.abort = mock.Mock() + game_connection.abort = CoroutineMock() await game_connection.handle_action('TeamkillHappened', ['200', 0, 'Dostya', '0', 'Rhiza']) game_connection.abort.assert_not_called() diff --git a/tests/unit_tests/test_ladder.py b/tests/unit_tests/test_ladder.py index aa6227659..06737358f 100644 --- a/tests/unit_tests/test_ladder.py +++ b/tests/unit_tests/test_ladder.py @@ -1,12 +1,11 @@ import asyncio from unittest import mock -from asynctest import exhaust_callbacks import pytest +from asynctest import CoroutineMock, exhaust_callbacks from server import GameService, LadderService from server.matchmaker import Search from server.players import PlayerState -from asynctest import CoroutineMock from tests.utils import fast_forward pytestmark = pytest.mark.asyncio @@ -17,11 +16,6 @@ async def test_start_game(ladder_service: LadderService, game_service: p1 = player_factory('Dostya', player_id=1) p2 = player_factory('Rhiza', player_id=2) - mock_lc1 = mock.Mock() - mock_lc2 = mock.Mock() - p1.lobby_connection = mock_lc1 - p2.lobby_connection = mock_lc2 - game_service.ladder_maps = [(1, 'scmp_007', 'maps/scmp_007.zip')] with mock.patch('server.games.game.Game.await_hosted', CoroutineMock()): @@ -37,11 +31,6 @@ async def test_start_game_timeout(ladder_service: LadderService, game_service: p1 = player_factory('Dostya', player_id=1) p2 = player_factory('Rhiza', player_id=2) - mock_lc1 = mock.Mock() - mock_lc2 = mock.Mock() - p1.lobby_connection = mock_lc1 - p2.lobby_connection = mock_lc2 - game_service.ladder_maps = [(1, 'scmp_007', 'maps/scmp_007.zip')] await ladder_service.start_game(p1, p2) @@ -56,19 +45,16 @@ async def test_start_game_timeout(ladder_service: LadderService, game_service: async def test_inform_player(ladder_service: LadderService, player_factory): p1 = player_factory('Dostya', player_id=1, ladder_rating=(1500, 500)) - mock_lc = mock.Mock() - p1.lobby_connection = mock_lc - - ladder_service.inform_player(p1) + await ladder_service.inform_player(p1) # Message is sent after the first call p1.lobby_connection.send.assert_called_once() - ladder_service.inform_player(p1) + await ladder_service.inform_player(p1) p1.lobby_connection.send.reset_mock() # But not after the second p1.lobby_connection.send.assert_not_called() ladder_service.on_connection_lost(p1) - ladder_service.inform_player(p1) + await ladder_service.inform_player(p1) # But it is called if the player relogs p1.lobby_connection.send.assert_called_once() @@ -78,12 +64,9 @@ async def test_start_and_cancel_search(ladder_service: LadderService, player_factory, event_loop): p1 = player_factory('Dostya', player_id=1, ladder_rating=(1500, 500), ladder_games=0) - mock_lc = mock.Mock() - p1.lobby_connection = mock_lc - search = Search([p1]) - ladder_service.start_search(p1, search, 'ladder1v1') + await ladder_service.start_search(p1, search, 'ladder1v1') await exhaust_callbacks(event_loop) assert p1.state == PlayerState.SEARCHING_LADDER @@ -100,12 +83,9 @@ async def test_start_search_cancels_previous_search( ladder_service: LadderService, player_factory, event_loop): p1 = player_factory('Dostya', player_id=1, ladder_rating=(1500, 500), ladder_games=0) - mock_lc = mock.Mock() - p1.lobby_connection = mock_lc - search1 = Search([p1]) - ladder_service.start_search(p1, search1, 'ladder1v1') + await ladder_service.start_search(p1, search1, 'ladder1v1') await exhaust_callbacks(event_loop) assert p1.state == PlayerState.SEARCHING_LADDER @@ -113,7 +93,7 @@ async def test_start_search_cancels_previous_search( search2 = Search([p1]) - ladder_service.start_search(p1, search2, 'ladder1v1') + await ladder_service.start_search(p1, search2, 'ladder1v1') await exhaust_callbacks(event_loop) assert p1.state == PlayerState.SEARCHING_LADDER @@ -126,12 +106,9 @@ async def test_cancel_all_searches(ladder_service: LadderService, player_factory, event_loop): p1 = player_factory('Dostya', player_id=1, ladder_rating=(1500, 500), ladder_games=0) - mock_lc = mock.Mock() - p1.lobby_connection = mock_lc - search = Search([p1]) - ladder_service.start_search(p1, search, 'ladder1v1') + await ladder_service.start_search(p1, search, 'ladder1v1') await exhaust_callbacks(event_loop) assert p1.state == PlayerState.SEARCHING_LADDER @@ -149,16 +126,11 @@ async def test_cancel_twice(ladder_service: LadderService, player_factory): p1 = player_factory('Dostya', player_id=1, ladder_rating=(1500, 500), ladder_games=0) p2 = player_factory('Brackman', player_id=2, ladder_rating=(2000, 500), ladder_games=0) - mock_lc1 = mock.Mock() - mock_lc2 = mock.Mock() - p1.lobby_connection = mock_lc1 - p2.lobby_connection = mock_lc2 - search = Search([p1]) search2 = Search([p2]) - ladder_service.start_search(p1, search, 'ladder1v1') - ladder_service.start_search(p2, search2, 'ladder1v1') + await ladder_service.start_search(p1, search, 'ladder1v1') + await ladder_service.start_search(p2, search2, 'ladder1v1') searches = ladder_service._cancel_existing_searches(p1) assert search.is_cancelled @@ -179,18 +151,13 @@ async def test_start_game_called_on_match(ladder_service: LadderService, p1 = player_factory('Dostya', player_id=1, ladder_rating=(2300, 64), ladder_games=0) p2 = player_factory('QAI', player_id=2, ladder_rating=(2350, 125), ladder_games=0) - mock_lc1 = mock.Mock() - mock_lc2 = mock.Mock() - p1.lobby_connection = mock_lc1 - p2.lobby_connection = mock_lc2 - ladder_service.start_game = CoroutineMock() - ladder_service.inform_player = mock.Mock() + ladder_service.inform_player = CoroutineMock() - ladder_service.start_search(p1, Search([p1]), 'ladder1v1') - ladder_service.start_search(p2, Search([p2]), 'ladder1v1') + await ladder_service.start_search(p1, Search([p1]), 'ladder1v1') + await ladder_service.start_search(p2, Search([p2]), 'ladder1v1') - await asyncio.sleep(1) + await asyncio.sleep(2) ladder_service.inform_player.assert_called() ladder_service.start_game.assert_called_once() diff --git a/tests/unit_tests/test_lobbyconnection.py b/tests/unit_tests/test_lobbyconnection.py index e43b403e2..f6f2dc32f 100644 --- a/tests/unit_tests/test_lobbyconnection.py +++ b/tests/unit_tests/test_lobbyconnection.py @@ -2,12 +2,14 @@ from unittest import mock from unittest.mock import Mock +import asynctest import pytest from aiohttp import web from asynctest import CoroutineMock from server import GameState, VisibilityState from server.db.models import ban, friends_and_foes from server.game_service import GameService +from server.gameconnection import GameConnection from server.games import CustomGame, Game from server.geoip_service import GeoIpService from server.ice_servers.nts import TwilioNTS @@ -71,7 +73,7 @@ def mock_games(database, mock_players, game_stats_service): @pytest.fixture def mock_protocol(): - return mock.create_autospec(QDataStreamProtocol(mock.Mock(), mock.Mock())) + return asynctest.create_autospec(QDataStreamProtocol(mock.Mock(), mock.Mock())) @pytest.fixture @@ -95,6 +97,7 @@ def lobbyconnection(event_loop, database, mock_protocol, mock_games, mock_player lc.player_service.get_permission_group.return_value = 0 lc.player_service.fetch_player_data = CoroutineMock() lc.peer_address = Address('127.0.0.1', 1234) + lc._authenticated = True return lc @@ -125,13 +128,62 @@ async def start_app(): event_loop.run_until_complete(runner.cleanup()) +async def test_unauthenticated_calls_abort(lobbyconnection, test_game_info): + lobbyconnection._authenticated = False + lobbyconnection.abort = CoroutineMock() + + await lobbyconnection.on_message_received({ + "command": "game_host", + **test_game_info + }) + + lobbyconnection.abort.assert_called_once_with( + "Message invalid for unauthenticated connection: game_host" + ) + + +async def test_bad_command_calls_abort(lobbyconnection): + lobbyconnection.send = CoroutineMock() + lobbyconnection.abort = CoroutineMock() + + await lobbyconnection.on_message_received({ + "command": "this_isnt_real" + }) + + lobbyconnection.send.assert_called_once_with({"command": "invalid"}) + lobbyconnection.abort.assert_called_once_with("Error processing command") + + +async def test_command_pong_does_nothing(lobbyconnection): + lobbyconnection.send = CoroutineMock() + + await lobbyconnection.on_message_received({ + "command": "pong" + }) + + lobbyconnection.send.assert_not_called() + + +async def test_command_create_account_returns_error(lobbyconnection): + lobbyconnection.send = CoroutineMock() + + await lobbyconnection.on_message_received({ + "command": "create_account" + }) + + lobbyconnection.send.assert_called_once_with({ + "command": "notice", + "style": "error", + "text": ("FAF no longer supports direct registration. " + "Please use the website to register.") + }) + + async def test_command_game_host_creates_game(lobbyconnection, mock_games, test_game_info, players): lobbyconnection.player = players.hosting - lobbyconnection.protocol = mock.Mock() - lobbyconnection._authenticated = True await lobbyconnection.on_message_received({ "command": "game_host", **test_game_info @@ -148,12 +200,12 @@ async def test_command_game_host_creates_game(lobbyconnection, async def test_launch_game(lobbyconnection, game, player_factory): - old_game_conn = mock.Mock() + old_game_conn = asynctest.create_autospec(GameConnection) lobbyconnection.player = player_factory() lobbyconnection.game_connection = old_game_conn - lobbyconnection.send = mock.Mock() - lobbyconnection.launch_game(game) + lobbyconnection.send = CoroutineMock() + await lobbyconnection.launch_game(game) # Verify all side effects of launch_game here old_game_conn.abort.assert_called_with("Player launched a new game") @@ -170,10 +222,8 @@ async def test_command_game_host_creates_correct_game( lobbyconnection, game_service, test_game_info, players): lobbyconnection.player = players.hosting lobbyconnection.game_service = game_service - lobbyconnection.launch_game = mock.Mock() + lobbyconnection.launch_game = CoroutineMock() - lobbyconnection.protocol = mock.Mock() - lobbyconnection._authenticated = True await lobbyconnection.on_message_received({ "command": "game_host", **test_game_info @@ -192,7 +242,6 @@ async def test_command_game_join_calls_join_game(mocker, players, game_stats_service): lobbyconnection.game_service = game_service - mock_protocol = mocker.patch.object(lobbyconnection, 'protocol') game = mock.create_autospec(Game(42, database, game_service, game_stats_service)) game.state = GameState.LOBBY game.password = None @@ -202,7 +251,6 @@ async def test_command_game_join_calls_join_game(mocker, lobbyconnection.player = players.hosting test_game_info['uid'] = 42 - lobbyconnection._authenticated = True await lobbyconnection.on_message_received({ "command": "game_join", **test_game_info @@ -213,7 +261,7 @@ async def test_command_game_join_calls_join_game(mocker, 'uid': 42, 'args': ['/numgames {}'.format(players.hosting.game_count[RatingType.GLOBAL])] } - mock_protocol.send_message.assert_called_with(expected_reply) + lobbyconnection.protocol.send_message.assert_called_with(expected_reply) async def test_command_game_join_uid_as_str(mocker, @@ -224,7 +272,6 @@ async def test_command_game_join_uid_as_str(mocker, players, game_stats_service): lobbyconnection.game_service = game_service - mock_protocol = mocker.patch.object(lobbyconnection, 'protocol') game = mock.create_autospec(Game(42, database, game_service, game_stats_service)) game.state = GameState.LOBBY game.password = None @@ -234,7 +281,6 @@ async def test_command_game_join_uid_as_str(mocker, lobbyconnection.player = players.hosting test_game_info['uid'] = '42' # Pass in uid as string - lobbyconnection._authenticated = True await lobbyconnection.on_message_received({ "command": "game_join", **test_game_info @@ -245,7 +291,7 @@ async def test_command_game_join_uid_as_str(mocker, 'uid': 42, 'args': ['/numgames {}'.format(players.hosting.game_count[RatingType.GLOBAL])] } - mock_protocol.send_message.assert_called_with(expected_reply) + lobbyconnection.protocol.send_message.assert_called_with(expected_reply) async def test_command_game_join_without_password(lobbyconnection, @@ -254,7 +300,7 @@ async def test_command_game_join_without_password(lobbyconnection, test_game_info, players, game_stats_service): - lobbyconnection.send = mock.Mock() + lobbyconnection.send = CoroutineMock() lobbyconnection.game_service = game_service game = mock.create_autospec(Game(42, database, game_service, game_stats_service)) game.state = GameState.LOBBY @@ -266,7 +312,6 @@ async def test_command_game_join_without_password(lobbyconnection, test_game_info['uid'] = 42 del test_game_info['password'] - lobbyconnection._authenticated = True await lobbyconnection.on_message_received({ "command": "game_join", **test_game_info @@ -279,12 +324,11 @@ async def test_command_game_join_game_not_found(lobbyconnection, game_service, test_game_info, players): - lobbyconnection.send = mock.Mock() + lobbyconnection.send = CoroutineMock() lobbyconnection.game_service = game_service lobbyconnection.player = players.hosting test_game_info['uid'] = 42 - lobbyconnection._authenticated = True await lobbyconnection.on_message_received({ "command": "game_join", **test_game_info @@ -295,12 +339,9 @@ async def test_command_game_join_game_not_found(lobbyconnection, async def test_command_game_host_calls_host_game_invalid_title(lobbyconnection, mock_games, - test_game_info_invalid, - players): - lobbyconnection.send = mock.Mock() + test_game_info_invalid): + lobbyconnection.send = CoroutineMock() mock_games.create_game = mock.Mock() - lobbyconnection.player = players.hosting - lobbyconnection._authenticated = True await lobbyconnection.on_message_received({ "command": "game_host", **test_game_info_invalid @@ -311,33 +352,31 @@ async def test_command_game_host_calls_host_game_invalid_title(lobbyconnection, async def test_abort(mocker, lobbyconnection): - proto = mocker.patch.object(lobbyconnection, 'protocol') - - lobbyconnection.abort() + lobbyconnection.protocol.writer.close = mock.Mock() + await lobbyconnection.abort() - proto.writer.close.assert_any_call() + lobbyconnection.protocol.writer.close.assert_any_call() async def test_send_game_list(mocker, database, lobbyconnection, game_stats_service): - protocol = mocker.patch.object(lobbyconnection, 'protocol') games = mocker.patch.object(lobbyconnection, 'game_service') # type: GameService game1, game2 = mock.create_autospec(Game(42, database, mock.Mock(), game_stats_service)), \ mock.create_autospec(Game(22, database, mock.Mock(), game_stats_service)) games.open_games = [game1, game2] - lobbyconnection.send_game_list() - - protocol.send_message.assert_any_call({'command': 'game_info', - 'games': [game1.to_dict(), game2.to_dict()]}) + await lobbyconnection.send_game_list() + lobbyconnection.protocol.send_message.assert_any_call({ + 'command': 'game_info', + 'games': [game1.to_dict(), game2.to_dict()] + }) -async def test_send_coop_maps(mocker, lobbyconnection): - protocol = mocker.patch.object(lobbyconnection, 'protocol') - await lobbyconnection.send_coop_maps() +async def test_coop_list(mocker, lobbyconnection): + await lobbyconnection.command_coop_list({}) - args = protocol.send_messages.call_args_list + args = lobbyconnection.protocol.send_messages.call_args_list assert len(args) == 1 coop_maps = args[0][0][0] for info in coop_maps: @@ -394,7 +433,7 @@ async def test_command_admin_closelobby(mocker, lobbyconnection): tuna.id = 55 lobbyconnection.player_service = {1: player, 55: tuna} - await lobbyconnection.command_admin({ + await lobbyconnection.on_message_received({ 'command': 'admin', 'action': 'closelobby', 'user_id': 55 @@ -411,7 +450,6 @@ async def test_command_admin_closelobby_with_ban(mocker, lobbyconnection, databa banme = mock.Mock() banme.id = 200 lobbyconnection.player_service = {1: player, banme.id: banme} - lobbyconnection._authenticated = True await lobbyconnection.on_message_received({ 'command': 'admin', @@ -443,7 +481,6 @@ async def test_command_admin_closelobby_with_ban_but_already_banned(mocker, lobb banme = mock.Mock() banme.id = 200 lobbyconnection.player_service = {1: player, banme.id: banme} - lobbyconnection._authenticated = True await lobbyconnection.on_message_received({ 'command': 'admin', @@ -482,7 +519,6 @@ async def test_command_admin_closelobby_with_ban_but_already_banned(mocker, lobb async def test_command_admin_closelobby_with_ban_duration_no_period(mocker, lobbyconnection, database): - mocker.patch.object(lobbyconnection, 'protocol') player = mocker.patch.object(lobbyconnection, 'player') player.login = 'Sheeo' player.id = 1 @@ -490,7 +526,6 @@ async def test_command_admin_closelobby_with_ban_duration_no_period(mocker, lobb banme = mock.Mock() banme.id = 200 lobbyconnection.player_service = {1: player, banme.id: banme} - lobbyconnection._authenticated = True mocker.patch('server.lobbyconnection.func.now', return_value=text('FROM_UNIXTIME(1000)')) await lobbyconnection.on_message_received({ @@ -515,13 +550,11 @@ async def test_command_admin_closelobby_with_ban_duration_no_period(mocker, lobb async def test_command_admin_closelobby_with_ban_bad_period(mocker, lobbyconnection, database): - proto = mocker.patch.object(lobbyconnection, 'protocol') player = mocker.patch.object(lobbyconnection, 'player') player.admin = True banme = mock.Mock() banme.id = 1 lobbyconnection.player_service = {1: player, banme.id: banme} - lobbyconnection._authenticated = True await lobbyconnection.on_message_received({ 'command': 'admin', @@ -535,7 +568,7 @@ async def test_command_admin_closelobby_with_ban_bad_period(mocker, lobbyconnect }) banme.lobbyconnection.kick.assert_not_called() - proto.send_message.assert_called_once_with({ + lobbyconnection.protocol.send_message.assert_called_once_with({ 'command': 'notice', 'style': 'error', 'text': "Period ') INJECTED!' is not allowed!" @@ -550,13 +583,11 @@ async def test_command_admin_closelobby_with_ban_bad_period(mocker, lobbyconnect async def test_command_admin_closelobby_with_ban_injection(mocker, lobbyconnection, database): - proto = mocker.patch.object(lobbyconnection, 'protocol') player = mocker.patch.object(lobbyconnection, 'player') player.admin = True banme = mock.Mock() banme.id = 1 lobbyconnection.player_service = {1: player, banme.id: banme} - lobbyconnection._authenticated = True await lobbyconnection.on_message_received({ 'command': 'admin', @@ -570,7 +601,7 @@ async def test_command_admin_closelobby_with_ban_injection(mocker, lobbyconnecti }) banme.lobbyconnection.kick.assert_not_called() - proto.send_message.assert_called_once_with({ + lobbyconnection.protocol.send_message.assert_called_once_with({ 'command': 'notice', 'style': 'error', 'text': "Period ') INJECTED!' is not allowed!" @@ -585,7 +616,6 @@ async def test_command_admin_closelobby_with_ban_injection(mocker, lobbyconnecti async def test_command_admin_closeFA(mocker, lobbyconnection): - mocker.patch.object(lobbyconnection, 'protocol') mocker.patch.object(lobbyconnection, '_logger') player = mocker.patch.object(lobbyconnection, 'player') player.login = 'Sheeo' @@ -593,7 +623,6 @@ async def test_command_admin_closeFA(mocker, lobbyconnection): player.id = 42 tuna = mock.Mock() tuna.id = 55 - lobbyconnection._authenticated = True lobbyconnection.player_service = {42: player, 55: tuna} await lobbyconnection.on_message_received({ @@ -612,7 +641,7 @@ async def test_game_subscription(lobbyconnection: LobbyConnection): game = Mock() game.handle_action = CoroutineMock() lobbyconnection.game_connection = game - lobbyconnection.ensure_authenticated = lambda _: True + lobbyconnection.ensure_authenticated = CoroutineMock(return_value=True) await lobbyconnection.on_message_received({'command': 'test', 'args': ['foo', 42], @@ -622,15 +651,15 @@ async def test_game_subscription(lobbyconnection: LobbyConnection): async def test_command_avatar_list(mocker, lobbyconnection: LobbyConnection, mock_player: Player): - protocol = mocker.patch.object(lobbyconnection, 'protocol') lobbyconnection.player = mock_player lobbyconnection.player.id = 2 # Dostya test user - await lobbyconnection.command_avatar({ + await lobbyconnection.on_message_received({ + 'command': 'avatar', 'action': 'list_avatar' }) - protocol.send_message.assert_any_call({ + lobbyconnection.protocol.send_message.assert_any_call({ "command": "avatar", "avatarlist": [{'url': 'http://content.faforever.com/faf/avatars/qai2.png', 'tooltip': 'QAI'}, {'url': 'http://content.faforever.com/faf/avatars/UEF.png', 'tooltip': 'UEF'}] }) @@ -639,7 +668,6 @@ async def test_command_avatar_list(mocker, lobbyconnection: LobbyConnection, moc async def test_command_avatar_select(mocker, database, lobbyconnection: LobbyConnection, mock_player: Player): lobbyconnection.player = mock_player lobbyconnection.player.id = 2 # Dostya test user - lobbyconnection._authenticated = True await lobbyconnection.on_message_received({ 'command': 'avatar', @@ -670,7 +698,6 @@ async def get_friends(player_id, database): async def test_command_social_add_friend(lobbyconnection, mock_player, database): lobbyconnection.player = mock_player lobbyconnection.player.id = 1 - lobbyconnection._authenticated = True friends = await get_friends(lobbyconnection.player.id, database) assert friends == [] @@ -687,7 +714,6 @@ async def test_command_social_add_friend(lobbyconnection, mock_player, database) async def test_command_social_remove_friend(lobbyconnection, mock_player, database): lobbyconnection.player = mock_player lobbyconnection.player.id = 2 - lobbyconnection._authenticated = True friends = await get_friends(lobbyconnection.player.id, database) assert friends == [1] @@ -702,7 +728,6 @@ async def test_command_social_remove_friend(lobbyconnection, mock_player, databa async def test_broadcast(lobbyconnection: LobbyConnection, mocker): - mocker.patch.object(lobbyconnection, 'protocol') player = mocker.patch.object(lobbyconnection, 'player') player.login = 'Sheeo' player.admin = True @@ -710,7 +735,7 @@ async def test_broadcast(lobbyconnection: LobbyConnection, mocker): tuna.id = 55 lobbyconnection.player_service = [player, tuna] - await lobbyconnection.command_admin({ + await lobbyconnection.on_message_received({ 'command': 'admin', 'action': 'broadcast', 'message': "This is a test message" @@ -721,16 +746,18 @@ async def test_broadcast(lobbyconnection: LobbyConnection, mocker): async def test_game_connection_not_restored_if_no_such_game_exists(lobbyconnection: LobbyConnection, mocker, mock_player): - protocol = mocker.patch.object(lobbyconnection, 'protocol') lobbyconnection.player = mock_player del lobbyconnection.player.game_connection lobbyconnection.player.state = PlayerState.IDLE - lobbyconnection.command_restore_game_session({'game_id': 123}) + await lobbyconnection.on_message_received({ + 'command': 'restore_game_session', + 'game_id': 123 + }) assert not lobbyconnection.player.game_connection assert lobbyconnection.player.state == PlayerState.IDLE - protocol.send_message.assert_any_call({ + lobbyconnection.protocol.send_message.assert_any_call({ "command": "notice", "style": "info", "text": "The game you were connected to does no longer exist" @@ -741,7 +768,6 @@ async def test_game_connection_not_restored_if_no_such_game_exists(lobbyconnecti async def test_game_connection_not_restored_if_game_state_prohibits(lobbyconnection: LobbyConnection, game_service: GameService, game_stats_service, game_state, mock_player, mocker, database): - protocol = mocker.patch.object(lobbyconnection, 'protocol') lobbyconnection.player = mock_player del lobbyconnection.player.game_connection lobbyconnection.player.state = PlayerState.IDLE @@ -753,12 +779,15 @@ async def test_game_connection_not_restored_if_game_state_prohibits(lobbyconnect game.id = 42 game_service.games[42] = game - lobbyconnection.command_restore_game_session({'game_id': 42}) + await lobbyconnection.on_message_received({ + 'command': 'restore_game_session', + 'game_id': 42 + }) assert not lobbyconnection.game_connection assert lobbyconnection.player.state == PlayerState.IDLE - protocol.send_message.assert_any_call({ + lobbyconnection.protocol.send_message.assert_any_call({ "command": "notice", "style": "info", "text": "The game you were connected to is no longer available" @@ -779,7 +808,10 @@ async def test_game_connection_restored_if_game_exists(lobbyconnection: LobbyCon game.id = 42 game_service.games[42] = game - lobbyconnection.command_restore_game_session({'game_id': 42}) + await lobbyconnection.on_message_received({ + 'command': 'restore_game_session', + 'game_id': 42 + }) assert lobbyconnection.game_connection assert lobbyconnection.player.state == PlayerState.PLAYING @@ -788,7 +820,6 @@ async def test_game_connection_restored_if_game_exists(lobbyconnection: LobbyCon async def test_command_game_matchmaking(lobbyconnection, mock_player): lobbyconnection.player = mock_player lobbyconnection.player.id = 1 - lobbyconnection._authenticated = True await lobbyconnection.on_message_received({ 'command': 'game_matchmaking', @@ -822,11 +853,11 @@ async def test_check_policy_conformity_fraudulent(lobbyconnection, policy_server f'http://{host}:{port}' ): # 42 is not a valid player ID which should cause a SQL constraint error - lobbyconnection.abort = mock.Mock() + lobbyconnection.abort = CoroutineMock() with pytest.raises(ClientError): await lobbyconnection.check_policy_conformity(42, "fraudulent", session=100) - lobbyconnection.abort = mock.Mock() + lobbyconnection.abort = CoroutineMock() player_id = 200 honest = await lobbyconnection.check_policy_conformity(player_id, "fraudulent", session=100) assert honest is False @@ -849,7 +880,7 @@ async def test_check_policy_conformity_fatal(lobbyconnection, policy_server): f'http://{host}:{port}' ): for result in ('vm', 'already_associated', 'fraudulent'): - lobbyconnection.abort = mock.Mock() + lobbyconnection.abort = CoroutineMock() honest = await lobbyconnection.check_policy_conformity(1, result, session=100) assert honest is False lobbyconnection.abort.assert_called_once()