Skip to content

Commit

Permalink
Threading improvement
Browse files Browse the repository at this point in the history
  • Loading branch information
BattlefieldDuck committed Feb 24, 2024
1 parent cd58ee7 commit a834328
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 18 deletions.
32 changes: 27 additions & 5 deletions discordgsm/async_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
else:
from typing import ParamSpec

from typing import Awaitable, Callable, Generator, List, TypeVar
from typing import AsyncGenerator, Awaitable, Callable, List, TypeVar

R = TypeVar("R")
P = ParamSpec("P")
Expand All @@ -17,16 +17,38 @@ def run_in_executor(_func: Callable[P, R]) -> Callable[P, Awaitable[R]]:
@wraps(_func)
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
func = partial(_func, *args, **kwargs)
return await asyncio.get_running_loop().run_in_executor(executor=None, func=func)
return await asyncio.get_running_loop().run_in_executor(
executor=None, func=func
)

return wrapper


T = TypeVar('T')
T = TypeVar("T")


async def to_chunks(lst: List[T], n: int) -> Generator[List[T], None, None]:
async def to_chunks(lst: List[T], n: int) -> AsyncGenerator[List[T], None]:
"""Yield successive n-sized chunks from lst."""
for i in range(0, len(lst), n):
await asyncio.sleep(0)
yield lst[i:i + n]
yield lst[i : i + n]


# Define a function to run the async function in a new event loop
def run_in_new_loop(async_func, *args):
# Create a new event loop
loop = asyncio.new_event_loop()
# Set the new event loop as the event loop for the current context
asyncio.set_event_loop(loop)
try:
# Run the async function in the new event loop
loop.run_until_complete(async_func(*args))
except RuntimeError:
# RuntimeError('cannot schedule new futures after shutdown')
pass
except KeyboardInterrupt:
# Optionally show a message or perform other cleanup here
pass
finally:
# Close the loop after use
loop.close()
38 changes: 32 additions & 6 deletions discordgsm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import os
from datetime import datetime
from enum import Enum
import threading
from typing import Optional

import aiohttp
Expand All @@ -15,7 +16,7 @@
from discord.ext import tasks
from discord.ui import Button, Modal, Select, TextInput, View
from dotenv import load_dotenv
from discordgsm.async_utils import to_chunks
from discordgsm.async_utils import run_in_new_loop, to_chunks

from discordgsm.environment import AdvertiseType, env
from discordgsm.gamedig import GamedigGame
Expand Down Expand Up @@ -67,8 +68,8 @@ async def on_ready():
await sync_commands(whitelist_guilds)
await tasks_fetch_messages()

if not tasks_query.is_running():
tasks_query.start()
if not tasks_query_servers.is_running():
tasks_query_servers.start()

if not cache_guilds.is_running() and env('WEB_API_ENABLE'):
cache_guilds.start()
Expand Down Expand Up @@ -927,9 +928,25 @@ def group_servers_by_message_id(servers: list[Server]) -> dict[int, list[Server]


# region Application tasks
_tasks_query_servers_thread: Optional[threading.Thread] = None
exit_signal = threading.Event()


@tasks.loop(seconds=max(15.0, env('TASK_QUERY_SERVER')))
async def tasks_query():
async def tasks_query_servers():
"""Query servers (Scheduled)"""
global _tasks_query_servers_thread

if _tasks_query_servers_thread is not None:
while _tasks_query_servers_thread.is_alive():
await asyncio.sleep(1)

if _tasks_query_servers_thread is None or not _tasks_query_servers_thread.is_alive():
_tasks_query_servers_thread = threading.Thread(target=run_in_new_loop, args=(__tasks_query_servers, tasks_query_servers.current_loop))
_tasks_query_servers_thread.start()


async def __tasks_query_servers(current_loop: int):
# Pre query servers, some servers cannot be queried one by one
games_servers_count = await database.count_servers_per_game()
pre_query_tasks = [pre_query(protocol({})) for name, protocol in protocols.items() if protocol.pre_query_required and games_servers_count.get(name, 0) > 0]
Expand All @@ -954,13 +971,17 @@ async def tasks_query():
Logger.info(f'Query servers: Total = {len(queried_servers)}, Success = {success}, Failed = {failed} ({percent}% fail)')

# Run the tasks after the server queries
await asyncio.gather(tasks_send_alert(), tasks_edit_messages(), tasks_presence_update(tasks_query.current_loop))
await asyncio.gather(tasks_send_alert(), tasks_edit_messages(), tasks_presence_update(current_loop))


async def query_servers(distinct_servers: dict[tuple[str, str, int, str], list[Server]]):
query_tasks = [query_distinct_server(servers) for servers in distinct_servers.values()]
query_tasks = [asyncio.create_task(query_distinct_server(servers)) for servers in distinct_servers.values()]

async for chunks in to_chunks(query_tasks, int(os.getenv('TASK_QUERY_CHUNK_SIZE', '50'))):
if exit_signal.is_set():
Logger.debug(f'Exit signal received. Terminating server queries.')
break

await asyncio.gather(*chunks)

servers: list[Server] = []
Expand All @@ -970,6 +991,7 @@ async def query_servers(distinct_servers: dict[tuple[str, str, int, str], list[S

return servers


async def query_distinct_server(servers: list[Server]):
"""Query server"""
server = servers[0]
Expand Down Expand Up @@ -1120,6 +1142,10 @@ async def tasks_edit_messages():

# Discord Rate limit: 50 requests per second
async for chunks in to_chunks(tasks, 25):
if exit_signal.is_set():
Logger.debug('Exit signal received. Terminating message editing tasks.')
break

start = datetime.now().timestamp()
results += await asyncio.gather(*chunks, return_exceptions=True)
time_used = datetime.now().timestamp() - start
Expand Down
16 changes: 9 additions & 7 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,14 @@
# A valid token should contains 2 dots and 3 items
if len(items) != 3:
Logger.critical('Improper token has been passed, please change APP_TOKEN to a valid token. Learn more: https://discordgsm.com/guide/how-to-get-a-discord-bot-token')
else:
hmac_hide = '*' * len(items[2]) # Hide the secret
Logger.debug(f'Static token: {items[0]}.{items[1]}.{hmac_hide}')
exit(1)

# Run the bot
from discordgsm.main import client
client.run(token)
hmac_hide = '*' * len(items[2]) # Hide the secret
Logger.debug(f'Static token: {items[0]}.{items[1]}.{hmac_hide}')

Logger.info('Stopped Discord Game Server Monitor.')
# Run the bot
from discordgsm.main import client, exit_signal
client.run(token)

exit_signal.set()
Logger.info('Stopping Discord Game Server Monitor...')

0 comments on commit a834328

Please sign in to comment.