33
44import discord
55import lavalink
6+ from lavalink .errors import ClientError
67from discord .ext import commands
78
89import teapot
910
10- url_rx = re .compile ('https?:\\ / \\ /(?:www\ \ .)?.+' ) # noqa: W605
11+ url_rx = re .compile ('https?:\/\ /(?:www\.)?.+' ) # noqa: W605
1112
13+ class LavalinkVoiceClient (discord .VoiceProtocol ):
14+ """Voice protocol implementation that relays Discord voice events to Lavalink."""
1215
13- class Music (commands .Cog ):
16+ def __init__ (self , client : discord .Client , channel : discord .abc .Connectable ):
17+ super ().__init__ (client , channel )
18+ if not hasattr (client , 'lavalink' ):
19+ raise RuntimeError ('Lavalink client is not initialized on the bot.' )
20+
21+ self .guild_id = channel .guild .id
22+ self ._destroyed = False
23+ self .lavalink : lavalink .Client = client .lavalink
24+
25+ async def connect (self , * , timeout : float , reconnect : bool , self_deaf : bool = False , self_mute : bool = False ):
26+ player = self .lavalink .player_manager .create (self .guild_id )
27+ player .channel_id = str (self .channel .id )
28+ await self .channel .guild .change_voice_state (channel = self .channel , self_deaf = self_deaf , self_mute = self_mute )
29+ self ._destroyed = False
30+
31+ async def disconnect (self , * , force : bool = False ):
32+ player = self .lavalink .player_manager .get (self .guild_id )
33+
34+ if player and not force and not player .is_connected :
35+ return
36+
37+ await self .channel .guild .change_voice_state (channel = None )
38+
39+ if player :
40+ player .channel_id = None
41+
42+ await self ._destroy ()
43+
44+ async def on_voice_server_update (self , data ):
45+ await self .lavalink .voice_update_handler ({'t' : 'VOICE_SERVER_UPDATE' , 'd' : data })
46+
47+ async def on_voice_state_update (self , data ):
48+ channel_id = data .get ('channel_id' )
49+
50+ if channel_id :
51+ maybe_channel = self .client .get_channel (int (channel_id ))
52+ if maybe_channel is not None :
53+ self .channel = maybe_channel
54+ else :
55+ await self ._destroy ()
56+
57+ await self .lavalink .voice_update_handler ({'t' : 'VOICE_STATE_UPDATE' , 'd' : data })
58+
59+ async def _destroy (self ):
60+ if self ._destroyed :
61+ return
62+
63+ self ._destroyed = True
64+ self .cleanup ()
65+
66+ try :
67+ await self .lavalink .player_manager .destroy (self .guild_id )
68+ except ClientError :
69+ pass
70+
71+ class Music (commands .Cog ): # TODO: event check to save-up resources when not one is in voice channels
1472 """Music Time"""
1573
1674 def __init__ (self , bot ):
1775 self .bot = bot
18-
19- if not hasattr (bot , 'lavalink' ): # This ensures the client isn't overwritten during cog reloads.
20- bot .lavalink = lavalink .Client (bot .user .id )
21- bot .lavalink .add_node (teapot .config .lavalink_host (), teapot .config .lavalink_port (),
22- teapot .config .lavalink_password (), 'zz' ,
23- 'default' ) # Host, Port, Password, Region, Name
24- bot .add_listener (bot .lavalink .voice_update_handler , 'on_socket_response' )
25-
26- bot .lavalink .add_event_hook (self .track_hook )
76+ self .lavalink : lavalink .Client = bot .lavalink
77+ self .lavalink .add_event_hook (self .track_hook )
2778
2879 def cog_unload (self ):
2980 self .bot .lavalink ._event_hooks .clear ()
@@ -41,12 +92,12 @@ async def cog_command_error(self, ctx, error):
4192 async def track_hook (self , event ):
4293 if isinstance (event , lavalink .events .QueueEndEvent ):
4394 guild_id = int (event .player .guild_id )
44- await self .connect_to (guild_id , None )
45-
46- async def connect_to ( self , guild_id : int , channel_id : str ) :
47- """ Connects to the given voice channel ID. A channel_id of `None` means disconnect. """
48- ws = self . bot . _connection . _get_websocket ( guild_id )
49- await ws . voice_state ( str ( guild_id ), channel_id )
95+ guild = self .bot . get_guild (guild_id )
96+ if guild and guild . voice_client :
97+ try :
98+ await guild . voice_client . disconnect ( force = True )
99+ except Exception :
100+ pass
50101
51102 @commands .command (aliases = ['p' ])
52103 async def play (self , ctx , * , query : str ):
@@ -263,7 +314,8 @@ async def disconnect(self, ctx):
263314
264315 player .queue .clear ()
265316 await player .stop ()
266- await self .connect_to (ctx .guild .id , None )
317+ if ctx .voice_client :
318+ await ctx .voice_client .disconnect (force = True )
267319 await ctx .send ('*⃣ | Disconnected.' )
268320
269321 async def ensure_voice (self , ctx ):
@@ -284,9 +336,19 @@ async def ensure_voice(self, ctx):
284336 raise commands .CommandInvokeError ('I need the `CONNECT` and `SPEAK` permissions.' )
285337
286338 player .store ('channel' , ctx .channel .id )
287- await self .connect_to (ctx .guild .id , str (ctx .author .voice .channel .id ))
339+ # mark the player's voice channel for lavalink to know which guild/channel
340+ try :
341+ player .channel_id = str (ctx .author .voice .channel .id )
342+ except Exception :
343+ pass
344+
345+ voice_client = ctx .voice_client
346+ if voice_client and voice_client .channel .id != ctx .author .voice .channel .id :
347+ raise commands .CommandInvokeError ('You need to be in my voice channel.' )
348+ if not voice_client :
349+ await ctx .author .voice .channel .connect (cls = LavalinkVoiceClient )
288350 else :
289- if int (player .channel_id ) != ctx .author .voice .channel .id :
351+ if not player . channel_id or int (player .channel_id ) != ctx .author .voice .channel .id :
290352 raise commands .CommandInvokeError ('You need to be in my voice channel.' )
291353
292354
0 commit comments