Skip to content

Commit 9c24bc5

Browse files
committed
feat: re-implemented support with lavalink v4
1 parent ed6a751 commit 9c24bc5

File tree

3 files changed

+91
-21
lines changed

3 files changed

+91
-21
lines changed

Teapot.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import discord
88
from discord.ext import commands as dcmd
99
from dotenv import load_dotenv
10+
import lavalink
1011

1112
import teapot
1213
from teapot.event_handler.loader import EventHandlerLoader
@@ -101,14 +102,20 @@ async def on_ready():
101102

102103
# load cogs
103104
teapot.events.__init__(bot)
105+
# Initialize lavalink client once here so cogs do not need to recreate it
106+
if not hasattr(bot, 'lavalink'):
107+
print("Initializing Lavalink client...")
108+
bot.lavalink = lavalink.Client(bot.user.id)
109+
bot.lavalink.add_node(teapot.config.lavalink_host(), teapot.config.lavalink_port(), teapot.config.lavalink_password(), 'zz', 'default')
110+
bot.add_listener(bot.lavalink.voice_update_handler, 'on_socket_response')
104111
extensions = [
105112
'teapot.cogs.cmds',
106113
'teapot.cogs.osu',
107114
'teapot.cogs.github',
108115
'teapot.cogs.cat',
109116
'teapot.cogs.neko',
110117
'teapot.cogs.nqn',
111-
# 'teapot.cogs.music' -- TODO: WIP
118+
'teapot.cogs.music' # TODO: WIP
112119
]
113120

114121
for extension in extensions:

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,4 +20,5 @@ websockets>=15.0.1
2020
yarl>=1.20.1
2121
mysql-connector-python>=9.4.0
2222
alt-profanity-check==1.7.2
23+
PyNaCl>=1.6.1
2324
protobuf>=6.32.1 # not directly required, pinned by Snyk to avoid a vulnerability

teapot/cogs/music.py

Lines changed: 82 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -3,27 +3,78 @@
33

44
import discord
55
import lavalink
6+
from lavalink.errors import ClientError
67
from discord.ext import commands
78

89
import teapot
910

10-
url_rx = re.compile('https?:\\/\\/(?:www\\.)?.+') # noqa: W605
11+
url_rx = re.compile('https?:\/\/(?:www\.)?.+') # noqa: W605
1112

13+
class LavalinkVoiceClient(discord.VoiceProtocol):
14+
"""Voice protocol implementation that relays Discord voice events to Lavalink."""
1215

13-
class Music(commands.Cog):
16+
def __init__(self, client: discord.Client, channel: discord.abc.Connectable):
17+
super().__init__(client, channel)
18+
if not hasattr(client, 'lavalink'):
19+
raise RuntimeError('Lavalink client is not initialized on the bot.')
20+
21+
self.guild_id = channel.guild.id
22+
self._destroyed = False
23+
self.lavalink: lavalink.Client = client.lavalink
24+
25+
async def connect(self, *, timeout: float, reconnect: bool, self_deaf: bool = False, self_mute: bool = False):
26+
player = self.lavalink.player_manager.create(self.guild_id)
27+
player.channel_id = str(self.channel.id)
28+
await self.channel.guild.change_voice_state(channel=self.channel, self_deaf=self_deaf, self_mute=self_mute)
29+
self._destroyed = False
30+
31+
async def disconnect(self, *, force: bool = False):
32+
player = self.lavalink.player_manager.get(self.guild_id)
33+
34+
if player and not force and not player.is_connected:
35+
return
36+
37+
await self.channel.guild.change_voice_state(channel=None)
38+
39+
if player:
40+
player.channel_id = None
41+
42+
await self._destroy()
43+
44+
async def on_voice_server_update(self, data):
45+
await self.lavalink.voice_update_handler({'t': 'VOICE_SERVER_UPDATE', 'd': data})
46+
47+
async def on_voice_state_update(self, data):
48+
channel_id = data.get('channel_id')
49+
50+
if channel_id:
51+
maybe_channel = self.client.get_channel(int(channel_id))
52+
if maybe_channel is not None:
53+
self.channel = maybe_channel
54+
else:
55+
await self._destroy()
56+
57+
await self.lavalink.voice_update_handler({'t': 'VOICE_STATE_UPDATE', 'd': data})
58+
59+
async def _destroy(self):
60+
if self._destroyed:
61+
return
62+
63+
self._destroyed = True
64+
self.cleanup()
65+
66+
try:
67+
await self.lavalink.player_manager.destroy(self.guild_id)
68+
except ClientError:
69+
pass
70+
71+
class Music(commands.Cog): # TODO: event check to save-up resources when not one is in voice channels
1472
"""Music Time"""
1573

