diff --git a/server/__init__.py b/server/__init__.py
index 7b2eca1f2..5b6df70ba 100644
--- a/server/__init__.py
+++ b/server/__init__.py
@@ -11,6 +11,7 @@
from typing import Any, Dict, Optional
import aiomeasures
+import asyncio
from server.db import FAFDatabase
from . import config as config
@@ -26,9 +27,11 @@
from .game_service import GameService
from .ladder_service import LadderService
from .control import init as run_control_server
+from .timing import at_interval
+
__version__ = '0.9.17'
-__author__ = 'Chris Kitching, Dragonfire, Gael Honorez, Jeroen De Dauw, Crotalus, Michael Søndergaard, Michel Jung'
+__author__ = 'Askaholic, Chris Kitching, Dragonfire, Gael Honorez, Jeroen De Dauw, Crotalus, Michael Søndergaard, Michel Jung'
__contact__ = 'admin@faforever.com'
__license__ = 'GPLv3'
__copyright__ = 'Copyright (c) 2011-2015 ' + __author__
@@ -94,7 +97,8 @@ def run_lobby_server(
Run the lobby server
"""
- def report_dirties():
+ @at_interval(DIRTY_REPORT_INTERVAL)
+ async def do_report_dirties():
try:
dirty_games = games.dirty_games
dirty_queues = games.dirty_queues
@@ -103,13 +107,20 @@ def report_dirties():
player_service.clear_dirty()
if len(dirty_queues) > 0:
- ctx.broadcast_raw(encode_queues(dirty_queues), lambda lobby_conn: lobby_conn.authenticated)
+ await ctx.broadcast_raw(
+ encode_queues(dirty_queues),
+ lambda lobby_conn: lobby_conn.authenticated
+ )
if len(dirty_players) > 0:
- ctx.broadcast_raw(encode_players(dirty_players), lambda lobby_conn: lobby_conn.authenticated)
-
- # TODO: This spams squillions of messages: we should implement per-connection message
- # aggregation at the next abstraction layer down :P
+ await ctx.broadcast_raw(
+ encode_players(dirty_players),
+ lambda lobby_conn: lobby_conn.authenticated
+ )
+
+ # TODO: This spams squillions of messages: we should implement per-
+ # connection message aggregation at the next abstraction layer down :P
+ tasks = []
for game in dirty_games:
if game.state == GameState.ENDED:
games.remove_game(game)
@@ -117,25 +128,33 @@ def report_dirties():
# So we're going to be broadcasting this to _somebody_...
message = encode_dict(game.to_dict())
- # These games shouldn't be broadcast, but instead privately sent to those who are
- # allowed to see them.
+ # These games shouldn't be broadcast, but instead privately sent
+ # to those who are allowed to see them.
if game.visibility == VisibilityState.FRIENDS:
- # To see this game, you must have an authenticated connection and be a friend of the host, or the host.
- validation_func = lambda lobby_conn: lobby_conn.player.id in game.host.friends or lobby_conn.player == game.host
+ # To see this game, you must have an authenticated
+ # connection and be a friend of the host, or the host.
+ def validation_func(lobby_conn):
+ return lobby_conn.player.id in game.host.friends or \
+ lobby_conn.player == game.host
else:
- validation_func = lambda lobby_conn: lobby_conn.player.id not in game.host.foes
+ def validation_func(lobby_conn):
+ return lobby_conn.player.id not in game.host.foes
+
+ tasks.append(ctx.broadcast_raw(
+ message,
+ lambda lobby_conn: lobby_conn.authenticated and validation_func(lobby_conn)
+ ))
+
+ await asyncio.gather(*tasks)
- ctx.broadcast_raw(message, lambda lobby_conn: lobby_conn.authenticated and validation_func(lobby_conn))
except Exception as e:
logging.getLogger().exception(e)
- finally:
- loop.call_later(DIRTY_REPORT_INTERVAL, report_dirties)
ping_msg = encode_message('PING')
- def ping_broadcast():
- ctx.broadcast_raw(ping_msg)
- loop.call_later(45, ping_broadcast)
+ @at_interval(45)
+ async def ping_broadcast():
+ await ctx.broadcast_raw(ping_msg)
def make_connection() -> LobbyConnection:
return LobbyConnection(
@@ -147,7 +166,5 @@ def make_connection() -> LobbyConnection:
ladder_service=ladder_service
)
ctx = ServerContext(make_connection, name="LobbyServer")
- loop.call_later(DIRTY_REPORT_INTERVAL, report_dirties)
- loop.call_soon(ping_broadcast)
loop.run_until_complete(ctx.listen(*address))
return ctx
diff --git a/server/api/oauth_session.py b/server/api/oauth_session.py
index 656bfb425..5d78afd3b 100644
--- a/server/api/oauth_session.py
+++ b/server/api/oauth_session.py
@@ -3,8 +3,9 @@
from typing import Dict
import aiohttp
-from oauthlib.oauth2.rfc6749.errors import (InsecureTransportError,
- MissingTokenError)
+from oauthlib.oauth2.rfc6749.errors import (
+ InsecureTransportError, MissingTokenError
+)
class OAuth2Session(object):
diff --git a/server/control.py b/server/control.py
index 0bd013301..9d5a35206 100644
--- a/server/control.py
+++ b/server/control.py
@@ -2,12 +2,15 @@
Tiny local-only http server for getting stats and performing various tasks
"""
+import logging
import socket
+from json import dumps
from aiohttp import web
-import logging
-from server import PlayerService, GameService, config
-from json import dumps
+
+from . import config
+from .game_service import GameService
+from .player_service import PlayerService
logger = logging.getLogger(__name__)
diff --git a/server/game_service.py b/server/game_service.py
index ba7169e6d..d112c5e55 100644
--- a/server/game_service.py
+++ b/server/game_service.py
@@ -2,11 +2,10 @@
from typing import Dict, List, Optional, Union, ValuesView
import aiocron
-from server import GameState, VisibilityState
from server.db import FAFDatabase
from server.decorators import with_logger
from server.games import CoopGame, CustomGame, FeaturedMod, LadderGame
-from server.games.game import Game
+from server.games.game import Game, GameState, VisibilityState
from server.matchmaker import MatchmakerQueue
from server.players import Player
diff --git a/server/gameconnection.py b/server/gameconnection.py
index 0bea4b86b..63cf97439 100644
--- a/server/gameconnection.py
+++ b/server/gameconnection.py
@@ -1,10 +1,11 @@
import asyncio
from server.db import FAFDatabase
-from sqlalchemy import text, select
+from sqlalchemy import select, text
from .abc.base_game import GameConnectionState
from .config import TRACE
+from .db.models import login, moderation_report, reported_user
from .decorators import with_logger
from .game_service import GameService
from .games.game import Game, GameState, ValidityState, Victory
@@ -12,8 +13,6 @@
from .players import Player, PlayerState
from .protocol import GpgNetServerProtocol, QDataStreamProtocol
-from .db.models import (reported_user, moderation_report, login)
-
@with_logger
class GameConnection(GpgNetServerProtocol):
@@ -69,11 +68,11 @@ def player(self) -> Player:
def player(self, val: Player):
self._player = val
- def send_message(self, message):
+ async def send_message(self, message):
message['target'] = "game"
self._logger.log(TRACE, ">>: %s", message)
- self.protocol.send_message(message)
+ await self.protocol.send_message(message)
async def _handle_idle_state(self):
"""
@@ -92,7 +91,7 @@ async def _handle_idle_state(self):
pass
else:
self._logger.exception("Unknown PlayerState: %s", state)
- self.abort()
+ await self.abort()
async def _handle_lobby_state(self):
"""
@@ -104,7 +103,7 @@ async def _handle_lobby_state(self):
try:
player_state = self.player.state
if player_state == PlayerState.HOSTING:
- self.send_HostGame(self.game.map_folder_name)
+ await self.send_HostGame(self.game.map_folder_name)
self.game.set_hosted()
# If the player is joining, we connect him to host
# followed by the rest of the players.
@@ -116,10 +115,12 @@ async def _handle_lobby_state(self):
self._state = GameConnectionState.CONNECTED_TO_HOST
self.game.add_game_connection(self)
+ tasks = []
for peer in self.game.connections:
if peer != self and peer.player != self.game.host:
self._logger.debug("%s connecting to %s", self.player, peer)
- asyncio.ensure_future(self.connect_to_peer(peer))
+ tasks.append(self.connect_to_peer(peer))
+ await asyncio.gather(*tasks)
except Exception as e: # pragma: no cover
self._logger.exception(e)
@@ -129,24 +130,29 @@ async def connect_to_host(self, peer: "GameConnection"):
:return:
"""
assert peer.player.state == PlayerState.HOSTING
- self.send_JoinGame(peer.player.login,
- peer.player.id)
+ await self.send_JoinGame(peer.player.login, peer.player.id)
- peer.send_ConnectToPeer(player_name=self.player.login,
- player_uid=self.player.id,
- offer=True)
+ await peer.send_ConnectToPeer(
+ player_name=self.player.login,
+ player_uid=self.player.id,
+ offer=True
+ )
async def connect_to_peer(self, peer: "GameConnection"):
"""
Connect two peers
:return: None
"""
- self.send_ConnectToPeer(player_name=peer.player.login,
- player_uid=peer.player.id,
- offer=True)
- peer.send_ConnectToPeer(player_name=self.player.login,
- player_uid=self.player.id,
- offer=False)
+ await self.send_ConnectToPeer(
+ player_name=peer.player.login,
+ player_uid=peer.player.id,
+ offer=True
+ )
+ await peer.send_ConnectToPeer(
+ player_name=self.player.login,
+ player_uid=self.player.id,
+ offer=False
+ )
async def handle_action(self, command, args):
"""
@@ -167,7 +173,7 @@ async def handle_action(self, command, args):
except Exception as e: # pragma: no cover
self._logger.exception(e)
self._logger.exception("Something awful happened in a game thread!")
- self.abort()
+ await self.abort()
async def handle_desync(self, *_args): # pragma: no cover
self.game.desyncs += 1
@@ -303,7 +309,7 @@ async def handle_teamkill_report(self, gametime, reporter_id, reporter_name, tea
:param teamkiller_id: teamkiller id
:param teamkiller_name: teamkiller nickname - Used as a failsafe in case ID is wrong
"""
-
+
async with self._db.acquire() as conn:
"""
Sometime the game sends a wrong ID - but a correct player name
@@ -403,7 +409,7 @@ async def handle_ice_message(self, receiver_id, ice_msg):
)
return
- game_connection.send_message({
+ await game_connection.send_message({
"command": "IceMsg",
"args": [int(self.player.id), ice_msg]
})
@@ -499,7 +505,7 @@ def _mark_dirty(self):
if self.game:
self.game_service.mark_dirty(self.game)
- def abort(self, log_message: str=''):
+ async def abort(self, log_message: str=''):
"""
Abort the connection
@@ -513,10 +519,10 @@ def abort(self, log_message: str=''):
self._logger.debug("%s.abort(%s)", self, log_message)
if self.game.state == GameState.LOBBY:
- self.disconnect_all_peers()
+ await self.disconnect_all_peers()
self._state = GameConnectionState.ENDED
- asyncio.ensure_future(self.game.remove_game_connection(self))
+ await self.game.remove_game_connection(self)
self._mark_dirty()
self.player.state = PlayerState.IDLE
del self.player.game
@@ -524,14 +530,16 @@ def abort(self, log_message: str=''):
except Exception as ex: # pragma: no cover
self._logger.debug("Exception in abort(): %s", ex)
- def disconnect_all_peers(self):
+ async def disconnect_all_peers(self):
+ tasks = []
for peer in self.game.connections:
if peer == self:
continue
- try:
- peer.send_DisconnectFromPeer(self.player.id)
- except Exception: # pragma no cover
+ tasks.append(peer.send_DisconnectFromPeer(self.player.id))
+
+ for result in await asyncio.gather(*tasks, return_exceptions=True):
+ if isinstance(result, Exception):
self._logger.exception(
"peer_sendDisconnectFromPeer failed for player %i",
self.player.id)
@@ -542,7 +550,7 @@ async def on_connection_lost(self):
except Exception as e: # pragma: no cover
self._logger.exception(e)
finally:
- self.abort()
+ await self.abort()
def __str__(self):
return "GameConnection({}, {})".format(self.player, self.game)
diff --git a/server/games/game.py b/server/games/game.py
index ef88f0fdd..1e5cfef7a 100644
--- a/server/games/game.py
+++ b/server/games/game.py
@@ -3,17 +3,17 @@
import logging
import re
import time
-from collections import Counter, defaultdict
+from collections import defaultdict
from enum import Enum, unique
-from typing import Any, Dict, Optional, Tuple, Union
+from typing import Any, Dict, Optional, Tuple
import trueskill
+from server.games.game_results import GameOutcome, GameResult, GameResults
+from server.rating import RatingType
from trueskill import Rating
from ..abc.base_game import GameConnectionState, InitMode
from ..players import Player, PlayerState
-from server.rating import RatingType
-from server.games.game_results import GameOutcome, GameResult, GameResults
FFA_TEAM = 1
diff --git a/server/geoip_service.py b/server/geoip_service.py
index 03caf1a62..f9e81a40e 100644
--- a/server/geoip_service.py
+++ b/server/geoip_service.py
@@ -28,7 +28,8 @@ def __init__(self):
self.db = None
# crontab: min hour day month day_of_week
- # Run every Wednesday because GeoLite2 is updated every Tuesday
+ # Run every Wednesday because GeoLite2 is updated every first Tuesday
+ # of the month.
self._update_cron = aiocron.crontab(
'0 0 0 * * 3', func=self.check_update_geoip_db
)
diff --git a/server/ladder_service.py b/server/ladder_service.py
index 2e0e73386..c48d727ac 100644
--- a/server/ladder_service.py
+++ b/server/ladder_service.py
@@ -3,16 +3,16 @@
from collections import defaultdict
from typing import Dict, List, NamedTuple, Set
+from server.db import FAFDatabase
+from server.rating import RatingType
from sqlalchemy import and_, func, select, text
-from server.db import FAFDatabase
from .config import LADDER_ANTI_REPETITION_LIMIT
from .db.models import game_featuredMods, game_player_stats, game_stats
from .decorators import with_logger
from .game_service import GameService
from .matchmaker import MatchmakerQueue, Search
from .players import Player, PlayerState
-from server.rating import RatingType
MapDescription = NamedTuple('Map', [("id", int), ("name", str), ("path", str)])
@@ -37,7 +37,7 @@ def __init__(self, database: FAFDatabase, games_service: GameService):
asyncio.ensure_future(self.handle_queue_matches())
- def start_search(self, initiator: Player, search: Search, queue_name: str):
+ async def start_search(self, initiator: Player, search: Search, queue_name: str):
self._cancel_existing_searches(initiator)
for player in search.players:
@@ -45,7 +45,7 @@ def start_search(self, initiator: Player, search: Search, queue_name: str):
# For now, inform_player is only designed for ladder1v1
if queue_name == "ladder1v1":
- self.inform_player(player)
+ await self.inform_player(player)
self.searches[queue_name][initiator] = search
@@ -76,13 +76,13 @@ def _cancel_existing_searches(self, initiator: Player) -> List[Search]:
del self.searches[queue_name][initiator]
return searches
- def inform_player(self, player: Player):
+ async def inform_player(self, player: Player):
if player not in self._informed_players:
self._informed_players.add(player)
mean, deviation = player.ratings[RatingType.LADDER_1V1]
if deviation > 490:
- player.lobby_connection.send({
+ await player.lobby_connection.send({
"command": "notice",
"style": "info",
"text": (
@@ -97,7 +97,7 @@ def inform_player(self, player: Player):
})
elif deviation > 250:
progress = (500.0 - deviation) / 2.5
- player.lobby_connection.send({
+ await player.lobby_connection.send({
"command": "notice",
"style": "info",
"text": (
@@ -113,8 +113,10 @@ async def handle_queue_matches(self):
assert len(s2.players) == 1
p1, p2 = s1.players[0], s2.players[0]
msg = {"command": "match_found", "queue": "ladder1v1"}
- p1.lobby_connection.send(msg)
- p2.lobby_connection.send(msg)
+ await asyncio.gather(
+ p1.lobby_connection.send(msg),
+ p2.lobby_connection.send(msg)
+ )
asyncio.ensure_future(self.start_game(p1, p2))
except Exception as e:
self._logger.exception(
@@ -159,15 +161,17 @@ async def start_game(self, host: Player, guest: Player):
# FIXME: Database filenames contain the maps/ prefix and .zip suffix.
# Really in the future, just send a better description
self._logger.debug("Starting ladder game: %s", game)
- host.lobby_connection.launch_game(game, is_host=True, use_map=mapname)
+ await host.lobby_connection.launch_game(game, is_host=True, use_map=mapname)
try:
hosted = await game.await_hosted()
if not hosted:
raise TimeoutError("Host left lobby")
except TimeoutError:
msg = {"command": "game_launch_timeout"}
- host.lobby_connection.send(msg)
- guest.lobby_connection.send(msg)
+ await asyncio.gather(
+ host.lobby_connection.send(msg),
+ guest.lobby_connection.send(msg)
+ )
# TODO: Uncomment this line once the client supports `game_launch_timeout`.
# Until then, returning here will cause the client to think it is
# searching for ladder, even though the server has already removed it
@@ -175,7 +179,7 @@ async def start_game(self, host: Player, guest: Player):
# return
self._logger.debug("Ladder game failed to launch due to a timeout")
- guest.lobby_connection.launch_game(
+ await guest.lobby_connection.launch_game(
game, is_host=False, use_map=mapname
)
self._logger.debug("Ladder game launched successfully")
diff --git a/server/lobbyconnection.py b/server/lobbyconnection.py
index af7189d28..438bd79db 100644
--- a/server/lobbyconnection.py
+++ b/server/lobbyconnection.py
@@ -97,13 +97,13 @@ def on_connection_made(self, protocol: QDataStreamProtocol, peername: Address):
self.peer_address = peername
server.stats.incr('server.connections')
- def abort(self, logspam=""):
+ async def abort(self, logspam=""):
if self.player:
self._logger.warning("Client %s dropped. %s" % (self.player.login, logspam))
else:
self._logger.warning("Aborting %s. %s" % (self.peer_address.host, logspam))
if self.game_connection:
- self.game_connection.abort()
+ await self.game_connection.abort()
self.game_connection = None
self._authenticated = False
self.protocol.writer.close()
@@ -113,11 +113,11 @@ def abort(self, logspam=""):
self.player = None
server.stats.incr('server.connections.aborted')
- def ensure_authenticated(self, cmd):
+ async def ensure_authenticated(self, cmd):
if not self._authenticated:
if cmd not in ['hello', 'ask_session', 'create_account', 'ping', 'pong', 'Bottleneck']: # Bottleneck is sent by the game during reconnect
server.stats.incr('server.received_messages.unauthenticated', tags={"command": cmd})
- self.abort("Message invalid for unauthenticated connection: %s" % cmd)
+ await self.abort("Message invalid for unauthenticated connection: %s" % cmd)
return False
return True
@@ -129,7 +129,7 @@ async def on_message_received(self, message):
try:
cmd = message['command']
- if not self.ensure_authenticated(cmd):
+ if not await self.ensure_authenticated(cmd):
return
target = message.get('target')
if target == 'game':
@@ -144,43 +144,41 @@ async def on_message_received(self, message):
raise ClientError("Your client version is no longer supported. Please update to the newest version: https://faforever.com")
handler = getattr(self, 'command_{}'.format(cmd))
- if asyncio.iscoroutinefunction(handler):
- await handler(message)
- else:
- handler(message)
+ await handler(message)
+
except AuthenticationError as ex:
- self.protocol.send_message(
- {'command': 'authentication_failed',
- 'text': ex.message}
- )
+ await self.send({
+ 'command': 'authentication_failed',
+ 'text': ex.message
+ })
except ClientError as ex:
self._logger.warning("Client error: %s", ex.message)
- self.protocol.send_message(
- {'command': 'notice',
- 'style': 'error',
- 'text': ex.message}
- )
+ await self.send({
+ 'command': 'notice',
+ 'style': 'error',
+ 'text': ex.message
+ })
if not ex.recoverable:
- self.abort(ex.message)
+ await self.abort(ex.message)
except (KeyError, ValueError) as ex:
self._logger.exception(ex)
- self.abort("Garbage command: {}".format(message))
+ await self.abort("Garbage command: {}".format(message))
except Exception as ex: # pragma: no cover
- self.protocol.send_message({'command': 'invalid'})
+ await self.send({'command': 'invalid'})
self._logger.exception(ex)
- self.abort("Error processing command")
+ await self.abort("Error processing command")
- def command_ping(self, msg):
- self.protocol.send_raw(self.protocol.pack_message('PONG'))
+ async def command_ping(self, msg):
+ await self.protocol.send_raw(self.protocol.pack_message('PONG'))
- def command_pong(self, msg):
+ async def command_pong(self, msg):
pass
- @asyncio.coroutine
- def command_create_account(self, message):
+ async def command_create_account(self, message):
raise ClientError("FAF no longer supports direct registration. Please use the website to register.", recoverable=True)
- async def send_coop_maps(self):
+ async def command_coop_list(self, message):
+ """ Request for coop map list"""
async with self._db.acquire() as conn:
result = await conn.execute("SELECT name, description, filename, type, id FROM `coop_map`")
@@ -209,19 +207,16 @@ async def send_coop_maps(self):
json_to_send["uid"] = row["id"]
maps.append(json_to_send)
- self.protocol.send_messages(maps)
+ await self.protocol.send_messages(maps)
async def command_matchmaker_info(self, message):
- self.send({
+ await self.send({
'command': 'matchmaker_info',
'queues': [queue.to_dict() for queue in self.ladder_service.queues.values()]
})
- await self.protocol.drain()
-
- @timed()
- def send_game_list(self):
- self.send({
+ async def send_game_list(self):
+ await self.send({
'command': 'game_info',
'games': [game.to_dict() for game in self.game_service.open_games]
})
@@ -232,7 +227,7 @@ async def command_social_remove(self, message):
elif "foe" in message:
subject_id = message["foe"]
else:
- self.abort("No-op social_remove.")
+ await self.abort("No-op social_remove.")
return
async with self._db.acquire() as conn:
@@ -258,15 +253,15 @@ async def command_social_add(self, message):
subject_id=subject_id,
))
- def kick(self):
- self.send({
+ async def kick(self):
+ await self.send({
"command": "notice",
"style": "kick",
})
- self.abort()
+ await self.abort()
- def send_updated_achievements(self, updated_achievements):
- self.send({
+ async def send_updated_achievements(self, updated_achievements):
+ await self.send({
"command": "updated_achievements",
"updated_achievements": updated_achievements
})
@@ -333,7 +328,7 @@ async def command_admin(self, message):
for player in self.player_service:
try:
if player.lobby_connection:
- player.lobby_connection.send_warning(message.get('message'))
+ await player.lobby_connection.send_warning(message.get('message'))
except Exception as ex:
self._logger.debug("Could not send broadcast message to %s: %s".format(player, ex))
@@ -403,7 +398,7 @@ async def check_user_login(self, conn, username, password):
return player_id, real_username, steamid
- def check_version(self, message):
+ async def check_version(self, message):
versionDB, updateFile = self.player_service.client_version_info
update_msg = {
'command': 'update',
@@ -417,7 +412,7 @@ def check_version(self, message):
server.stats.gauge('user.agents.{}'.format(self.user_agent), 1, delta=True)
if not self.user_agent or 'downlords-faf-client' not in self.user_agent:
- self.send_warning(
+ await self.send_warning(
"You are using an unofficial client version! "
"Some features might not work as expected. "
"If you experience any problems please download the latest "
@@ -428,7 +423,7 @@ def check_version(self, message):
if not version or not self.user_agent:
update_msg['command'] = 'welcome'
# For compatibility with 0.10.x updating mechanism
- self.send(update_msg)
+ await self.send(update_msg)
return False
# Check their client is reporting the right version number.
@@ -439,10 +434,10 @@ def check_version(self, message):
if "+" in version:
version = version.split('+')[0]
if semver.compare(versionDB, version) > 0:
- self.send(update_msg)
+ await self.send(update_msg)
return False
except ValueError:
- self.send(update_msg)
+ await self.send(update_msg)
return False
return True
@@ -467,7 +462,7 @@ async def check_policy_conformity(self, player_id, uid_hash, session, ignore_res
if response.get('result', '') == 'vm':
self._logger.debug("Using VM: %d: %s", player_id, uid_hash)
- self.send({
+ await self.send({
"command": "notice",
"style": "error",
"text": (
@@ -477,7 +472,7 @@ async def check_policy_conformity(self, player_id, uid_hash, session, ignore_res
"positive."
)
})
- self.send_warning("Your computer seems to be a virtual machine.
In order to "
+ await self.send_warning("Your computer seems to be a virtual machine.
In order to "
"log in from a VM, you have to link your account to Steam: " +
config.WWW_URL + "/account/link.
If you need an exception, please contact an "
@@ -485,7 +480,7 @@ async def check_policy_conformity(self, player_id, uid_hash, session, ignore_res
if response.get('result', '') == 'already_associated':
self._logger.warning("UID hit: %d: %s", player_id, uid_hash)
- self.send_warning("Your computer is already associated with another FAF account.
In order to "
+ await self.send_warning("Your computer is already associated with another FAF account.
In order to "
"log in with an additional account, you have to link it to Steam: " +
config.WWW_URL + "/account/link.
If you need an exception, please contact an "
@@ -494,7 +489,7 @@ async def check_policy_conformity(self, player_id, uid_hash, session, ignore_res
if response.get('result', '') == 'fraudulent':
self._logger.info("Banning player %s for fraudulent looking login.", player_id)
- self.send_warning("Fraudulent login attempt detected. As a precautionary measure, your account has been "
+ await self.send_warning("Fraudulent login attempt detected. As a precautionary measure, your account has been "
"banned permanently. Please contact an admin or moderator on the forums if you feel this is "
"a false positive.",
fatal=True)
@@ -566,7 +561,7 @@ async def command_hello(self, message):
if old_player:
self._logger.debug("player {} already signed in: {}".format(self.player.id, old_player))
if old_player.lobby_connection:
- old_player.lobby_connection.send_warning("You have been signed out because you signed in elsewhere.", fatal=True)
+ await old_player.lobby_connection.send_warning("You have been signed out because you signed in elsewhere.", fatal=True)
old_player.lobby_connection.game_connection = None
old_player.lobby_connection.player = None
self._logger.debug("Removing previous game_connection and player reference of player {} in hope on_connection_lost() wouldn't drop her out of the game".format(self.player.id))
@@ -581,7 +576,7 @@ async def command_hello(self, message):
self.player.country = self.geoip_service.country(self.peer_address.host)
# Send the player their own player info.
- self.send({
+ await self.send({
"command": "welcome",
"me": self.player.to_dict(),
@@ -591,12 +586,10 @@ async def command_hello(self, message):
})
# Tell player about everybody online. This must happen after "welcome".
- self.send(
- {
- "command": "player_info",
- "players": [player.to_dict() for player in self.player_service]
- }
- )
+ await self.send({
+ "command": "player_info",
+ "players": [player.to_dict() for player in self.player_service]
+ })
# Tell everyone else online about us. This must happen after all the player_info messages.
# This ensures that no other client will perform an operation that interacts with the
@@ -630,21 +623,21 @@ async def command_hello(self, message):
channels.append("#%s_clan" % self.player.clan)
json_to_send = {"command": "social", "autojoin": channels, "channels": channels, "friends": friends, "foes": foes, "power": permission_group}
- self.send(json_to_send)
+ await self.send(json_to_send)
- self.send_game_list()
+ await self.send_game_list()
- def command_restore_game_session(self, message):
+ async def command_restore_game_session(self, message):
game_id = int(message.get('game_id'))
# Restore the player's game connection, if the game still exists and is live
if not game_id or game_id not in self.game_service:
- self.send_warning("The game you were connected to does no longer exist")
+ await self.send_warning("The game you were connected to does no longer exist")
return
game = self.game_service[game_id] # type: Game
if game.state != GameState.LOBBY and game.state != GameState.LIVE:
- self.send_warning("The game you were connected to is no longer available")
+ await self.send_warning("The game you were connected to is no longer available")
return
self._logger.debug("Restoring game session of player %s to game %s", self.player, game)
@@ -663,10 +656,9 @@ def command_restore_game_session(self, message):
if not hasattr(self.player, "game"):
self.player.game = game
- @timed
- def command_ask_session(self, message):
- if self.check_version(message):
- self.send({"command": "session", "session": self.session})
+ async def command_ask_session(self, message):
+ if await self.check_version(message):
+ await self.send({"command": "session", "session": self.session})
async def command_avatar(self, message):
action = message['action']
@@ -684,7 +676,7 @@ async def command_avatar(self, message):
avatarList.append(avatar)
if avatarList:
- self.send({"command": "avatar", "avatarlist": avatarList})
+ await self.send({"command": "avatar", "avatarlist": avatarList})
elif action == "select":
avatar = message['avatar']
@@ -719,7 +711,7 @@ async def command_game_join(self, message):
game = self.game_service[uuid]
if not game or game.state != GameState.LOBBY:
self._logger.debug("Game not in lobby state: %s", game)
- self.send({
+ await self.send({
"command": "notice",
"style": "info",
"text": "The game you are trying to join is not ready."
@@ -727,17 +719,17 @@ async def command_game_join(self, message):
return
if game.password != password:
- self.send({
+ await self.send({
"command": "notice",
"style": "info",
"text": "Bad password (it's case sensitive)"
})
return
- self.launch_game(game, is_host=False)
+ await self.launch_game(game, is_host=False)
except KeyError:
- self.send({
+ await self.send({
"command": "notice",
"style": "info",
"text": "The host has left the game"
@@ -765,11 +757,7 @@ async def command_game_matchmaking(self, message):
# TODO: Put player parties here
search = Search([self.player])
- self.ladder_service.start_search(self.player, search, queue_name=mod)
-
- def command_coop_list(self, message):
- """ Request for coop map list"""
- asyncio.ensure_future(self.send_coop_maps())
+ await self.ladder_service.start_search(self.player, search, queue_name=mod)
async def command_game_host(self, message):
assert isinstance(self.player, Player)
@@ -782,7 +770,7 @@ async def command_game_host(self, message):
visibility = VisibilityState.from_string(message.get('visibility'))
if not isinstance(visibility, VisibilityState):
# Protocol violation.
- self.abort("{} sent a nonsense visibility code: {}".format(self.player.login, message.get('visibility')))
+ await self.abort("{} sent a nonsense visibility code: {}".format(self.player.login, message.get('visibility')))
return
title = html.escape(message.get('title') or f"{self.player.login}'s game")
@@ -790,7 +778,7 @@ async def command_game_host(self, message):
try:
title.encode('ascii')
except UnicodeEncodeError:
- self.send({
+ await self.send({
"command": "notice",
"style": "error",
"text": "Non-ascii characters in game name detected."
@@ -810,13 +798,13 @@ async def command_game_host(self, message):
mapname=mapname,
password=password
)
- self.launch_game(game, is_host=True)
+ await self.launch_game(game, is_host=True)
server.stats.incr('game.hosted', tags={'game_mode': game_mode})
- def launch_game(self, game, is_host=False, use_map=None):
+ async def launch_game(self, game, is_host=False, use_map=None):
# TODO: Fix setting up a ridiculous amount of cyclic pointers here
if self.game_connection:
- self.game_connection.abort("Player launched a new game")
+ await self.game_connection.abort("Player launched a new game")
if is_host:
game.host = self.player
@@ -840,7 +828,7 @@ def launch_game(self, game, is_host=False, use_map=None):
}
if use_map:
cmd['mapname'] = use_map
- self.send(cmd)
+ await self.send(cmd)
async def command_modvault(self, message):
type = message["type"]
@@ -861,7 +849,7 @@ async def command_modvault(self, message):
comments=[], description=description, played=played, likes=likes,
downloads=downloads, date=int(date.timestamp()), uid=uid, name=name, version=version, author=author,
ui=ui)
- self.send(out)
+ await self.send(out)
except:
self._logger.error("Error handling table_mod row (uid: {})".format(uid), exc_info=True)
pass
@@ -899,7 +887,7 @@ async def command_modvault(self, message):
"JOIN mod_version v ON v.mod_id = s.mod_id "
"SET s.likes = s.likes + 1, likers=%s WHERE v.uid = %s",
json.dumps(likers), uid)
- self.send(out)
+ await self.send(out)
elif type == "download":
uid = message["uid"]
@@ -910,7 +898,6 @@ async def command_modvault(self, message):
else:
raise ValueError('invalid type argument')
- @asyncio.coroutine
async def command_ice_servers(self, message):
if not self.player:
return
@@ -924,13 +911,13 @@ async def command_ice_servers(self, message):
if self.nts_client:
ice_servers = ice_servers + await self.nts_client.server_tokens(ttl=ttl)
- self.send({
+ await self.send({
'command': 'ice_servers',
'ice_servers': ice_servers,
'ttl': ttl
})
- def send_warning(self, message: str, fatal: bool=False):
+ async def send_warning(self, message: str, fatal: bool=False):
"""
Display a warning message to the client
:param message: Warning message to display
@@ -939,20 +926,22 @@ def send_warning(self, message: str, fatal: bool=False):
and not attempt to reconnect.
:return: None
"""
- self.send({'command': 'notice',
- 'style': 'info' if not fatal else 'error',
- 'text': message})
+ await self.send({
+ 'command': 'notice',
+ 'style': 'info' if not fatal else 'error',
+ 'text': message
+ })
if fatal:
- self.abort(message)
+ await self.abort(message)
- def send(self, message):
+ async def send(self, message):
"""
:param message:
:return:
"""
self._logger.log(TRACE, ">>: %s", message)
- self.protocol.send_message(message)
+ await self.protocol.send_message(message)
async def drain(self):
await self.protocol.drain()
diff --git a/server/matchmaker/search.py b/server/matchmaker/search.py
index 9831307cf..8bef98abe 100644
--- a/server/matchmaker/search.py
+++ b/server/matchmaker/search.py
@@ -3,12 +3,12 @@
import time
from typing import List, Optional, Tuple
+from server.rating import RatingType
from trueskill import Rating, quality
from .. import config
from ..decorators import with_logger
from ..players import Player
-from server.rating import RatingType
@with_logger
@@ -71,7 +71,6 @@ def has_no_top_player(self) -> bool:
max_rating = max(map(lambda rating_tuple: rating_tuple[0], self.ratings))
return max_rating < config.TOP_PLAYER_MIN_RATING
-
@property
def ratings(self):
ratings = []
diff --git a/server/protocol/gpgnet.py b/server/protocol/gpgnet.py
index 72dd1ca8a..41af9ccc5 100644
--- a/server/protocol/gpgnet.py
+++ b/server/protocol/gpgnet.py
@@ -6,44 +6,44 @@ class GpgNetServerProtocol(metaclass=ABCMeta):
"""
Defines an interface for the server side GPGNet protocol
"""
- def send_ConnectToPeer(self, player_name: str, player_uid: int, offer: bool):
+ async def send_ConnectToPeer(self, player_name: str, player_uid: int, offer: bool):
"""
Tells a client that has a listening LobbyComm instance to connect to the given peer
:param player_name: Remote player name
:param player_uid: Remote player identifier
"""
- self.send_gpgnet_message('ConnectToPeer', [player_name, player_uid, offer])
+ await self.send_gpgnet_message('ConnectToPeer', [player_name, player_uid, offer])
- def send_JoinGame(self, remote_player_name: str, remote_player_uid: int):
+ async def send_JoinGame(self, remote_player_name: str, remote_player_uid: int):
"""
Tells the game to join the given peer by ID
:param remote_player_name:
:param remote_player_uid:
"""
- self.send_gpgnet_message('JoinGame', [remote_player_name, remote_player_uid])
+ await self.send_gpgnet_message('JoinGame', [remote_player_name, remote_player_uid])
- def send_HostGame(self, map):
+ async def send_HostGame(self, map):
"""
Tells the game to start listening for incoming connections as a host
:param map: Which scenario to use
"""
- self.send_gpgnet_message('HostGame', [str(map)])
+ await self.send_gpgnet_message('HostGame', [str(map)])
- def send_DisconnectFromPeer(self, id: int):
+ async def send_DisconnectFromPeer(self, id: int):
"""
Instructs the game to disconnect from the peer given by id
:param id:
:return:
"""
- self.send_gpgnet_message('DisconnectFromPeer', [id])
+ await self.send_gpgnet_message('DisconnectFromPeer', [id])
- def send_gpgnet_message(self, command_id, arguments):
+ async def send_gpgnet_message(self, command_id, arguments):
message = {"command": command_id, "args": arguments}
- self.send_message(message)
+ await self.send_message(message)
@abstractmethod
- def send_message(self, message):
+ async def send_message(self, message):
pass # pragma: no cover
diff --git a/server/protocol/protocol.py b/server/protocol/protocol.py
index 99018cd76..403839257 100644
--- a/server/protocol/protocol.py
+++ b/server/protocol/protocol.py
@@ -14,7 +14,7 @@ async def read_message(self: 'Protocol') -> dict:
pass # pragma: no cover
@abstractmethod
- def send_message(self, message: dict) -> None:
+ async def send_message(self, message: dict) -> None:
"""
Send a single message in the form of a dictionary
@@ -23,7 +23,7 @@ def send_message(self, message: dict) -> None:
pass # pragma: no cover
@abstractmethod
- def send_messages(self, messages: List[dict]) -> None:
+ async def send_messages(self, messages: List[dict]) -> None:
"""
Send multiple messages in the form of a list of dictionaries.
@@ -34,7 +34,7 @@ def send_messages(self, messages: List[dict]) -> None:
pass # pragma: no cover
@abstractmethod
- def send_raw(self, data: bytes) -> None:
+ async def send_raw(self, data: bytes) -> None:
"""
Send raw bytes. Should generally not be used.
diff --git a/server/protocol/qdatastreamprotocol.py b/server/protocol/qdatastreamprotocol.py
index 78baa5404..d0ef572a4 100644
--- a/server/protocol/qdatastreamprotocol.py
+++ b/server/protocol/qdatastreamprotocol.py
@@ -121,15 +121,6 @@ async def read_message(self):
pass
return message
- async def drain(self):
- """
- Await the write buffer to empty.
-
- See StreamWriter.drain()
- """
- await asyncio.sleep(0)
- await self.writer.drain()
-
def close(self):
"""
Close writer stream
@@ -137,20 +128,23 @@ def close(self):
"""
self.writer.close()
- def send_message(self, message: dict):
+ async def send_message(self, message: dict):
+ server.stats.incr('server.sent_messages')
self.writer.write(
self.pack_message(json.dumps(message, separators=(',', ':')))
)
- server.stats.incr('server.sent_messages')
+ await self.writer.drain()
- def send_messages(self, messages):
+ async def send_messages(self, messages):
server.stats.incr('server.sent_messages')
payload = [
self.pack_message(json.dumps(msg, separators=(',', ':')))
for msg in messages
]
self.writer.writelines(payload)
+ await self.writer.drain()
- def send_raw(self, data):
+ async def send_raw(self, data):
server.stats.incr('server.sent_messages')
self.writer.write(data)
+ await self.writer.drain()
diff --git a/server/servercontext.py b/server/servercontext.py
index ff20a6ead..d75f3ea13 100644
--- a/server/servercontext.py
+++ b/server/servercontext.py
@@ -46,11 +46,14 @@ def close(self):
def __contains__(self, connection):
return connection in self.connections.keys()
- def broadcast_raw(self, message, validate_fn=lambda a: True):
+ async def broadcast_raw(self, message, validate_fn=lambda a: True):
server.stats.incr('server.broadcasts')
+ tasks = []
for conn, proto in self.connections.items():
if validate_fn(conn):
- proto.send_raw(message)
+ tasks.append(proto.send_raw(message))
+
+ await asyncio.gather(*tasks)
async def client_connected(self, stream_reader, stream_writer):
self._logger.debug("%s: Client connected", self)
@@ -64,9 +67,6 @@ async def client_connected(self, stream_reader, stream_writer):
message = await protocol.read_message()
with server.stats.timer('connection.on_message_received'):
await connection.on_message_received(message)
- with server.stats.timer('servercontext.drain'):
- await asyncio.sleep(0)
- await connection.drain()
except ConnectionResetError:
pass
except ConnectionAbortedError:
diff --git a/server/timing/__init__.py b/server/timing/__init__.py
new file mode 100644
index 000000000..c23da5060
--- /dev/null
+++ b/server/timing/__init__.py
@@ -0,0 +1,3 @@
+from .timer import Timer, at_interval
+
+__all__ = ("Timer", "at_interval")
diff --git a/server/timing/timer.py b/server/timing/timer.py
new file mode 100644
index 000000000..1adea10a4
--- /dev/null
+++ b/server/timing/timer.py
@@ -0,0 +1,98 @@
+"""
+This code is a modified version of Gael Pasgrimaud's library `aiocron`.
+
+See the original code here:
+https://github.com/gawel/aiocron/blob/e82a53c3f9a7950209cee7b3e493204c1dfc8b12/aiocron/__init__.py
+"""
+
+import asyncio
+import functools
+
+
+async def null_callback(*args):
+ return args
+
+
+def wrap_func(func):
+ """wrap in a coroutine"""
+ if not asyncio.iscoroutinefunction(func):
+ return asyncio.coroutine(func)
+ return func
+
+
+class Timer(object):
+ """Schedules a function to be called asynchronously on a fixed interval"""
+
+ def __init__(self, interval, func=None, args=(), start=False, loop=None):
+ self.interval = interval
+ if func is not None:
+ self.func = func if not args else functools.partial(func, *args)
+ else:
+ self.func = null_callback
+ self.cron = wrap_func(self.func)
+ self.auto_start = start
+ self.handle = self.future = None
+ self.loop = loop if loop is not None else asyncio.get_event_loop()
+ if start and self.func is not null_callback:
+ self.handle = self.loop.call_soon_threadsafe(self.start)
+
+ def start(self):
+ """Start scheduling"""
+ self.stop()
+ self.handle = self.loop.call_later(self.get_delay(), self.call_next)
+
+ def stop(self):
+ """Stop scheduling"""
+ if self.handle is not None:
+ self.handle.cancel()
+ self.handle = self.future = None
+
+ def get_delay(self):
+ """Return next interval to wait between calls"""
+ return self.interval
+
+ def call_next(self):
+ """Set next hop in the loop. Call task"""
+ if self.handle is not None:
+ self.handle.cancel()
+ delay = self.get_delay()
+ self.handle = self.loop.call_later(delay, self.call_next)
+ self.call_func()
+
+ def call_func(self, *args, **kwargs):
+ """Called. Take care of exceptions using gather"""
+ asyncio.gather(
+ self.cron(*args, **kwargs),
+ loop=self.loop, return_exceptions=True
+ ).add_done_callback(self.set_result)
+
+ def set_result(self, result):
+ """Set future's result if needed (can be an exception).
+ Else raise if needed."""
+ result = result.result()[0]
+ if self.future is not None:
+ if isinstance(result, Exception):
+ self.future.set_exception(result)
+ else:
+ self.future.set_result(result)
+ self.future = None
+ elif isinstance(result, Exception):
+ raise result
+
+ def __call__(self, func):
+ """Used as a decorator"""
+ self.func = func
+ self.cron = wrap_func(func)
+ if self.auto_start:
+ self.loop.call_soon_threadsafe(self.start)
+ return self
+
+ def __str__(self):
+ return f"{self.interval} {self.func}"
+
+ def __repr__(self):
+ return f""
+
+
+def at_interval(interval, func=None, args=(), start=True, loop=None):
+ return Timer(interval, func=func, args=args, start=start, loop=loop)
diff --git a/tests/conftest.py b/tests/conftest.py
index 967420786..deb371192 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -11,20 +11,20 @@
from typing import Iterable
from unittest import mock
+import asynctest
import pytest
+from asynctest import CoroutineMock
from server.api.api_accessor import ApiAccessor
from server.config import DB_LOGIN, DB_PASSWORD, DB_PORT, DB_SERVER
+from server.db import FAFDatabase
from server.game_service import GameService
from server.geoip_service import GeoIpService
+from server.lobbyconnection import LobbyConnection
from server.matchmaker import MatchmakerQueue
from server.player_service import PlayerService
from server.rating import RatingType
-from server.db import FAFDatabase
from tests.utils import MockDatabase
-from asynctest import CoroutineMock
-import asynctest
-
logging.getLogger().setLevel(logging.DEBUG)
@@ -159,6 +159,11 @@ def make(state=PlayerState.IDLE, global_rating=None, ladder_rating=None,
p = Player(ratings=ratings, game_count=games, **kwargs)
p.state = state
+
+ # lobby_connection is a weak reference, but we want the mock
+ # to live for the full lifetime of the player object
+ p.__owned_lobby_connection = asynctest.create_autospec(LobbyConnection)
+ p.lobby_connection = p.__owned_lobby_connection
return p
return make
diff --git a/tests/data/test-data.sql b/tests/data/test-data.sql
index 4d9c2f5c5..f9add180f 100644
--- a/tests/data/test-data.sql
+++ b/tests/data/test-data.sql
@@ -1,4 +1,5 @@
insert into login (id, login, email, password, steamid, create_time) values
+ (10, 'friends', 'friends@example.com', SHA2('friends', 256), null, '2000-01-01 00:00:00'),
(50, 'player_service1', 'ps1@example.com', SHA2('player_service1', 256), null, '2000-01-01 00:00:00'),
(51, 'player_service2', 'ps2@example.com', SHA2('player_service2', 256), null, '2000-01-01 00:00:00'),
(52, 'player_service3', 'ps3@example.com', SHA2('player_service3', 256), null, '2000-01-01 00:00:00'),
@@ -69,7 +70,8 @@ insert into game_player_stats (gameId, playerId, AI, faction, color, team, place
delete from friends_and_foes where user_id = 1 and subject_id = 2;
insert into friends_and_foes (user_id, subject_id, status) values
- (2, 1, 'FRIEND');
+ (2, 1, 'FRIEND'),
+ (10, 1, 'FRIEND');
insert into `mod` (id, display_name, author) values
diff --git a/tests/integration_tests/conftest.py b/tests/integration_tests/conftest.py
index 441d92a28..b379874f1 100644
--- a/tests/integration_tests/conftest.py
+++ b/tests/integration_tests/conftest.py
@@ -110,7 +110,7 @@ async def perform_login(
) -> None:
login, pw = credentials
pw_hash = hashlib.sha256(pw.encode('utf-8'))
- proto.send_message({
+ await proto.send_message({
'command': 'hello',
'version': '1.0.0-dev',
'user_agent': 'faf-client',
@@ -118,7 +118,6 @@ async def perform_login(
'password': pw_hash.hexdigest(),
'unique_id': 'some_id'
})
- await proto.drain()
async def read_until(
@@ -139,8 +138,7 @@ async def read_until_command(proto: QDataStreamProtocol, command: str) -> Dict[s
async def get_session(proto):
- proto.send_message({'command': 'ask_session', 'user_agent': 'faf-client', 'version': '0.11.16'})
- await proto.drain()
+ await proto.send_message({'command': 'ask_session', 'user_agent': 'faf-client', 'version': '0.11.16'})
msg = await read_until_command(proto, 'session')
return msg['session']
diff --git a/tests/integration_tests/test_matchmaker.py b/tests/integration_tests/test_matchmaker.py
index d9ba3e03b..fe99ce27c 100644
--- a/tests/integration_tests/test_matchmaker.py
+++ b/tests/integration_tests/test_matchmaker.py
@@ -19,19 +19,17 @@ async def queue_players_for_matchmaking(lobby_server):
await read_until_command(proto1, 'game_info')
await read_until_command(proto2, 'game_info')
- proto1.send_message({
+ await proto1.send_message({
'command': 'game_matchmaking',
'state': 'start',
'faction': 'uef'
})
- await proto1.drain()
- proto2.send_message({
+ await proto2.send_message({
'command': 'game_matchmaking',
'state': 'start',
'faction': 1 # Python client sends factions as numbers
})
- await proto2.drain()
# If the players did not match, this will fail due to a timeout error
await read_until_command(proto1, 'match_found')
@@ -46,7 +44,7 @@ async def test_game_matchmaking(lobby_server):
# The player that queued last will be the host
msg2 = await read_until_command(proto2, 'game_launch')
- proto2.send_message({
+ await proto2.send_message({
'command': 'GameState',
'target': 'game',
'args': ['Lobby']
@@ -58,7 +56,7 @@ async def test_game_matchmaking(lobby_server):
assert msg2['mod'] == 'ladder1v1'
-@fast_forward(1000)
+@fast_forward(100)
async def test_matchmaker_info_message(lobby_server, mocker):
mocker.patch('server.matchmaker.pop_timer.time', return_value=1_562_000_000)
@@ -95,8 +93,7 @@ async def test_command_matchmaker_info(lobby_server, mocker):
await read_until_command(proto, "game_info")
- proto.send_message({"command": "matchmaker_info"})
- await proto.drain()
+ await proto.send_message({"command": "matchmaker_info"})
msg = await read_until_command(proto, "matchmaker_info")
assert msg == {
@@ -120,7 +117,7 @@ async def test_matchmaker_info_message_on_cancel(lobby_server):
await read_until_command(proto, 'game_info')
- proto.send_message({
+ await proto.send_message({
'command': 'game_matchmaking',
'state': 'start',
'faction': 'uef'
@@ -132,7 +129,7 @@ async def test_matchmaker_info_message_on_cancel(lobby_server):
assert msg["queues"][0]["queue_name"] == "ladder1v1"
assert len(msg["queues"][0]["boundary_80s"]) == 1
- proto.send_message({
+ await proto.send_message({
'command': 'game_matchmaking',
'state': 'stop',
})
diff --git a/tests/integration_tests/test_modvault.py b/tests/integration_tests/test_modvault.py
index 0d6e7118f..855ad4da9 100644
--- a/tests/integration_tests/test_modvault.py
+++ b/tests/integration_tests/test_modvault.py
@@ -1,6 +1,7 @@
-from .conftest import connect_and_sign_in, read_until_command
import pytest
+from .conftest import connect_and_sign_in, read_until_command
+
pytestmark = pytest.mark.asyncio
@@ -12,11 +13,10 @@ async def test_modvault_start(lobby_server):
await read_until_command(proto, 'game_info')
- proto.send_message({
+ await proto.send_message({
'command': 'modvault',
'type': 'start'
})
- await proto.drain()
# Make sure all 5 mod version messages are sent
for _ in range(5):
@@ -31,12 +31,11 @@ async def test_modvault_like(lobby_server):
await read_until_command(proto, 'game_info')
- proto.send_message({
+ await proto.send_message({
'command': 'modvault',
'type': 'like',
'uid': 'FFF'
})
- await proto.drain()
msg = await read_until_command(proto, 'modvault_info')
# Not going to verify the date
diff --git a/tests/integration_tests/test_server.py b/tests/integration_tests/test_server.py
index 35f9ecf41..e25cdec29 100644
--- a/tests/integration_tests/test_server.py
+++ b/tests/integration_tests/test_server.py
@@ -1,9 +1,9 @@
import asyncio
-from unittest import mock
import pytest
from server import VisibilityState
from server.db.models import ban
+from tests.utils import fast_forward
from .conftest import (
connect_and_sign_in, connect_client, perform_login, read_until,
@@ -17,15 +17,13 @@
async def test_server_deprecated_client(lobby_server):
proto = await connect_client(lobby_server)
- proto.send_message({'command': 'ask_session', 'user_agent': 'faf-client', 'version': '0.0.0'})
- await proto.drain()
+ await proto.send_message({'command': 'ask_session', 'user_agent': 'faf-client', 'version': '0.0.0'})
msg = await proto.read_message()
assert msg['command'] == 'notice'
proto = await connect_client(lobby_server)
- proto.send_message({'command': 'ask_session', 'version': '0.0.0'})
- await proto.drain()
+ await proto.send_message({'command': 'ask_session', 'version': '0.0.0'})
msg = await proto.read_message()
assert msg['command'] == 'notice'
@@ -132,6 +130,14 @@ async def test_server_double_login(lobby_server):
await lobby_server.wait_closed()
+@fast_forward(50)
+async def test_ping_message(lobby_server):
+ _, _, proto = await connect_and_sign_in(('test', 'test_password'), lobby_server)
+
+ # We should receive the message every 45 seconds
+ await asyncio.wait_for(read_until_command(proto, 'ping'), 46)
+
+
async def test_player_info_broadcast(lobby_server):
p1 = await connect_client(lobby_server)
p2 = await connect_client(lobby_server)
@@ -155,13 +161,12 @@ async def test_info_broadcast_authenticated(lobby_server):
await perform_login(proto1, ('test', 'test_password'))
await perform_login(proto2, ('Rhiza', 'puff_the_magic_dragon'))
- proto1.send_message({
+ await proto1.send_message({
"command": "game_matchmaking",
"state": "start",
"mod": "ladder1v1",
"faction": "uef"
})
- await proto1.drain()
# Will timeout if the message is never received
await read_until_command(proto2, "matchmaker_info")
with pytest.raises(asyncio.TimeoutError):
@@ -170,6 +175,70 @@ async def test_info_broadcast_authenticated(lobby_server):
assert False
+async def test_game_info_not_broadcast_to_foes(lobby_server):
+ # Rhiza is foed by test
+ _, _, proto1 = await connect_and_sign_in(
+ ("test", "test_password"), lobby_server
+ )
+ _, _, proto2 = await connect_and_sign_in(
+ ("Rhiza", "puff_the_magic_dragon"), lobby_server
+ )
+ await read_until_command(proto1, "game_info")
+ await read_until_command(proto2, "game_info")
+
+ await proto1.send_message({
+ "command": "game_host",
+ "title": "No Foes Allowed",
+ "mod": "faf",
+ "visibility": "public"
+ })
+
+ msg = await read_until_command(proto1, "game_info")
+
+ assert msg["featured_mod"] == "faf"
+ assert msg["title"] == "No Foes Allowed"
+ assert msg["visibility"] == "public"
+
+ with pytest.raises(asyncio.TimeoutError):
+ await asyncio.wait_for(read_until_command(proto2, "game_info"), 0.2)
+
+
+async def test_game_info_broadcast_to_friends(lobby_server):
+ # test is the friend of friends
+ _, _, proto1 = await connect_and_sign_in(
+ ("friends", "friends"), lobby_server
+ )
+ _, _, proto2 = await connect_and_sign_in(
+ ("test", "test_password"), lobby_server
+ )
+ _, _, proto3 = await connect_and_sign_in(
+ ("Rhiza", "puff_the_magic_dragon"), lobby_server
+ )
+ await read_until_command(proto1, "game_info")
+ await read_until_command(proto2, "game_info")
+ await read_until_command(proto3, "game_info")
+
+ await proto1.send_message({
+ "command": "game_host",
+ "title": "Friends Only",
+ "mod": "faf",
+ "visibility": "friends"
+ })
+
+ # The host and his friend should see the game
+ msg = await read_until_command(proto1, "game_info")
+ msg2 = await read_until_command(proto2, "game_info")
+
+ assert msg == msg2
+ assert msg["featured_mod"] == "faf"
+ assert msg["title"] == "Friends Only"
+ assert msg["visibility"] == "friends"
+
+ # However, the other person should not see the game
+ with pytest.raises(asyncio.TimeoutError):
+ await asyncio.wait_for(read_until_command(proto3, "game_info"), 0.2)
+
+
@pytest.mark.parametrize("user", [
("test", "test_password"),
("ban_revoked", "ban_revoked"),
@@ -181,13 +250,12 @@ async def test_game_host_authenticated(lobby_server, user):
_, _, proto = await connect_and_sign_in(user, lobby_server)
await read_until_command(proto, 'game_info')
- proto.send_message({
+ await proto.send_message({
'command': 'game_host',
'title': 'My Game',
'mod': 'faf',
'visibility': 'public',
})
- await proto.drain()
msg = await read_until_command(proto, 'game_launch')
@@ -205,13 +273,12 @@ async def test_host_missing_fields(event_loop, lobby_server, player_service):
await read_until_command(proto, 'game_info')
- proto.send_message({
+ await proto.send_message({
'command': 'game_host',
'mod': '',
- 'visibility': VisibilityState.to_string(VisibilityState.PUBLIC),
+ 'visibility': 'public',
'title': ''
})
- await proto.drain()
msg = await read_until_command(proto, 'game_info')
@@ -229,8 +296,7 @@ async def test_coop_list(lobby_server):
await read_until_command(proto, 'game_info')
- proto.send_message({"command": "coop_list"})
- await proto.drain()
+ await proto.send_message({"command": "coop_list"})
msg = await read_until_command(proto, "coop_info")
assert "name" in msg
@@ -261,8 +327,7 @@ async def test_server_ban_prevents_hosting(lobby_server, database, command):
)
)
- proto.send_message({"command": command})
- await proto.drain()
+ await proto.send_message({"command": command})
msg = await proto.read_message()
assert msg == {
diff --git a/tests/integration_tests/test_servercontext.py b/tests/integration_tests/test_servercontext.py
index a02013d99..d54919f84 100644
--- a/tests/integration_tests/test_servercontext.py
+++ b/tests/integration_tests/test_servercontext.py
@@ -1,12 +1,10 @@
import asyncio
-from asynctest import exhaust_callbacks
-import pytest
from unittest import mock
-from server import ServerContext
+import pytest
+from asynctest import exhaust_callbacks
+from server import ServerContext, fake_statsd
from server.protocol import QDataStreamProtocol
-from server import fake_statsd
-
pytestmark = pytest.mark.asyncio
@@ -45,8 +43,7 @@ def fin():
async def test_serverside_abort(event_loop, mock_context, mock_server):
(reader, writer) = await asyncio.open_connection(*mock_context.sockets[0].getsockname())
proto = QDataStreamProtocol(reader, writer)
- proto.send_message({"some_junk": True})
- await writer.drain()
+ await proto.send_message({"some_junk": True})
await exhaust_callbacks(event_loop)
mock_server.on_connection_lost.assert_any_call()
diff --git a/tests/unit_tests/conftest.py b/tests/unit_tests/conftest.py
index 5e3b4ed2f..2e321e70b 100644
--- a/tests/unit_tests/conftest.py
+++ b/tests/unit_tests/conftest.py
@@ -1,14 +1,14 @@
from unittest import mock
+import asynctest
import pytest
+from asynctest import CoroutineMock
from server import GameStatsService
from server.game_service import GameService
from server.gameconnection import GameConnection, GameConnectionState
from server.games import Game
from server.geoip_service import GeoIpService
from server.ladder_service import LadderService
-import asynctest
-from asynctest import CoroutineMock
@pytest.fixture
@@ -19,7 +19,15 @@ def lobbythread():
@pytest.fixture
-def game_connection(request, database, game, players, game_service, player_service):
+def game_connection(
+ request,
+ database,
+ game,
+ players,
+ game_service,
+ player_service,
+ event_loop
+):
from server import GameConnection
conn = GameConnection(
database=database,
@@ -33,7 +41,7 @@ def game_connection(request, database, game, players, game_service, player_servi
conn.finished_sim = False
def fin():
- conn.abort()
+ event_loop.run_until_complete(conn.abort())
request.addfinalizer(fin)
return conn
diff --git a/tests/unit_tests/test_gameconnection.py b/tests/unit_tests/test_gameconnection.py
index dbfa60d44..a5ef526d0 100644
--- a/tests/unit_tests/test_gameconnection.py
+++ b/tests/unit_tests/test_gameconnection.py
@@ -1,17 +1,24 @@
import asyncio
from unittest import mock
-import pytest
-import asynctest
+import asynctest
+import pytest
+from asynctest import CoroutineMock, exhaust_callbacks
from server import GameConnection
+from server.abc.base_game import GameConnectionState
from server.games import Game
-from server.games.game import ValidityState, Victory
+from server.games.game import GameState, ValidityState, Victory
from server.players import PlayerState
-from asynctest import CoroutineMock, exhaust_callbacks
pytestmark = pytest.mark.asyncio
+@pytest.fixture
+def real_game(event_loop, database, game_service, game_stats_service):
+ game = Game(42, database, game_service, game_stats_service)
+ yield game
+
+
def assert_message_sent(game_connection: GameConnection, command, args):
game_connection.protocol.send_message.assert_called_with({
'command': command,
@@ -24,11 +31,46 @@ async def test_abort(game_connection: GameConnection, game: Game, players):
game_connection.player = players.hosting
game_connection.game = game
- game_connection.abort()
+ await game_connection.abort()
game.remove_game_connection.assert_called_with(game_connection)
+async def test_disconnect_all_peers(
+ game_connection: GameConnection,
+ real_game: Game,
+ players
+):
+ real_game.state = GameState.LOBBY
+ game_connection.player = players.hosting
+ game_connection.game = real_game
+
+ disconnect_done = mock.Mock()
+
+ async def fake_send_dc(player_id):
+ await asyncio.sleep(1) # Take some time
+ disconnect_done.success()
+ return "OK"
+
+ # Set up a peer that will disconnect without error
+ ok_disconnect = asynctest.create_autospec(GameConnection)
+ ok_disconnect.state = GameConnectionState.CONNECTED_TO_HOST
+ ok_disconnect.send_DisconnectFromPeer = fake_send_dc
+
+ # Set up a peer that will throw an exception
+ fail_disconnect = asynctest.create_autospec(GameConnection)
+ fail_disconnect.send_DisconnectFromPeer.return_value = Exception("Test exception")
+ fail_disconnect.state = GameConnectionState.CONNECTED_TO_HOST
+
+ # Add the peers to the game
+ real_game.add_game_connection(fail_disconnect)
+ real_game.add_game_connection(ok_disconnect)
+
+ await game_connection.disconnect_all_peers()
+
+ disconnect_done.success.assert_called_once()
+
+
async def test_handle_action_GameState_idle_adds_connection(
game: Game,
game_connection: GameConnection,
@@ -49,7 +91,7 @@ async def test_handle_action_GameState_idle_non_searching_player_aborts(
):
game_connection.player = players.hosting
game_connection.lobby = mock.Mock()
- game_connection.abort = mock.Mock()
+ game_connection.abort = CoroutineMock()
players.hosting.state = PlayerState.IDLE
await game_connection.handle_action('GameState', ['Idle'])
@@ -233,12 +275,12 @@ async def test_handle_action_TeamkillReport(game: Game, game_connection: GameCon
result = await conn.execute("select game_id,id from moderation_report where reporter_id=2 and game_id=%s and game_incident_timecode=200", (game.id))
report = await result.fetchone()
assert game.id == report["game_id"]
-
+
reported_user_query = await conn.execute("select player_id from reported_user where report_id=%s", (report["id"]))
data = await reported_user_query.fetchone()
assert data["player_id"] == 3
-
-
+
+
async def test_handle_action_TeamkillReport_invalid_ids(game: Game, game_connection: GameConnection, database):
game.launch = CoroutineMock()
await game_connection.handle_action('TeamkillReport', ['230', 0, 'Dostya', 0, 'Rhiza'])
@@ -247,7 +289,7 @@ async def test_handle_action_TeamkillReport_invalid_ids(game: Game, game_connect
result = await conn.execute("select game_id,id from moderation_report where reporter_id=2 and game_id=%s and game_incident_timecode=230", (game.id))
report = await result.fetchone()
assert game.id == report["game_id"]
-
+
reported_user_query = await conn.execute("select player_id from reported_user where report_id=%s", (report["id"]))
data = await reported_user_query.fetchone()
assert data["player_id"] == 3
@@ -288,7 +330,7 @@ async def test_handle_action_TeamkillHappened(game: Game, game_connection: GameC
async def test_handle_action_TeamkillHappened_AI(game: Game, game_connection: GameConnection, database):
# Should fail with a sql constraint error if this isn't handled correctly
- game_connection.abort = mock.Mock()
+ game_connection.abort = CoroutineMock()
await game_connection.handle_action('TeamkillHappened', ['200', 0, 'Dostya', '0', 'Rhiza'])
game_connection.abort.assert_not_called()
diff --git a/tests/unit_tests/test_ladder.py b/tests/unit_tests/test_ladder.py
index aa6227659..06737358f 100644
--- a/tests/unit_tests/test_ladder.py
+++ b/tests/unit_tests/test_ladder.py
@@ -1,12 +1,11 @@
import asyncio
from unittest import mock
-from asynctest import exhaust_callbacks
import pytest
+from asynctest import CoroutineMock, exhaust_callbacks
from server import GameService, LadderService
from server.matchmaker import Search
from server.players import PlayerState
-from asynctest import CoroutineMock
from tests.utils import fast_forward
pytestmark = pytest.mark.asyncio
@@ -17,11 +16,6 @@ async def test_start_game(ladder_service: LadderService, game_service:
p1 = player_factory('Dostya', player_id=1)
p2 = player_factory('Rhiza', player_id=2)
- mock_lc1 = mock.Mock()
- mock_lc2 = mock.Mock()
- p1.lobby_connection = mock_lc1
- p2.lobby_connection = mock_lc2
-
game_service.ladder_maps = [(1, 'scmp_007', 'maps/scmp_007.zip')]
with mock.patch('server.games.game.Game.await_hosted', CoroutineMock()):
@@ -37,11 +31,6 @@ async def test_start_game_timeout(ladder_service: LadderService, game_service:
p1 = player_factory('Dostya', player_id=1)
p2 = player_factory('Rhiza', player_id=2)
- mock_lc1 = mock.Mock()
- mock_lc2 = mock.Mock()
- p1.lobby_connection = mock_lc1
- p2.lobby_connection = mock_lc2
-
game_service.ladder_maps = [(1, 'scmp_007', 'maps/scmp_007.zip')]
await ladder_service.start_game(p1, p2)
@@ -56,19 +45,16 @@ async def test_start_game_timeout(ladder_service: LadderService, game_service:
async def test_inform_player(ladder_service: LadderService, player_factory):
p1 = player_factory('Dostya', player_id=1, ladder_rating=(1500, 500))
- mock_lc = mock.Mock()
- p1.lobby_connection = mock_lc
-
- ladder_service.inform_player(p1)
+ await ladder_service.inform_player(p1)
# Message is sent after the first call
p1.lobby_connection.send.assert_called_once()
- ladder_service.inform_player(p1)
+ await ladder_service.inform_player(p1)
p1.lobby_connection.send.reset_mock()
# But not after the second
p1.lobby_connection.send.assert_not_called()
ladder_service.on_connection_lost(p1)
- ladder_service.inform_player(p1)
+ await ladder_service.inform_player(p1)
# But it is called if the player relogs
p1.lobby_connection.send.assert_called_once()
@@ -78,12 +64,9 @@ async def test_start_and_cancel_search(ladder_service: LadderService,
player_factory, event_loop):
p1 = player_factory('Dostya', player_id=1, ladder_rating=(1500, 500), ladder_games=0)
- mock_lc = mock.Mock()
- p1.lobby_connection = mock_lc
-
search = Search([p1])
- ladder_service.start_search(p1, search, 'ladder1v1')
+ await ladder_service.start_search(p1, search, 'ladder1v1')
await exhaust_callbacks(event_loop)
assert p1.state == PlayerState.SEARCHING_LADDER
@@ -100,12 +83,9 @@ async def test_start_search_cancels_previous_search(
ladder_service: LadderService, player_factory, event_loop):
p1 = player_factory('Dostya', player_id=1, ladder_rating=(1500, 500), ladder_games=0)
- mock_lc = mock.Mock()
- p1.lobby_connection = mock_lc
-
search1 = Search([p1])
- ladder_service.start_search(p1, search1, 'ladder1v1')
+ await ladder_service.start_search(p1, search1, 'ladder1v1')
await exhaust_callbacks(event_loop)
assert p1.state == PlayerState.SEARCHING_LADDER
@@ -113,7 +93,7 @@ async def test_start_search_cancels_previous_search(
search2 = Search([p1])
- ladder_service.start_search(p1, search2, 'ladder1v1')
+ await ladder_service.start_search(p1, search2, 'ladder1v1')
await exhaust_callbacks(event_loop)
assert p1.state == PlayerState.SEARCHING_LADDER
@@ -126,12 +106,9 @@ async def test_cancel_all_searches(ladder_service: LadderService,
player_factory, event_loop):
p1 = player_factory('Dostya', player_id=1, ladder_rating=(1500, 500), ladder_games=0)
- mock_lc = mock.Mock()
- p1.lobby_connection = mock_lc
-
search = Search([p1])
- ladder_service.start_search(p1, search, 'ladder1v1')
+ await ladder_service.start_search(p1, search, 'ladder1v1')
await exhaust_callbacks(event_loop)
assert p1.state == PlayerState.SEARCHING_LADDER
@@ -149,16 +126,11 @@ async def test_cancel_twice(ladder_service: LadderService, player_factory):
p1 = player_factory('Dostya', player_id=1, ladder_rating=(1500, 500), ladder_games=0)
p2 = player_factory('Brackman', player_id=2, ladder_rating=(2000, 500), ladder_games=0)
- mock_lc1 = mock.Mock()
- mock_lc2 = mock.Mock()
- p1.lobby_connection = mock_lc1
- p2.lobby_connection = mock_lc2
-
search = Search([p1])
search2 = Search([p2])
- ladder_service.start_search(p1, search, 'ladder1v1')
- ladder_service.start_search(p2, search2, 'ladder1v1')
+ await ladder_service.start_search(p1, search, 'ladder1v1')
+ await ladder_service.start_search(p2, search2, 'ladder1v1')
searches = ladder_service._cancel_existing_searches(p1)
assert search.is_cancelled
@@ -179,18 +151,13 @@ async def test_start_game_called_on_match(ladder_service: LadderService,
p1 = player_factory('Dostya', player_id=1, ladder_rating=(2300, 64), ladder_games=0)
p2 = player_factory('QAI', player_id=2, ladder_rating=(2350, 125), ladder_games=0)
- mock_lc1 = mock.Mock()
- mock_lc2 = mock.Mock()
- p1.lobby_connection = mock_lc1
- p2.lobby_connection = mock_lc2
-
ladder_service.start_game = CoroutineMock()
- ladder_service.inform_player = mock.Mock()
+ ladder_service.inform_player = CoroutineMock()
- ladder_service.start_search(p1, Search([p1]), 'ladder1v1')
- ladder_service.start_search(p2, Search([p2]), 'ladder1v1')
+ await ladder_service.start_search(p1, Search([p1]), 'ladder1v1')
+ await ladder_service.start_search(p2, Search([p2]), 'ladder1v1')
- await asyncio.sleep(1)
+ await asyncio.sleep(2)
ladder_service.inform_player.assert_called()
ladder_service.start_game.assert_called_once()
diff --git a/tests/unit_tests/test_lobbyconnection.py b/tests/unit_tests/test_lobbyconnection.py
index e43b403e2..f6f2dc32f 100644
--- a/tests/unit_tests/test_lobbyconnection.py
+++ b/tests/unit_tests/test_lobbyconnection.py
@@ -2,12 +2,14 @@
from unittest import mock
from unittest.mock import Mock
+import asynctest
import pytest
from aiohttp import web
from asynctest import CoroutineMock
from server import GameState, VisibilityState
from server.db.models import ban, friends_and_foes
from server.game_service import GameService
+from server.gameconnection import GameConnection
from server.games import CustomGame, Game
from server.geoip_service import GeoIpService
from server.ice_servers.nts import TwilioNTS
@@ -71,7 +73,7 @@ def mock_games(database, mock_players, game_stats_service):
@pytest.fixture
def mock_protocol():
- return mock.create_autospec(QDataStreamProtocol(mock.Mock(), mock.Mock()))
+ return asynctest.create_autospec(QDataStreamProtocol(mock.Mock(), mock.Mock()))
@pytest.fixture
@@ -95,6 +97,7 @@ def lobbyconnection(event_loop, database, mock_protocol, mock_games, mock_player
lc.player_service.get_permission_group.return_value = 0
lc.player_service.fetch_player_data = CoroutineMock()
lc.peer_address = Address('127.0.0.1', 1234)
+ lc._authenticated = True
return lc
@@ -125,13 +128,62 @@ async def start_app():
event_loop.run_until_complete(runner.cleanup())
+async def test_unauthenticated_calls_abort(lobbyconnection, test_game_info):
+ lobbyconnection._authenticated = False
+ lobbyconnection.abort = CoroutineMock()
+
+ await lobbyconnection.on_message_received({
+ "command": "game_host",
+ **test_game_info
+ })
+
+ lobbyconnection.abort.assert_called_once_with(
+ "Message invalid for unauthenticated connection: game_host"
+ )
+
+
+async def test_bad_command_calls_abort(lobbyconnection):
+ lobbyconnection.send = CoroutineMock()
+ lobbyconnection.abort = CoroutineMock()
+
+ await lobbyconnection.on_message_received({
+ "command": "this_isnt_real"
+ })
+
+ lobbyconnection.send.assert_called_once_with({"command": "invalid"})
+ lobbyconnection.abort.assert_called_once_with("Error processing command")
+
+
+async def test_command_pong_does_nothing(lobbyconnection):
+ lobbyconnection.send = CoroutineMock()
+
+ await lobbyconnection.on_message_received({
+ "command": "pong"
+ })
+
+ lobbyconnection.send.assert_not_called()
+
+
+async def test_command_create_account_returns_error(lobbyconnection):
+ lobbyconnection.send = CoroutineMock()
+
+ await lobbyconnection.on_message_received({
+ "command": "create_account"
+ })
+
+ lobbyconnection.send.assert_called_once_with({
+ "command": "notice",
+ "style": "error",
+ "text": ("FAF no longer supports direct registration. "
+ "Please use the website to register.")
+ })
+
+
async def test_command_game_host_creates_game(lobbyconnection,
mock_games,
test_game_info,
players):
lobbyconnection.player = players.hosting
- lobbyconnection.protocol = mock.Mock()
- lobbyconnection._authenticated = True
await lobbyconnection.on_message_received({
"command": "game_host",
**test_game_info
@@ -148,12 +200,12 @@ async def test_command_game_host_creates_game(lobbyconnection,
async def test_launch_game(lobbyconnection, game, player_factory):
- old_game_conn = mock.Mock()
+ old_game_conn = asynctest.create_autospec(GameConnection)
lobbyconnection.player = player_factory()
lobbyconnection.game_connection = old_game_conn
- lobbyconnection.send = mock.Mock()
- lobbyconnection.launch_game(game)
+ lobbyconnection.send = CoroutineMock()
+ await lobbyconnection.launch_game(game)
# Verify all side effects of launch_game here
old_game_conn.abort.assert_called_with("Player launched a new game")
@@ -170,10 +222,8 @@ async def test_command_game_host_creates_correct_game(
lobbyconnection, game_service, test_game_info, players):
lobbyconnection.player = players.hosting
lobbyconnection.game_service = game_service
- lobbyconnection.launch_game = mock.Mock()
+ lobbyconnection.launch_game = CoroutineMock()
- lobbyconnection.protocol = mock.Mock()
- lobbyconnection._authenticated = True
await lobbyconnection.on_message_received({
"command": "game_host",
**test_game_info
@@ -192,7 +242,6 @@ async def test_command_game_join_calls_join_game(mocker,
players,
game_stats_service):
lobbyconnection.game_service = game_service
- mock_protocol = mocker.patch.object(lobbyconnection, 'protocol')
game = mock.create_autospec(Game(42, database, game_service, game_stats_service))
game.state = GameState.LOBBY
game.password = None
@@ -202,7 +251,6 @@ async def test_command_game_join_calls_join_game(mocker,
lobbyconnection.player = players.hosting
test_game_info['uid'] = 42
- lobbyconnection._authenticated = True
await lobbyconnection.on_message_received({
"command": "game_join",
**test_game_info
@@ -213,7 +261,7 @@ async def test_command_game_join_calls_join_game(mocker,
'uid': 42,
'args': ['/numgames {}'.format(players.hosting.game_count[RatingType.GLOBAL])]
}
- mock_protocol.send_message.assert_called_with(expected_reply)
+ lobbyconnection.protocol.send_message.assert_called_with(expected_reply)
async def test_command_game_join_uid_as_str(mocker,
@@ -224,7 +272,6 @@ async def test_command_game_join_uid_as_str(mocker,
players,
game_stats_service):
lobbyconnection.game_service = game_service
- mock_protocol = mocker.patch.object(lobbyconnection, 'protocol')
game = mock.create_autospec(Game(42, database, game_service, game_stats_service))
game.state = GameState.LOBBY
game.password = None
@@ -234,7 +281,6 @@ async def test_command_game_join_uid_as_str(mocker,
lobbyconnection.player = players.hosting
test_game_info['uid'] = '42' # Pass in uid as string
- lobbyconnection._authenticated = True
await lobbyconnection.on_message_received({
"command": "game_join",
**test_game_info
@@ -245,7 +291,7 @@ async def test_command_game_join_uid_as_str(mocker,
'uid': 42,
'args': ['/numgames {}'.format(players.hosting.game_count[RatingType.GLOBAL])]
}
- mock_protocol.send_message.assert_called_with(expected_reply)
+ lobbyconnection.protocol.send_message.assert_called_with(expected_reply)
async def test_command_game_join_without_password(lobbyconnection,
@@ -254,7 +300,7 @@ async def test_command_game_join_without_password(lobbyconnection,
test_game_info,
players,
game_stats_service):
- lobbyconnection.send = mock.Mock()
+ lobbyconnection.send = CoroutineMock()
lobbyconnection.game_service = game_service
game = mock.create_autospec(Game(42, database, game_service, game_stats_service))
game.state = GameState.LOBBY
@@ -266,7 +312,6 @@ async def test_command_game_join_without_password(lobbyconnection,
test_game_info['uid'] = 42
del test_game_info['password']
- lobbyconnection._authenticated = True
await lobbyconnection.on_message_received({
"command": "game_join",
**test_game_info
@@ -279,12 +324,11 @@ async def test_command_game_join_game_not_found(lobbyconnection,
game_service,
test_game_info,
players):
- lobbyconnection.send = mock.Mock()
+ lobbyconnection.send = CoroutineMock()
lobbyconnection.game_service = game_service
lobbyconnection.player = players.hosting
test_game_info['uid'] = 42
- lobbyconnection._authenticated = True
await lobbyconnection.on_message_received({
"command": "game_join",
**test_game_info
@@ -295,12 +339,9 @@ async def test_command_game_join_game_not_found(lobbyconnection,
async def test_command_game_host_calls_host_game_invalid_title(lobbyconnection,
mock_games,
- test_game_info_invalid,
- players):
- lobbyconnection.send = mock.Mock()
+ test_game_info_invalid):
+ lobbyconnection.send = CoroutineMock()
mock_games.create_game = mock.Mock()
- lobbyconnection.player = players.hosting
- lobbyconnection._authenticated = True
await lobbyconnection.on_message_received({
"command": "game_host",
**test_game_info_invalid
@@ -311,33 +352,31 @@ async def test_command_game_host_calls_host_game_invalid_title(lobbyconnection,
async def test_abort(mocker, lobbyconnection):
- proto = mocker.patch.object(lobbyconnection, 'protocol')
-
- lobbyconnection.abort()
+ lobbyconnection.protocol.writer.close = mock.Mock()
+ await lobbyconnection.abort()
- proto.writer.close.assert_any_call()
+ lobbyconnection.protocol.writer.close.assert_any_call()
async def test_send_game_list(mocker, database, lobbyconnection, game_stats_service):
- protocol = mocker.patch.object(lobbyconnection, 'protocol')
games = mocker.patch.object(lobbyconnection, 'game_service') # type: GameService
game1, game2 = mock.create_autospec(Game(42, database, mock.Mock(), game_stats_service)), \
mock.create_autospec(Game(22, database, mock.Mock(), game_stats_service))
games.open_games = [game1, game2]
- lobbyconnection.send_game_list()
-
- protocol.send_message.assert_any_call({'command': 'game_info',
- 'games': [game1.to_dict(), game2.to_dict()]})
+ await lobbyconnection.send_game_list()
+ lobbyconnection.protocol.send_message.assert_any_call({
+ 'command': 'game_info',
+ 'games': [game1.to_dict(), game2.to_dict()]
+ })
-async def test_send_coop_maps(mocker, lobbyconnection):
- protocol = mocker.patch.object(lobbyconnection, 'protocol')
- await lobbyconnection.send_coop_maps()
+async def test_coop_list(mocker, lobbyconnection):
+ await lobbyconnection.command_coop_list({})
- args = protocol.send_messages.call_args_list
+ args = lobbyconnection.protocol.send_messages.call_args_list
assert len(args) == 1
coop_maps = args[0][0][0]
for info in coop_maps:
@@ -394,7 +433,7 @@ async def test_command_admin_closelobby(mocker, lobbyconnection):
tuna.id = 55
lobbyconnection.player_service = {1: player, 55: tuna}
- await lobbyconnection.command_admin({
+ await lobbyconnection.on_message_received({
'command': 'admin',
'action': 'closelobby',
'user_id': 55
@@ -411,7 +450,6 @@ async def test_command_admin_closelobby_with_ban(mocker, lobbyconnection, databa
banme = mock.Mock()
banme.id = 200
lobbyconnection.player_service = {1: player, banme.id: banme}
- lobbyconnection._authenticated = True
await lobbyconnection.on_message_received({
'command': 'admin',
@@ -443,7 +481,6 @@ async def test_command_admin_closelobby_with_ban_but_already_banned(mocker, lobb
banme = mock.Mock()
banme.id = 200
lobbyconnection.player_service = {1: player, banme.id: banme}
- lobbyconnection._authenticated = True
await lobbyconnection.on_message_received({
'command': 'admin',
@@ -482,7 +519,6 @@ async def test_command_admin_closelobby_with_ban_but_already_banned(mocker, lobb
async def test_command_admin_closelobby_with_ban_duration_no_period(mocker, lobbyconnection, database):
- mocker.patch.object(lobbyconnection, 'protocol')
player = mocker.patch.object(lobbyconnection, 'player')
player.login = 'Sheeo'
player.id = 1
@@ -490,7 +526,6 @@ async def test_command_admin_closelobby_with_ban_duration_no_period(mocker, lobb
banme = mock.Mock()
banme.id = 200
lobbyconnection.player_service = {1: player, banme.id: banme}
- lobbyconnection._authenticated = True
mocker.patch('server.lobbyconnection.func.now', return_value=text('FROM_UNIXTIME(1000)'))
await lobbyconnection.on_message_received({
@@ -515,13 +550,11 @@ async def test_command_admin_closelobby_with_ban_duration_no_period(mocker, lobb
async def test_command_admin_closelobby_with_ban_bad_period(mocker, lobbyconnection, database):
- proto = mocker.patch.object(lobbyconnection, 'protocol')
player = mocker.patch.object(lobbyconnection, 'player')
player.admin = True
banme = mock.Mock()
banme.id = 1
lobbyconnection.player_service = {1: player, banme.id: banme}
- lobbyconnection._authenticated = True
await lobbyconnection.on_message_received({
'command': 'admin',
@@ -535,7 +568,7 @@ async def test_command_admin_closelobby_with_ban_bad_period(mocker, lobbyconnect
})
banme.lobbyconnection.kick.assert_not_called()
- proto.send_message.assert_called_once_with({
+ lobbyconnection.protocol.send_message.assert_called_once_with({
'command': 'notice',
'style': 'error',
'text': "Period ') INJECTED!' is not allowed!"
@@ -550,13 +583,11 @@ async def test_command_admin_closelobby_with_ban_bad_period(mocker, lobbyconnect
async def test_command_admin_closelobby_with_ban_injection(mocker, lobbyconnection, database):
- proto = mocker.patch.object(lobbyconnection, 'protocol')
player = mocker.patch.object(lobbyconnection, 'player')
player.admin = True
banme = mock.Mock()
banme.id = 1
lobbyconnection.player_service = {1: player, banme.id: banme}
- lobbyconnection._authenticated = True
await lobbyconnection.on_message_received({
'command': 'admin',
@@ -570,7 +601,7 @@ async def test_command_admin_closelobby_with_ban_injection(mocker, lobbyconnecti
})
banme.lobbyconnection.kick.assert_not_called()
- proto.send_message.assert_called_once_with({
+ lobbyconnection.protocol.send_message.assert_called_once_with({
'command': 'notice',
'style': 'error',
'text': "Period ') INJECTED!' is not allowed!"
@@ -585,7 +616,6 @@ async def test_command_admin_closelobby_with_ban_injection(mocker, lobbyconnecti
async def test_command_admin_closeFA(mocker, lobbyconnection):
- mocker.patch.object(lobbyconnection, 'protocol')
mocker.patch.object(lobbyconnection, '_logger')
player = mocker.patch.object(lobbyconnection, 'player')
player.login = 'Sheeo'
@@ -593,7 +623,6 @@ async def test_command_admin_closeFA(mocker, lobbyconnection):
player.id = 42
tuna = mock.Mock()
tuna.id = 55
- lobbyconnection._authenticated = True
lobbyconnection.player_service = {42: player, 55: tuna}
await lobbyconnection.on_message_received({
@@ -612,7 +641,7 @@ async def test_game_subscription(lobbyconnection: LobbyConnection):
game = Mock()
game.handle_action = CoroutineMock()
lobbyconnection.game_connection = game
- lobbyconnection.ensure_authenticated = lambda _: True
+ lobbyconnection.ensure_authenticated = CoroutineMock(return_value=True)
await lobbyconnection.on_message_received({'command': 'test',
'args': ['foo', 42],
@@ -622,15 +651,15 @@ async def test_game_subscription(lobbyconnection: LobbyConnection):
async def test_command_avatar_list(mocker, lobbyconnection: LobbyConnection, mock_player: Player):
- protocol = mocker.patch.object(lobbyconnection, 'protocol')
lobbyconnection.player = mock_player
lobbyconnection.player.id = 2 # Dostya test user
- await lobbyconnection.command_avatar({
+ await lobbyconnection.on_message_received({
+ 'command': 'avatar',
'action': 'list_avatar'
})
- protocol.send_message.assert_any_call({
+ lobbyconnection.protocol.send_message.assert_any_call({
"command": "avatar",
"avatarlist": [{'url': 'http://content.faforever.com/faf/avatars/qai2.png', 'tooltip': 'QAI'}, {'url': 'http://content.faforever.com/faf/avatars/UEF.png', 'tooltip': 'UEF'}]
})
@@ -639,7 +668,6 @@ async def test_command_avatar_list(mocker, lobbyconnection: LobbyConnection, moc
async def test_command_avatar_select(mocker, database, lobbyconnection: LobbyConnection, mock_player: Player):
lobbyconnection.player = mock_player
lobbyconnection.player.id = 2 # Dostya test user
- lobbyconnection._authenticated = True
await lobbyconnection.on_message_received({
'command': 'avatar',
@@ -670,7 +698,6 @@ async def get_friends(player_id, database):
async def test_command_social_add_friend(lobbyconnection, mock_player, database):
lobbyconnection.player = mock_player
lobbyconnection.player.id = 1
- lobbyconnection._authenticated = True
friends = await get_friends(lobbyconnection.player.id, database)
assert friends == []
@@ -687,7 +714,6 @@ async def test_command_social_add_friend(lobbyconnection, mock_player, database)
async def test_command_social_remove_friend(lobbyconnection, mock_player, database):
lobbyconnection.player = mock_player
lobbyconnection.player.id = 2
- lobbyconnection._authenticated = True
friends = await get_friends(lobbyconnection.player.id, database)
assert friends == [1]
@@ -702,7 +728,6 @@ async def test_command_social_remove_friend(lobbyconnection, mock_player, databa
async def test_broadcast(lobbyconnection: LobbyConnection, mocker):
- mocker.patch.object(lobbyconnection, 'protocol')
player = mocker.patch.object(lobbyconnection, 'player')
player.login = 'Sheeo'
player.admin = True
@@ -710,7 +735,7 @@ async def test_broadcast(lobbyconnection: LobbyConnection, mocker):
tuna.id = 55
lobbyconnection.player_service = [player, tuna]
- await lobbyconnection.command_admin({
+ await lobbyconnection.on_message_received({
'command': 'admin',
'action': 'broadcast',
'message': "This is a test message"
@@ -721,16 +746,18 @@ async def test_broadcast(lobbyconnection: LobbyConnection, mocker):
async def test_game_connection_not_restored_if_no_such_game_exists(lobbyconnection: LobbyConnection, mocker, mock_player):
- protocol = mocker.patch.object(lobbyconnection, 'protocol')
lobbyconnection.player = mock_player
del lobbyconnection.player.game_connection
lobbyconnection.player.state = PlayerState.IDLE
- lobbyconnection.command_restore_game_session({'game_id': 123})
+ await lobbyconnection.on_message_received({
+ 'command': 'restore_game_session',
+ 'game_id': 123
+ })
assert not lobbyconnection.player.game_connection
assert lobbyconnection.player.state == PlayerState.IDLE
- protocol.send_message.assert_any_call({
+ lobbyconnection.protocol.send_message.assert_any_call({
"command": "notice",
"style": "info",
"text": "The game you were connected to does no longer exist"
@@ -741,7 +768,6 @@ async def test_game_connection_not_restored_if_no_such_game_exists(lobbyconnecti
async def test_game_connection_not_restored_if_game_state_prohibits(lobbyconnection: LobbyConnection, game_service: GameService,
game_stats_service, game_state, mock_player, mocker,
database):
- protocol = mocker.patch.object(lobbyconnection, 'protocol')
lobbyconnection.player = mock_player
del lobbyconnection.player.game_connection
lobbyconnection.player.state = PlayerState.IDLE
@@ -753,12 +779,15 @@ async def test_game_connection_not_restored_if_game_state_prohibits(lobbyconnect
game.id = 42
game_service.games[42] = game
- lobbyconnection.command_restore_game_session({'game_id': 42})
+ await lobbyconnection.on_message_received({
+ 'command': 'restore_game_session',
+ 'game_id': 42
+ })
assert not lobbyconnection.game_connection
assert lobbyconnection.player.state == PlayerState.IDLE
- protocol.send_message.assert_any_call({
+ lobbyconnection.protocol.send_message.assert_any_call({
"command": "notice",
"style": "info",
"text": "The game you were connected to is no longer available"
@@ -779,7 +808,10 @@ async def test_game_connection_restored_if_game_exists(lobbyconnection: LobbyCon
game.id = 42
game_service.games[42] = game
- lobbyconnection.command_restore_game_session({'game_id': 42})
+ await lobbyconnection.on_message_received({
+ 'command': 'restore_game_session',
+ 'game_id': 42
+ })
assert lobbyconnection.game_connection
assert lobbyconnection.player.state == PlayerState.PLAYING
@@ -788,7 +820,6 @@ async def test_game_connection_restored_if_game_exists(lobbyconnection: LobbyCon
async def test_command_game_matchmaking(lobbyconnection, mock_player):
lobbyconnection.player = mock_player
lobbyconnection.player.id = 1
- lobbyconnection._authenticated = True
await lobbyconnection.on_message_received({
'command': 'game_matchmaking',
@@ -822,11 +853,11 @@ async def test_check_policy_conformity_fraudulent(lobbyconnection, policy_server
f'http://{host}:{port}'
):
# 42 is not a valid player ID which should cause a SQL constraint error
- lobbyconnection.abort = mock.Mock()
+ lobbyconnection.abort = CoroutineMock()
with pytest.raises(ClientError):
await lobbyconnection.check_policy_conformity(42, "fraudulent", session=100)
- lobbyconnection.abort = mock.Mock()
+ lobbyconnection.abort = CoroutineMock()
player_id = 200
honest = await lobbyconnection.check_policy_conformity(player_id, "fraudulent", session=100)
assert honest is False
@@ -849,7 +880,7 @@ async def test_check_policy_conformity_fatal(lobbyconnection, policy_server):
f'http://{host}:{port}'
):
for result in ('vm', 'already_associated', 'fraudulent'):
- lobbyconnection.abort = mock.Mock()
+ lobbyconnection.abort = CoroutineMock()
honest = await lobbyconnection.check_policy_conformity(1, result, session=100)
assert honest is False
lobbyconnection.abort.assert_called_once()