Skip to content

Commit a834328

Browse files
Threading improvement
1 parent cd58ee7 commit a834328

File tree

3 files changed

+68
-18
lines changed

3 files changed

+68
-18
lines changed

discordgsm/async_utils.py

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
else:
88
from typing import ParamSpec
99

10-
from typing import Awaitable, Callable, Generator, List, TypeVar
10+
from typing import AsyncGenerator, Awaitable, Callable, List, TypeVar
1111

1212
R = TypeVar("R")
1313
P = ParamSpec("P")
@@ -17,16 +17,38 @@ def run_in_executor(_func: Callable[P, R]) -> Callable[P, Awaitable[R]]:
1717
@wraps(_func)
1818
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
1919
func = partial(_func, *args, **kwargs)
20-
return await asyncio.get_running_loop().run_in_executor(executor=None, func=func)
20+
return await asyncio.get_running_loop().run_in_executor(
21+
executor=None, func=func
22+
)
2123

2224
return wrapper
2325

2426

25-
T = TypeVar('T')
27+
T = TypeVar("T")
2628

2729

28-
async def to_chunks(lst: List[T], n: int) -> Generator[List[T], None, None]:
30+
async def to_chunks(lst: List[T], n: int) -> AsyncGenerator[List[T], None]:
2931
"""Yield successive n-sized chunks from lst."""
3032
for i in range(0, len(lst), n):
3133
await asyncio.sleep(0)
32-
yield lst[i:i + n]
34+
yield lst[i : i + n]
35+
36+
37+
# Define a function to run the async function in a new event loop
38+
def run_in_new_loop(async_func, *args):
39+
# Create a new event loop
40+
loop = asyncio.new_event_loop()
41+
# Set the new event loop as the event loop for the current context
42+
asyncio.set_event_loop(loop)
43+
try:
44+
# Run the async function in the new event loop
45+
loop.run_until_complete(async_func(*args))
46+
except RuntimeError:
47+
# RuntimeError('cannot schedule new futures after shutdown')
48+
pass
49+
except KeyboardInterrupt:
50+
# Optionally show a message or perform other cleanup here
51+
pass
52+
finally:
53+
# Close the loop after use
54+
loop.close()

discordgsm/main.py

Lines changed: 32 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import os
66
from datetime import datetime
77
from enum import Enum
8+
import threading
89
from typing import Optional
910

1011
import aiohttp
@@ -15,7 +16,7 @@
1516
from discord.ext import tasks
1617
from discord.ui import Button, Modal, Select, TextInput, View
1718
from dotenv import load_dotenv
18-
from discordgsm.async_utils import to_chunks
19+
from discordgsm.async_utils import run_in_new_loop, to_chunks
1920

2021
from discordgsm.environment import AdvertiseType, env
2122
from discordgsm.gamedig import GamedigGame
@@ -67,8 +68,8 @@ async def on_ready():
6768
await sync_commands(whitelist_guilds)
6869
await tasks_fetch_messages()
6970

70-
if not tasks_query.is_running():
71-
tasks_query.start()
71+
if not tasks_query_servers.is_running():
72+
tasks_query_servers.start()
7273

7374
if not cache_guilds.is_running() and env('WEB_API_ENABLE'):
7475
cache_guilds.start()
@@ -927,9 +928,25 @@ def group_servers_by_message_id(servers: list[Server]) -> dict[int, list[Server]
927928

928929

929930
# region Application tasks
931+
_tasks_query_servers_thread: Optional[threading.Thread] = None
932+
exit_signal = threading.Event()
933+
934+
930935
@tasks.loop(seconds=max(15.0, env('TASK_QUERY_SERVER')))
931-
async def tasks_query():
936+
async def tasks_query_servers():
932937
"""Query servers (Scheduled)"""
938+
global _tasks_query_servers_thread
939+
940+
if _tasks_query_servers_thread is not None:
941+
while _tasks_query_servers_thread.is_alive():
942+
await asyncio.sleep(1)
943+
944+
if _tasks_query_servers_thread is None or not _tasks_query_servers_thread.is_alive():
945+
_tasks_query_servers_thread = threading.Thread(target=run_in_new_loop, args=(__tasks_query_servers, tasks_query_servers.current_loop))
946+
_tasks_query_servers_thread.start()
947+
948+
949+
async def __tasks_query_servers(current_loop: int):
933950
# Pre query servers, some servers cannot be queried one by one
934951
games_servers_count = await database.count_servers_per_game()
935952
pre_query_tasks = [pre_query(protocol({})) for name, protocol in protocols.items() if protocol.pre_query_required and games_servers_count.get(name, 0) > 0]
@@ -954,13 +971,17 @@ async def tasks_query():
954971
Logger.info(f'Query servers: Total = {len(queried_servers)}, Success = {success}, Failed = {failed} ({percent}% fail)')
955972

956973
# Run the tasks after the server queries
957-
await asyncio.gather(tasks_send_alert(), tasks_edit_messages(), tasks_presence_update(tasks_query.current_loop))
974+
await asyncio.gather(tasks_send_alert(), tasks_edit_messages(), tasks_presence_update(current_loop))
958975

959976

960977
async def query_servers(distinct_servers: dict[tuple[str, str, int, str], list[Server]]):
961-
query_tasks = [query_distinct_server(servers) for servers in distinct_servers.values()]
978+
query_tasks = [asyncio.create_task(query_distinct_server(servers)) for servers in distinct_servers.values()]
962979

963980
async for chunks in to_chunks(query_tasks, int(os.getenv('TASK_QUERY_CHUNK_SIZE', '50'))):
981+
if exit_signal.is_set():
982+
Logger.debug(f'Exit signal received. Terminating server queries.')
983+
break
984+
964985
await asyncio.gather(*chunks)
965986

966987
servers: list[Server] = []
@@ -970,6 +991,7 @@ async def query_servers(distinct_servers: dict[tuple[str, str, int, str], list[S
970991

971992
return servers
972993

994+
973995
async def query_distinct_server(servers: list[Server]):
974996
"""Query server"""
975997
server = servers[0]
@@ -1120,6 +1142,10 @@ async def tasks_edit_messages():
11201142

11211143
# Discord Rate limit: 50 requests per second
11221144
async for chunks in to_chunks(tasks, 25):
1145+
if exit_signal.is_set():
1146+
Logger.debug('Exit signal received. Terminating message editing tasks.')
1147+
break
1148+
11231149
start = datetime.now().timestamp()
11241150
results += await asyncio.gather(*chunks, return_exceptions=True)
11251151
time_used = datetime.now().timestamp() - start

main.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,14 @@
1010
# A valid token should contains 2 dots and 3 items
1111
if len(items) != 3:
1212
Logger.critical('Improper token has been passed, please change APP_TOKEN to a valid token. Learn more: https://discordgsm.com/guide/how-to-get-a-discord-bot-token')
13-
else:
14-
hmac_hide = '*' * len(items[2]) # Hide the secret
15-
Logger.debug(f'Static token: {items[0]}.{items[1]}.{hmac_hide}')
13+
exit(1)
1614

17-
# Run the bot
18-
from discordgsm.main import client
19-
client.run(token)
15+
hmac_hide = '*' * len(items[2]) # Hide the secret
16+
Logger.debug(f'Static token: {items[0]}.{items[1]}.{hmac_hide}')
2017

21-
Logger.info('Stopped Discord Game Server Monitor.')
18+
# Run the bot
19+
from discordgsm.main import client, exit_signal
20+
client.run(token)
21+
22+
exit_signal.set()
23+
Logger.info('Stopping Discord Game Server Monitor...')

0 commit comments

Comments
 (0)