diff --git a/.github/workflows/unittests.yml b/.github/workflows/unittests.yml index b2530bd06c7d..3ad29b007772 100644 --- a/.github/workflows/unittests.yml +++ b/.github/workflows/unittests.yml @@ -24,7 +24,7 @@ on: - '.github/workflows/unittests.yml' jobs: - build: + unit: runs-on: ${{ matrix.os }} name: Test Python ${{ matrix.python.version }} ${{ matrix.os }} @@ -60,3 +60,32 @@ jobs: - name: Unittests run: | pytest -n auto + + hosting: + runs-on: ${{ matrix.os }} + name: Test hosting with ${{ matrix.python.version }} on ${{ matrix.os }} + + strategy: + matrix: + os: + - ubuntu-latest + python: + - {version: '3.11'} # current + + steps: + - uses: actions/checkout@v4 + - name: Set up Python ${{ matrix.python.version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python.version }} + - name: Install dependencies + run: | + python -m venv venv + source venv/bin/activate + python -m pip install --upgrade pip + python ModuleUpdate.py --yes --force --append "WebHostLib/requirements.txt" + - name: Test hosting + run: | + source venv/bin/activate + export PYTHONPATH=$(pwd) + python test/hosting/__main__.py diff --git a/MultiServer.py b/MultiServer.py index 6816f381ea82..e1d2ac81aa85 100644 --- a/MultiServer.py +++ b/MultiServer.py @@ -3,6 +3,7 @@ import argparse import asyncio import collections +import contextlib import copy import datetime import functools @@ -176,7 +177,7 @@ class Context: location_name_groups: typing.Dict[str, typing.Dict[str, typing.Set[str]]] all_item_and_group_names: typing.Dict[str, typing.Set[str]] all_location_and_group_names: typing.Dict[str, typing.Set[str]] - non_hintable_names: typing.Dict[str, typing.Set[str]] + non_hintable_names: typing.Dict[str, typing.AbstractSet[str]] spheres: typing.List[typing.Dict[int, typing.Set[int]]] """ each sphere is { player: { location_id, ... } } """ logger: logging.Logger @@ -231,7 +232,7 @@ def __init__(self, host: str, port: int, server_password: str, password: str, lo self.embedded_blacklist = {"host", "port"} self.client_ids: typing.Dict[typing.Tuple[int, int], datetime.datetime] = {} self.auto_save_interval = 60 # in seconds - self.auto_saver_thread = None + self.auto_saver_thread: typing.Optional[threading.Thread] = None self.save_dirty = False self.tags = ['AP'] self.games: typing.Dict[int, str] = {} @@ -268,6 +269,11 @@ def _load_game_data(self): for world_name, world in worlds.AutoWorldRegister.world_types.items(): self.non_hintable_names[world_name] = world.hint_blacklist + for game_package in self.gamespackage.values(): + # remove groups from data sent to clients + del game_package["item_name_groups"] + del game_package["location_name_groups"] + def _init_game_data(self): for game_name, game_package in self.gamespackage.items(): if "checksum" in game_package: @@ -1926,8 +1932,6 @@ def _cmd_status(self, tag: str = "") -> bool: def _cmd_exit(self) -> bool: """Shutdown the server""" self.ctx.server.ws_server.close() - if self.ctx.shutdown_task: - self.ctx.shutdown_task.cancel() self.ctx.exit_event.set() return True @@ -2286,7 +2290,8 @@ def parse_args() -> argparse.Namespace: async def auto_shutdown(ctx, to_cancel=None): - await asyncio.sleep(ctx.auto_shutdown) + with contextlib.suppress(asyncio.TimeoutError): + await asyncio.wait_for(ctx.exit_event.wait(), ctx.auto_shutdown) def inactivity_shutdown(): ctx.server.ws_server.close() @@ -2306,7 +2311,8 @@ def inactivity_shutdown(): if seconds < 0: inactivity_shutdown() else: - await asyncio.sleep(seconds) + with contextlib.suppress(asyncio.TimeoutError): + await asyncio.wait_for(ctx.exit_event.wait(), seconds) def load_server_cert(path: str, cert_key: typing.Optional[str]) -> "ssl.SSLContext": diff --git a/WebHost.py b/WebHost.py index 9b5edd322f91..afacd6288ec2 100644 --- a/WebHost.py +++ b/WebHost.py @@ -12,6 +12,9 @@ import Utils import settings +if typing.TYPE_CHECKING: + from flask import Flask + Utils.local_path.cached_path = os.path.dirname(__file__) or "." # py3.8 is not abs. remove "." when dropping 3.8 settings.no_gui = True configpath = os.path.abspath("config.yaml") @@ -19,7 +22,7 @@ configpath = os.path.abspath(Utils.user_path('config.yaml')) -def get_app(): +def get_app() -> "Flask": from WebHostLib import register, cache, app as raw_app from WebHostLib.models import db diff --git a/WebHostLib/customserver.py b/WebHostLib/customserver.py index 3a86cb551d27..9f70165b61e5 100644 --- a/WebHostLib/customserver.py +++ b/WebHostLib/customserver.py @@ -168,17 +168,28 @@ def get_random_port(): def get_static_server_data() -> dict: import worlds data = { - "non_hintable_names": {}, - "gamespackage": worlds.network_data_package["games"], - "item_name_groups": {world_name: world.item_name_groups for world_name, world in - worlds.AutoWorldRegister.world_types.items()}, - "location_name_groups": {world_name: world.location_name_groups for world_name, world in - worlds.AutoWorldRegister.world_types.items()}, + "non_hintable_names": { + world_name: world.hint_blacklist + for world_name, world in worlds.AutoWorldRegister.world_types.items() + }, + "gamespackage": { + world_name: { + key: value + for key, value in game_package.items() + if key not in ("item_name_groups", "location_name_groups") + } + for world_name, game_package in worlds.network_data_package["games"].items() + }, + "item_name_groups": { + world_name: world.item_name_groups + for world_name, world in worlds.AutoWorldRegister.world_types.items() + }, + "location_name_groups": { + world_name: world.location_name_groups + for world_name, world in worlds.AutoWorldRegister.world_types.items() + }, } - for world_name, world in worlds.AutoWorldRegister.world_types.items(): - data["non_hintable_names"][world_name] = world.hint_blacklist - return data @@ -266,12 +277,15 @@ async def start_room(room_id): ctx.logger.exception("Could not determine port. Likely hosting failure.") with db_session: ctx.auto_shutdown = Room.get(id=room_id).timeout + if ctx.saving: + setattr(asyncio.current_task(), "save", lambda: ctx._save(True)) ctx.shutdown_task = asyncio.create_task(auto_shutdown(ctx, [])) await ctx.shutdown_task except (KeyboardInterrupt, SystemExit): if ctx.saving: ctx._save() + setattr(asyncio.current_task(), "save", None) except Exception as e: with db_session: room = Room.get(id=room_id) @@ -281,8 +295,12 @@ async def start_room(room_id): else: if ctx.saving: ctx._save() + setattr(asyncio.current_task(), "save", None) finally: try: + ctx.save_dirty = False # make sure the saving thread does not write to DB after final wakeup + ctx.exit_event.set() # make sure the saving thread stops at some point + # NOTE: async saving should probably be an async task and could be merged with shutdown_task with (db_session): # ensure the Room does not spin up again on its own, minute of safety buffer room = Room.get(id=room_id) @@ -294,13 +312,32 @@ async def start_room(room_id): rooms_shutting_down.put(room_id) class Starter(threading.Thread): + _tasks: typing.List[asyncio.Future] + + def __init__(self): + super().__init__() + self._tasks = [] + + def _done(self, task: asyncio.Future): + self._tasks.remove(task) + task.result() + def run(self): while 1: next_room = rooms_to_run.get(block=True, timeout=None) - asyncio.run_coroutine_threadsafe(start_room(next_room), loop) + task = asyncio.run_coroutine_threadsafe(start_room(next_room), loop) + self._tasks.append(task) + task.add_done_callback(self._done) logging.info(f"Starting room {next_room} on {name}.") starter = Starter() starter.daemon = True starter.start() - loop.run_forever() + try: + loop.run_forever() + finally: + # save all tasks that want to be saved during shutdown + for task in asyncio.all_tasks(loop): + save: typing.Optional[typing.Callable[[], typing.Any]] = getattr(task, "save", None) + if save: + save() diff --git a/test/hosting/__init__.py b/test/hosting/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/test/hosting/__main__.py b/test/hosting/__main__.py new file mode 100644 index 000000000000..6640c637b5bd --- /dev/null +++ b/test/hosting/__main__.py @@ -0,0 +1,191 @@ +# A bunch of tests to verify MultiServer and custom webhost server work as expected. +# This spawns processes and may modify your local AP, so this is not run as part of unit testing. +# Run with `python test/hosting` instead, +import logging +import traceback +from tempfile import TemporaryDirectory +from time import sleep +from typing import Any + +from test.hosting.client import Client +from test.hosting.generate import generate_local +from test.hosting.serve import ServeGame, LocalServeGame, WebHostServeGame +from test.hosting.webhost import (create_room, get_app, get_multidata_for_room, set_multidata_for_room, start_room, + stop_autohost, upload_multidata) +from test.hosting.world import copy as copy_world, delete as delete_world + +failure = False +fail_fast = True + + +def assert_true(condition: Any, msg: str = "") -> None: + global failure + if not condition: + failure = True + msg = f": {msg}" if msg else "" + raise AssertionError(f"Assertion failed{msg}") + + +def assert_equal(first: Any, second: Any, msg: str = "") -> None: + global failure + if first != second: + failure = True + msg = f": {msg}" if msg else "" + raise AssertionError(f"Assertion failed: {first} == {second}{msg}") + + +if fail_fast: + expect_true = assert_true + expect_equal = assert_equal +else: + def expect_true(condition: Any, msg: str = "") -> None: + global failure + if not condition: + failure = True + tb = "".join(traceback.format_stack()[:-1]) + msg = f": {msg}" if msg else "" + logging.error(f"Expectation failed{msg}\n{tb}") + + def expect_equal(first: Any, second: Any, msg: str = "") -> None: + global failure + if first != second: + failure = True + tb = "".join(traceback.format_stack()[:-1]) + msg = f": {msg}" if msg else "" + logging.error(f"Expectation failed {first} == {second}{msg}\n{tb}") + + +if __name__ == "__main__": + import warnings + warnings.simplefilter("ignore", ResourceWarning) + warnings.simplefilter("ignore", UserWarning) + + spacer = '=' * 80 + + with TemporaryDirectory() as tempdir: + multis = [["Clique"], ["Temp World"], ["Clique", "Temp World"]] + p1_games = [] + data_paths = [] + rooms = [] + + copy_world("Clique", "Temp World") + try: + for n, games in enumerate(multis, 1): + print(f"Generating [{n}] {', '.join(games)}") + multidata = generate_local(games, tempdir) + print(f"Generated [{n}] {', '.join(games)} as {multidata}\n") + p1_games.append(games[0]) + data_paths.append(multidata) + finally: + delete_world("Temp World") + + webapp = get_app(tempdir) + webhost_client = webapp.test_client() + for n, multidata in enumerate(data_paths, 1): + seed = upload_multidata(webhost_client, multidata) + room = create_room(webhost_client, seed) + print(f"Uploaded [{n}] {multidata} as {room}\n") + rooms.append(room) + + print("Starting autohost") + from WebHostLib.autolauncher import autohost + try: + autohost(webapp.config) + + host: ServeGame + for n, (multidata, room, game, multi_games) in enumerate(zip(data_paths, rooms, p1_games, multis), 1): + involved_games = {"Archipelago"} | set(multi_games) + for collected_items in range(3): + print(f"\nTesting [{n}] {game} in {multidata} on MultiServer with {collected_items} items collected") + with LocalServeGame(multidata) as host: + with Client(host.address, game, "Player1") as client: + local_data_packages = client.games_packages + local_collected_items = len(client.checked_locations) + if collected_items < 2: # Clique only has 2 Locations + client.collect_any() + # TODO: Ctrl+C test here as well + + for game_name in sorted(involved_games): + expect_true(game_name in local_data_packages, + f"{game_name} missing from MultiServer datap ackage") + expect_true("item_name_groups" not in local_data_packages.get(game_name, {}), + f"item_name_groups are not supposed to be in MultiServer data for {game_name}") + expect_true("location_name_groups" not in local_data_packages.get(game_name, {}), + f"location_name_groups are not supposed to be in MultiServer data for {game_name}") + for game_name in local_data_packages: + expect_true(game_name in involved_games, + f"Received unexpected extra data package for {game_name} from MultiServer") + assert_equal(local_collected_items, collected_items, + "MultiServer did not load or save correctly") + + print(f"\nTesting [{n}] {game} in {multidata} on customserver with {collected_items} items collected") + prev_host_adr: str + with WebHostServeGame(webhost_client, room) as host: + prev_host_adr = host.address + with Client(host.address, game, "Player1") as client: + web_data_packages = client.games_packages + web_collected_items = len(client.checked_locations) + if collected_items < 2: # Clique only has 2 Locations + client.collect_any() + if collected_items == 1: + sleep(1) # wait for the server to collect the item + stop_autohost(True) # simulate Ctrl+C + sleep(3) + autohost(webapp.config) # this will spin the room right up again + sleep(1) # make log less annoying + # if saving failed, the next iteration will fail below + + # verify server shut down + try: + with Client(prev_host_adr, game, "Player1") as client: + assert_true(False, "Server did not shut down") + except ConnectionError: + pass + + for game_name in sorted(involved_games): + expect_true(game_name in web_data_packages, + f"{game_name} missing from customserver data package") + expect_true("item_name_groups" not in web_data_packages.get(game_name, {}), + f"item_name_groups are not supposed to be in customserver data for {game_name}") + expect_true("location_name_groups" not in web_data_packages.get(game_name, {}), + f"location_name_groups are not supposed to be in customserver data for {game_name}") + for game_name in web_data_packages: + expect_true(game_name in involved_games, + f"Received unexpected extra data package for {game_name} from customserver") + assert_equal(web_collected_items, collected_items, + "customserver did not load or save correctly during/after " + + ("Ctrl+C" if collected_items == 2 else "/exit")) + + # compare customserver to MultiServer + expect_equal(local_data_packages, web_data_packages, + "customserver datapackage differs from MultiServer") + + sleep(5.5) # make sure all tasks actually stopped + + # raise an exception in customserver and verify the save doesn't get destroyed + # local variables room is the last room's id here + old_data = get_multidata_for_room(webhost_client, room) + print(f"Destroying multidata for {room}") + set_multidata_for_room(webhost_client, room, bytes([0])) + try: + start_room(webhost_client, room, timeout=7) + except TimeoutError: + pass + else: + assert_true(False, "Room started with destroyed multidata") + print(f"Restoring multidata for {room}") + set_multidata_for_room(webhost_client, room, old_data) + with WebHostServeGame(webhost_client, room) as host: + with Client(host.address, game, "Player1") as client: + assert_equal(len(client.checked_locations), 2, + "Save was destroyed during exception in customserver") + print("Save file is not busted 🥳") + + finally: + print("Stopping autohost") + stop_autohost(False) + + if failure: + print("Some tests failed") + exit(1) + exit(0) diff --git a/test/hosting/client.py b/test/hosting/client.py new file mode 100644 index 000000000000..b805bb6a2638 --- /dev/null +++ b/test/hosting/client.py @@ -0,0 +1,110 @@ +import json +import sys +from typing import Any, Collection, Dict, Iterable, Optional +from websockets import ConnectionClosed +from websockets.sync.client import connect, ClientConnection +from threading import Thread + + +__all__ = [ + "Client" +] + + +class Client: + """Incomplete, minimalistic sync test client for AP network protocol""" + + recv_timeout = 1.0 + + host: str + game: str + slot: str + password: Optional[str] + + _ws: Optional[ClientConnection] + + games: Iterable[str] + data_package_checksums: Dict[str, Any] + games_packages: Dict[str, Any] + missing_locations: Collection[int] + checked_locations: Collection[int] + + def __init__(self, host: str, game: str, slot: str, password: Optional[str] = None) -> None: + self.host = host + self.game = game + self.slot = slot + self.password = password + self._ws = None + self.games = [] + self.data_package_checksums = {} + self.games_packages = {} + self.missing_locations = [] + self.checked_locations = [] + + def __enter__(self) -> "Client": + try: + self.connect() + except BaseException: + self.__exit__(*sys.exc_info()) + raise + return self + + def __exit__(self, exc_type, exc_val, exc_tb) -> None: # type: ignore + self.close() + + def _poll(self) -> None: + assert self._ws + try: + while True: + self._ws.recv() + except (TimeoutError, ConnectionClosed, KeyboardInterrupt, SystemExit): + pass + + def connect(self) -> None: + self._ws = connect(f"ws://{self.host}") + room_info = json.loads(self._ws.recv(self.recv_timeout))[0] + self.games = sorted(room_info["games"]) + self.data_package_checksums = room_info["datapackage_checksums"] + self._ws.send(json.dumps([{ + "cmd": "GetDataPackage", + "games": list(self.games), + }])) + data_package_msg = json.loads(self._ws.recv(self.recv_timeout))[0] + self.games_packages = data_package_msg["data"]["games"] + self._ws.send(json.dumps([{ + "cmd": "Connect", + "game": self.game, + "name": self.slot, + "password": self.password, + "uuid": "", + "version": { + "class": "Version", + "major": 0, + "minor": 4, + "build": 6, + }, + "items_handling": 0, + "tags": [], + "slot_data": False, + }])) + connect_result_msg = json.loads(self._ws.recv(self.recv_timeout))[0] + if connect_result_msg["cmd"] != "Connected": + raise ConnectionError(", ".join(connect_result_msg.get("errors", [connect_result_msg["cmd"]]))) + self.missing_locations = connect_result_msg["missing_locations"] + self.checked_locations = connect_result_msg["checked_locations"] + + def close(self) -> None: + if self._ws: + Thread(target=self._poll).start() + self._ws.close() + + def collect(self, locations: Iterable[int]) -> None: + if not self._ws: + raise ValueError("Not connected") + self._ws.send(json.dumps([{ + "cmd": "LocationChecks", + "locations": locations, + }])) + + def collect_any(self) -> None: + self.collect([next(iter(self.missing_locations))]) diff --git a/test/hosting/generate.py b/test/hosting/generate.py new file mode 100644 index 000000000000..356cbcca25a0 --- /dev/null +++ b/test/hosting/generate.py @@ -0,0 +1,75 @@ +import json +import sys +import warnings +from pathlib import Path +from typing import Iterable, Union, TYPE_CHECKING + +if TYPE_CHECKING: + from multiprocessing.managers import ListProxy # noqa + +__all__ = [ + "generate_local", +] + + +def _generate_local_inner(games: Iterable[str], + dest: Union[Path, str], + results: "ListProxy[Union[Path, BaseException]]") -> None: + original_argv = sys.argv + warnings.simplefilter("ignore") + try: + from tempfile import TemporaryDirectory + + if not isinstance(dest, Path): + dest = Path(dest) + + with TemporaryDirectory() as players_dir: + with TemporaryDirectory() as output_dir: + import Generate + + for n, game in enumerate(games, 1): + player_path = Path(players_dir) / f"{n}.yaml" + with open(player_path, "w", encoding="utf-8") as f: + f.write(json.dumps({ + "name": f"Player{n}", + "game": game, + game: {"hard_mode": "true"}, + "description": f"generate_local slot {n} ('Player{n}'): {game}", + })) + + # this is basically copied from test/programs/test_generate.py + # uses a reproducible seed that is different for each set of games + sys.argv = [sys.argv[0], "--seed", str(hash(tuple(games))), + "--player_files_path", players_dir, + "--outputpath", output_dir] + Generate.main() + output_files = list(Path(output_dir).glob('*.zip')) + assert len(output_files) == 1 + final_file = dest / output_files[0].name + output_files[0].rename(final_file) + results.append(final_file) + except BaseException as e: + results.append(e) + raise e + finally: + sys.argv = original_argv + + +def generate_local(games: Iterable[str], dest: Union[Path, str]) -> Path: + from multiprocessing import Manager, Process, set_start_method + + try: + set_start_method("spawn") + except RuntimeError: + pass + + manager = Manager() + results: "ListProxy[Union[Path, Exception]]" = manager.list() + + p = Process(target=_generate_local_inner, args=(games, dest, results)) + p.start() + p.join() + result = results[0] + if isinstance(result, BaseException): + raise Exception("Could not generate multiworld") from result + return result diff --git a/test/hosting/serve.py b/test/hosting/serve.py new file mode 100644 index 000000000000..c3eaac87cc08 --- /dev/null +++ b/test/hosting/serve.py @@ -0,0 +1,115 @@ +import sys +from pathlib import Path +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from threading import Event + from werkzeug.test import Client as FlaskClient + +__all__ = [ + "ServeGame", + "LocalServeGame", + "WebHostServeGame", +] + + +class ServeGame: + address: str + + +def _launch_multiserver(multidata: Path, ready: "Event", stop: "Event") -> None: + import os + import warnings + + original_argv = sys.argv + original_stdin = sys.stdin + warnings.simplefilter("ignore") + try: + import asyncio + from MultiServer import main, parse_args + + sys.argv = [sys.argv[0], str(multidata), "--host", "127.0.0.1"] + r, w = os.pipe() + sys.stdin = os.fdopen(r, "r") + + async def set_ready() -> None: + await asyncio.sleep(.01) # switch back to other task once more + ready.set() # server should be up, set ready state + + async def wait_stop() -> None: + await asyncio.get_event_loop().run_in_executor(None, stop.wait) + os.fdopen(w, "w").write("/exit") + + async def run() -> None: + # this will run main() until first await, then switch to set_ready() + await asyncio.gather( + main(parse_args()), + set_ready(), + wait_stop(), + ) + + asyncio.run(run()) + finally: + sys.argv = original_argv + sys.stdin = original_stdin + + +class LocalServeGame(ServeGame): + from multiprocessing import Process + + _multidata: Path + _proc: Process + _stop: "Event" + + def __init__(self, multidata: Path) -> None: + self.address = "" + self._multidata = multidata + + def __enter__(self) -> "LocalServeGame": + from multiprocessing import Manager, Process, set_start_method + + try: + set_start_method("spawn") + except RuntimeError: + pass + + manager = Manager() + ready: "Event" = manager.Event() + self._stop = manager.Event() + + self._proc = Process(target=_launch_multiserver, args=(self._multidata, ready, self._stop)) + try: + self._proc.start() + ready.wait(30) + self.address = "localhost:38281" + return self + except BaseException: + self.__exit__(*sys.exc_info()) + raise + + def __exit__(self, exc_type, exc_val, exc_tb) -> None: # type: ignore + try: + self._stop.set() + self._proc.join(30) + except TimeoutError: + self._proc.terminate() + self._proc.join() + + +class WebHostServeGame(ServeGame): + _client: "FlaskClient" + _room: str + + def __init__(self, app_client: "FlaskClient", room: str) -> None: + self.address = "" + self._client = app_client + self._room = room + + def __enter__(self) -> "WebHostServeGame": + from .webhost import start_room + self.address = start_room(self._client, self._room) + return self + + def __exit__(self, exc_type, exc_val, exc_tb) -> None: # type: ignore + from .webhost import stop_room + stop_room(self._client, self._room, timeout=30) diff --git a/test/hosting/webhost.py b/test/hosting/webhost.py new file mode 100644 index 000000000000..e1e31ae466c4 --- /dev/null +++ b/test/hosting/webhost.py @@ -0,0 +1,201 @@ +import re +from pathlib import Path +from typing import TYPE_CHECKING, Optional, cast + +if TYPE_CHECKING: + from flask import Flask + from werkzeug.test import Client as FlaskClient + +__all__ = [ + "get_app", + "upload_multidata", + "create_room", + "start_room", + "stop_room", + "set_room_timeout", + "get_multidata_for_room", + "set_multidata_for_room", + "stop_autohost", +] + + +def get_app(tempdir: str) -> "Flask": + from WebHostLib import app as raw_app + from WebHost import get_app + raw_app.config["PONY"] = { + "provider": "sqlite", + "filename": str(Path(tempdir) / "host.db"), + "create_db": True, + } + raw_app.config.update({ + "TESTING": True, + "HOST_ADDRESS": "localhost", + "HOSTERS": 1, + }) + return get_app() + + +def upload_multidata(app_client: "FlaskClient", multidata: Path) -> str: + response = app_client.post("/uploads", data={ + "file": multidata.open("rb"), + }) + assert response.status_code < 400, f"Upload of {multidata} failed: status {response.status_code}" + assert "Location" in response.headers, f"Upload of {multidata} failed: no redirect" + location = response.headers["Location"] + assert isinstance(location, str) + assert location.startswith("/seed/"), f"Upload of {multidata} failed: unexpected redirect" + return location[6:] + + +def create_room(app_client: "FlaskClient", seed: str, auto_start: bool = False) -> str: + response = app_client.get(f"/new_room/{seed}") + assert response.status_code < 400, f"Creating room for {seed} failed: status {response.status_code}" + assert "Location" in response.headers, f"Creating room for {seed} failed: no redirect" + location = response.headers["Location"] + assert isinstance(location, str) + assert location.startswith("/room/"), f"Creating room for {seed} failed: unexpected redirect" + room_id = location[6:] + + if not auto_start: + # by default, creating a room will auto-start it, so we update last activity here + stop_room(app_client, room_id, simulate_idle=False) + + return room_id + + +def start_room(app_client: "FlaskClient", room_id: str, timeout: float = 30) -> str: + from time import sleep + + poll_interval = .2 + + print(f"Starting room {room_id}") + no_timeout = timeout <= 0 + while no_timeout or timeout > 0: + response = app_client.get(f"/room/{room_id}") + assert response.status_code == 200, f"Starting room for {room_id} failed: status {response.status_code}" + match = re.search(r"/connect ([\w:.\-]+)", response.text) + if match: + return match[1] + timeout -= poll_interval + sleep(poll_interval) + raise TimeoutError("Room did not start") + + +def stop_room(app_client: "FlaskClient", + room_id: str, + timeout: Optional[float] = None, + simulate_idle: bool = True) -> None: + from datetime import datetime, timedelta + from time import sleep + + from pony.orm import db_session + + from WebHostLib.models import Command, Room + from WebHostLib import app + + poll_interval = 2 + + print(f"Stopping room {room_id}") + room_uuid = app.url_map.converters["suuid"].to_python(None, room_id) # type: ignore[arg-type] + + if timeout is not None: + sleep(.1) # should not be required, but other things might use threading + + with db_session: + room: Room = Room.get(id=room_uuid) + if simulate_idle: + new_last_activity = datetime.utcnow() - timedelta(seconds=room.timeout + 5) + else: + new_last_activity = datetime.utcnow() - timedelta(days=3) + room.last_activity = new_last_activity + address = f"localhost:{room.last_port}" if room.last_port > 0 else None + if address: + original_timeout = room.timeout + room.timeout = 1 # avoid spinning it up again + Command(room=room, commandtext="/exit") + + try: + if address and timeout is not None: + print("waiting for shutdown") + import socket + host_str, port_str = tuple(address.split(":")) + address_tuple = host_str, int(port_str) + + no_timeout = timeout <= 0 + while no_timeout or timeout > 0: + s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + try: + s.connect(address_tuple) + s.close() + except ConnectionRefusedError: + return + sleep(poll_interval) + timeout -= poll_interval + + raise TimeoutError("Room did not stop") + finally: + with db_session: + room = Room.get(id=room_uuid) + room.last_port = 0 # easier to detect when the host is up this way + if address: + room.timeout = original_timeout + room.last_activity = new_last_activity + print("timeout restored") + + +def set_room_timeout(room_id: str, timeout: float) -> None: + from pony.orm import db_session + + from WebHostLib.models import Room + from WebHostLib import app + + room_uuid = app.url_map.converters["suuid"].to_python(None, room_id) # type: ignore[arg-type] + with db_session: + room: Room = Room.get(id=room_uuid) + room.timeout = timeout + + +def get_multidata_for_room(webhost_client: "FlaskClient", room_id: str) -> bytes: + from pony.orm import db_session + + from WebHostLib.models import Room + from WebHostLib import app + + room_uuid = app.url_map.converters["suuid"].to_python(None, room_id) # type: ignore[arg-type] + with db_session: + room: Room = Room.get(id=room_uuid) + return cast(bytes, room.seed.multidata) + + +def set_multidata_for_room(webhost_client: "FlaskClient", room_id: str, data: bytes) -> None: + from pony.orm import db_session + + from WebHostLib.models import Room + from WebHostLib import app + + room_uuid = app.url_map.converters["suuid"].to_python(None, room_id) # type: ignore[arg-type] + with db_session: + room: Room = Room.get(id=room_uuid) + room.seed.multidata = data + + +def stop_autohost(graceful: bool = True) -> None: + import os + import signal + + import multiprocessing + + from WebHostLib.autolauncher import stop + + stop() + proc: multiprocessing.process.BaseProcess + for proc in filter(lambda child: child.name.startswith("MultiHoster"), multiprocessing.active_children()): + if graceful and proc.pid: + os.kill(proc.pid, getattr(signal, "CTRL_C_EVENT", signal.SIGINT)) + else: + proc.kill() + try: + proc.join(30) + except TimeoutError: + proc.kill() + proc.join() diff --git a/test/hosting/world.py b/test/hosting/world.py new file mode 100644 index 000000000000..e083e027fee1 --- /dev/null +++ b/test/hosting/world.py @@ -0,0 +1,42 @@ +import re +import shutil +from pathlib import Path +from typing import Dict + + +__all__ = ["copy", "delete"] + + +_new_worlds: Dict[str, str] = {} + + +def copy(src: str, dst: str) -> None: + from Utils import get_file_safe_name + from worlds import AutoWorldRegister + + assert dst not in _new_worlds, "World already created" + if '"' in dst or "\\" in dst: # easier to reject than to escape + raise ValueError(f"Unsupported symbols in {dst}") + dst_folder_name = get_file_safe_name(dst.lower()) + src_cls = AutoWorldRegister.world_types[src] + src_folder = Path(src_cls.__file__).parent + worlds_folder = src_folder.parent + if (not src_cls.__file__.endswith("__init__.py") or not src_folder.is_dir() + or not (worlds_folder / "generic").is_dir()): + raise ValueError(f"Unsupported layout for copy_world from {src}") + dst_folder = worlds_folder / dst_folder_name + if dst_folder.is_dir(): + raise ValueError(f"Destination {dst_folder} already exists") + shutil.copytree(src_folder, dst_folder) + _new_worlds[dst] = str(dst_folder) + with open(dst_folder / "__init__.py", "r", encoding="utf-8-sig") as f: + contents = f.read() + contents = re.sub(r'game\s*=\s*[\'"]' + re.escape(src) + r'[\'"]', f'game = "{dst}"', contents) + with open(dst_folder / "__init__.py", "w", encoding="utf-8") as f: + f.write(contents) + + +def delete(name: str) -> None: + assert name in _new_worlds, "World not created by this script" + shutil.rmtree(_new_worlds[name]) + del _new_worlds[name]