From 021ef4efab2afb1ffd75c4730e4b474e8aa301e4 Mon Sep 17 00:00:00 2001 From: init0 Date: Fri, 20 Jul 2018 12:11:15 +0200 Subject: [PATCH] Fixed too many things to count --- core/events.py | 1 + core/lavalink.py | 29 ++++++++++++++-------------- core/load_balancing.py | 14 +++----------- core/node.py | 44 +++++++++++++++++++----------------------- core/player.py | 43 +++++++++++++++++++++++++++++++++++++---- requirements.txt | 1 - 6 files changed, 78 insertions(+), 54 deletions(-) diff --git a/core/events.py b/core/events.py index 0497760..6d93827 100644 --- a/core/events.py +++ b/core/events.py @@ -114,6 +114,7 @@ async def track_start(self, event: TrackStartEvent): async def track_end(self, event: TrackEndEvent): event.player.reset() + await event.player.stop() async def track_exception(self, event: TrackExceptionEvent): pass diff --git a/core/lavalink.py b/core/lavalink.py index 8646bae..ef6eeab 100644 --- a/core/lavalink.py +++ b/core/lavalink.py @@ -1,14 +1,14 @@ import asyncio +import logging from enum import Enum -import logging from discord import InvalidArgument from discord.ext.commands import BotMissingPermissions from .exceptions import IllegalAction -from .node import Node -from .player import Player, AudioTrack from .load_balancing import LoadBalancer +from .node import Node +from .player import Player, AudioTrackPlaylist logger = logging.getLogger("magma") @@ -35,7 +35,7 @@ def __init__(self, user_id, shard_count): @property def playing_guilds(self): - return {name: node.stats.playing_players for name, node in self.nodes.items()} + return {name: node.stats.playing_players for name, node in self.nodes.items() if node.stats} @property def total_playing_guilds(self): @@ -60,10 +60,11 @@ def get_link(self, guild_id: int, bot=None): :return: A Link """ guild_id = int(guild_id) - if guild_id not in self.links: - if not bot: - raise IllegalAction("A bot instance was not passed when trying to acquire a Link!") - self.links[guild_id] = Link(self, guild_id, bot) + + if guild_id in self.links or not bot: + return self.links.get(guild_id) + + self.links[guild_id] = Link(self, guild_id, bot) return self.links[guild_id] async def add_node(self, name, uri, rest_uri, password): @@ -157,8 +158,8 @@ async def get_tracks(self, query): :return: """ node = await self.get_node(True) - tracks = await node.get_tracks(query) - return [AudioTrack(track) for track in tracks] + results = await node.get_tracks(query) + return AudioTrackPlaylist(results) async def get_tracks_yt(self, query): return await self.get_tracks("ytsearch:" + query) @@ -173,7 +174,7 @@ async def get_node(self, select_if_absent=False): :param select_if_absent: A boolean that indicates if a Node should be created if there is none :return: A Node """ - if select_if_absent and not self.node: + if select_if_absent and not (self.node and self.node.available): await self.change_node(await self.lavalink.get_best_node()) return self.node @@ -188,8 +189,8 @@ async def change_node(self, node): self.node.links[self.guild_id] = self if self.last_voice_update: await node.send(self.last_voice_update) - if self.player: - await self.player.node_changed() + if self._player: + await self._player.node_changed() async def connect(self, channel): """ @@ -206,7 +207,7 @@ async def connect(self, channel): me = channel.guild.me permissions = me.permissions_in(channel) - if not permissions.connect and not permissions.move_members: + if (not permissions.connect or len(channel.members) >= channel.user_limit) and not permissions.move_members: raise BotMissingPermissions(["connect"]) self.set_state(State.CONNECTING) diff --git a/core/load_balancing.py b/core/load_balancing.py index 8ac5ff4..fe3b904 100644 --- a/core/load_balancing.py +++ b/core/load_balancing.py @@ -66,16 +66,10 @@ def __init__(self, node, lavalink): async def get_total(self): # hard maths stats = self.node.stats - if not stats: - return + if not self.node.available or not stats: + return big_number - if self.lavalink: - # REEEEE complexity levels - for link in self.lavalink.links.values(): - if self.node == await link.get_node() and link.player.current and not link.player.paused: - self.player_penalty += 1 - else: - self.player_penalty = stats.playing_players + self.player_penalty = stats.playing_players self.cpu_penalty = 1.05 ** (100 * stats.system_load) * 10 - 10 if stats.avg_frame_deficit != -1: @@ -83,6 +77,4 @@ async def get_total(self): self.null_frame_penalty = (1.03 ** (500 * (stats.avg_frame_nulled / 3000))) * 300 - 300 self.null_frame_penalty *= 2 - if not self.node.available or not self.node.stats: - return big_number return self.player_penalty + self.cpu_penalty + self.deficit_frame_penalty + self.null_frame_penalty diff --git a/core/node.py b/core/node.py index 3c80fff..8dd4e0e 100644 --- a/core/node.py +++ b/core/node.py @@ -6,13 +6,12 @@ import aiohttp import logging import websockets +from discord.backoff import ExponentialBackoff from .events import TrackEndEvent, TrackStuckEvent, TrackExceptionEvent from .exceptions import NodeException logger = logging.getLogger("magma") -timeout = 5 -tries = 5 class NodeStats: @@ -49,6 +48,7 @@ def __init__(self, msg): class KeepAlive(threading.Thread): def __init__(self, node, interval, *args, **kwargs): super().__init__(*args, **kwargs) + self.name = f"{node.name}-KeepAlive" self.daemon = True self.node = node self.ws = node.ws @@ -68,13 +68,9 @@ def run(self): asyncio.run_coroutine_threadsafe(self.node.on_close(e.code, e.reason), loop=self.loop) return - try: - logger.info(f"Attempting to reconnect `{self.node.name}`") - future = asyncio.run_coroutine_threadsafe(self.node.connect(), loop=self.loop) - future.result() - except NodeException: - future = asyncio.run_coroutine_threadsafe(self.node.on_close(e.code, e.reason), loop=self.loop) - future.result() + logger.info(f"Attempting to reconnect `{self.node.name}`") + future = asyncio.run_coroutine_threadsafe(self.node.connect(), loop=self.loop) + future.result() def stop(self): self._stop_ev.set() @@ -94,19 +90,18 @@ def __init__(self, lavalink, name, uri, rest_uri, headers): self.available = False self.closing = False - async def _connect(self, try_=0): - try: - self.ws = await websockets.connect(self.uri, extra_headers=self.headers) - self.keep_alive = KeepAlive(self, 4) - self.keep_alive.start() - asyncio.ensure_future(self.listen()) - except OSError: - if try_ < tries: - logger.error(f"Connection refused, trying again in {timeout}s, try: {try_+1}/{tries}") - await asyncio.sleep(timeout) - await self._connect(try_+1) - else: - raise NodeException(f"Connection failed after {tries} tries") + async def _connect(self): + backoff = ExponentialBackoff(2) + while not (self.ws and self.ws.open): + try: + self.ws = await websockets.connect(self.uri, extra_headers=self.headers) + asyncio.ensure_future(self.listen()) + self.keep_alive = KeepAlive(self, 3) + self.keep_alive.start() + except OSError: + delay = backoff.delay() + logger.error(f"Connection refused, trying again in {delay:.2f}s") + await asyncio.sleep(delay) async def connect(self): await self._connect() @@ -153,8 +148,8 @@ async def listen(self): pass # ping() handles this for us, no need to hear it twice.. async def on_open(self): - await self.lavalink.load_balancer.on_node_connect(self) self.available = True + await self.lavalink.load_balancer.on_node_connect(self) async def on_close(self, code, reason): self.closing = False @@ -177,7 +172,8 @@ async def on_message(self, msg): op = msg.get("op") if op == "playerUpdate": link = self.lavalink.get_link(msg.get("guildId")) - await link.player.provide_state(msg.get("state")) + if link: + await link.player.provide_state(msg.get("state")) elif op == "stats": self.stats = NodeStats(msg) elif op == "event": diff --git a/core/player.py b/core/player.py index e66a549..1d5eea3 100644 --- a/core/player.py +++ b/core/player.py @@ -1,9 +1,20 @@ +import traceback +from enum import Enum from time import time from .exceptions import IllegalAction from .events import InternalEventAdapter, TrackPauseEvent, TrackResumeEvent, TrackStartEvent +class LoadTypes(Enum): + NO_MATCHES = -2 + LOAD_FAILED = -1 + UNKNOWN = 0 + TRACK_LOADED = 1 + PLAYLIST_LOADED = 2 + SEARCH_RESULT = 3 + + class AudioTrack: """ The base AudioTrack class that is used by the player to play songs @@ -20,6 +31,25 @@ def __init__(self, track): self.user_data = None +class AudioTrackPlaylist: + def __init__(self, results): + self.playlist_info = results["playlistInfo"] + self.playlist_name = self.playlist_info.get("name") + self.selected_track = self.playlist_info.get("selectedTrack") + self.load_type = LoadTypes[results["loadType"]] + self.tracks = [AudioTrack(track) for track in results["tracks"]] + + def __iter__(self): + for track in self.tracks: + yield track + + def __len__(self): + return self.tracks.__len__() + + def __getitem__(self, item): + return self.tracks[item] + + class Player: internal_event_adapter = InternalEventAdapter() @@ -161,10 +191,12 @@ async def destroy(self): "guildId": str(self.link.guild_id), } node = await self.link.get_node() - if node.available: + if node and node.available: await node.send(payload) - await self.event_adapter.destroy() - self.event_adapter = None + + if self.event_adapter: + await self.event_adapter.destroy() + self.event_adapter = None async def node_changed(self): if self.current: @@ -173,4 +205,7 @@ async def node_changed(self): async def trigger_event(self, event): await Player.internal_event_adapter.on_event(event) if self.event_adapter: # If we defined our on adapter - await self.event_adapter.on_event(event) + try: + await self.event_adapter.on_event(event) + except: + traceback.print_exc() diff --git a/requirements.txt b/requirements.txt index dcfec04..299ad24 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,2 +1 @@ git+https://github.com/Rapptz/discord.py@rewrite#egg=discord.py[voice] -aiohttp==2.2.5