1674
def __init__(self, bot):
1775
self.bot = bot
18-
19-
if not hasattr(bot, 'lavalink'): # This ensures the client isn't overwritten during cog reloads.
20-
bot.lavalink = lavalink.Client(bot.user.id)
21-
bot.lavalink.add_node(teapot.config.lavalink_host(), teapot.config.lavalink_port(),
22-
teapot.config.lavalink_password(), 'zz',
23-
'default') # Host, Port, Password, Region, Name
24-
bot.add_listener(bot.lavalink.voice_update_handler, 'on_socket_response')
25-
26-
bot.lavalink.add_event_hook(self.track_hook)
76+
self.lavalink: lavalink.Client = bot.lavalink
77+
self.lavalink.add_event_hook(self.track_hook)
2778

2879
def cog_unload(self):
2980
self.bot.lavalink._event_hooks.clear()
@@ -41,12 +92,12 @@ async def cog_command_error(self, ctx, error):
4192
async def track_hook(self, event):
4293
if isinstance(event, lavalink.events.QueueEndEvent):
4394
guild_id = int(event.player.guild_id)
44-
await self.connect_to(guild_id, None)
45-
46-
async def connect_to(self, guild_id: int, channel_id: str):
47-
""" Connects to the given voice channel ID. A channel_id of `None` means disconnect. """
48-
ws = self.bot._connection._get_websocket(guild_id)
49-
await ws.voice_state(str(guild_id), channel_id)
95+
guild = self.bot.get_guild(guild_id)
96+
if guild and guild.voice_client:
97+
try:
98+
await guild.voice_client.disconnect(force=True)
99+
except Exception:
100+
pass
50101

51102
@commands.command(aliases=['p'])
52103
async def play(self, ctx, *, query: str):
@@ -263,7 +314,8 @@ async def disconnect(self, ctx):
263314

264315
player.queue.clear()
265316
await player.stop()
266-
await self.connect_to(ctx.guild.id, None)
317+
if ctx.voice_client:
318+
await ctx.voice_client.disconnect(force=True)
267319
await ctx.send('*⃣ | Disconnected.')
268320

269321
async def ensure_voice(self, ctx):
@@ -284,9 +336,19 @@ async def ensure_voice(self, ctx):
284336
raise commands.CommandInvokeError('I need the `CONNECT` and `SPEAK` permissions.')
285337

286338
player.store('channel', ctx.channel.id)
287-
await self.connect_to(ctx.guild.id, str(ctx.author.voice.channel.id))
339+
# mark the player's voice channel for lavalink to know which guild/channel
340+
try:
341+
player.channel_id = str(ctx.author.voice.channel.id)
342+
except Exception:
343+
pass
344+
345+
voice_client = ctx.voice_client
346+
if voice_client and voice_client.channel.id != ctx.author.voice.channel.id:
347+
raise commands.CommandInvokeError('You need to be in my voice channel.')
348+
if not voice_client:
349+
await ctx.author.voice.channel.connect(cls=LavalinkVoiceClient)
288350
else:
289-
if int(player.channel_id) != ctx.author.voice.channel.id:
351+
if not player.channel_id or int(player.channel_id) != ctx.author.voice.channel.id:
290352
raise commands.CommandInvokeError('You need to be in my voice channel.')
291353

292354

0 commit comments

Comments
 (0)