Skip to content

Commit

Permalink
feat: support lavalink v4 (#68)
Browse files Browse the repository at this point in the history
* refactor(node): use new resuming

* refactor(player): use new track field in events

* feat(track): add artwork and isrc

* refactor: use new type capitalisation

* refactor(node): use new load types

* refactor: remove utils.decode_track

This was always experimental anyway.

* refactor: update to pyright 1.1.306
  • Loading branch information
ooliver1 authored May 3, 2023
1 parent c39d54f commit cd63b41
Show file tree
Hide file tree
Showing 21 changed files with 333 additions and 315 deletions.
2 changes: 2 additions & 0 deletions mafic/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@ class TrackLoadException(PlayerException):
The message returned by the node.
severity: :data:`~typing.Literal`\[``"COMMON"``, ``"SUSPICIOUS"``, ``"FATAL"``]
The severity of the error.
This is lowercase in Lavalink v4.
cause: :class:`str`
The cause of the error.
"""
Expand Down
2 changes: 1 addition & 1 deletion mafic/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def __init__(
self, *, track: Track, payload: TrackEndEventPayload, player: PlayerT
) -> None:
self.track: Track = track
self.reason: EndReason = EndReason(payload["reason"])
self.reason: EndReason = EndReason(payload["reason"].upper())
self.player: PlayerT = player

def __repr__(self) -> str:
Expand Down
170 changes: 130 additions & 40 deletions mafic/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
TrackLoadingResult,
UpdatePlayerParams,
UpdatePlayerPayload,
UpdateSessionPayload,
)

_log = getLogger(__name__)
Expand Down Expand Up @@ -141,12 +142,25 @@ class Node(Generic[ClientT]):
resume_key:
The key to use when resuming the node.
If not provided, the key will be generated from the host, port and label.
.. warning::
This is ignored in lavalink V4, use ``resuming_session_id`` instead.
regions:
The voice regions that the node can be used in.
This is used to determine when to use this node.
shard_ids:
The shard IDs that the node can be used in.
This is used to determine when to use this node.
resuming_session_id:
The session ID to use when resuming the node.
If not provided, the node will not resume.
This should be stored from :func:`~mafic.on_node_ready` with :attr:`session_id`
to resume the session and gain control of the players. If the node is not
resuming, players will be destroyed if Lavalink loses connection to us.
.. versionadded:: 2.2
Attributes
----------
Expand All @@ -162,6 +176,7 @@ class Node(Generic[ClientT]):
"__password",
"__session",
"_available",
"_checked_version",
"_client",
"_connect_task",
"_heartbeat",
Expand All @@ -175,8 +190,10 @@ class Node(Generic[ClientT]):
"_timeout",
"_ready",
"_rest_uri",
"_resuming_session_id",
"_session_id",
"_stats",
"_version",
"_ws",
"_ws_uri",
"_ws_task",
Expand All @@ -199,6 +216,7 @@ def __init__(
resume_key: str | None = None,
regions: Sequence[Group | Region | VoiceRegion] | None = None,
shard_ids: Sequence[int] | None = None,
resuming_session_id: str | None = None,
) -> None:
self._host = host
self._port = port
Expand All @@ -217,6 +235,7 @@ def __init__(
)
self._ws_uri = yarl.URL.build(scheme=f"ws{'s'*secure}", host=host, port=port)
self._resume_key = resume_key or f"{host}:{port}:{label}"
self._resuming_session_id: str = resuming_session_id or ""

self._ws: ClientWebSocketResponse | None = None
self._ws_task: Task[None] | None = None
Expand All @@ -232,6 +251,9 @@ def __init__(
self._msg_tasks: set[Task[None]] = set()
self._connect_task: Task[None] | None = None

self._checked_version: bool = False
self._version: int = 3

@property
def host(self) -> str:
"""The host of the node."""
Expand Down Expand Up @@ -365,6 +387,27 @@ def players(self) -> list[Player[ClientT]]:
"""
return [*self._players.values()]

@property
def session_id(self) -> str | None:
"""The session ID of the node.
This is ``None`` if the node is not connected.
.. versionadded:: 2.2
"""
return self._session_id

@property
def version(self) -> int:
"""The major semver version of the node.
This is ``3`` if the node is not connected.
This is mostly used in :class:`Player` for version checks.
.. versionadded:: 2.2
"""
return self._version

def get_player(self, guild_id: int) -> Player[ClientT] | None:
r"""Get a player from the node.
Expand Down Expand Up @@ -407,29 +450,31 @@ def remove_player(self, guild_id: int) -> None:
"""
self._players.pop(guild_id, None)

async def _check_version(self) -> None:
"""Check the version of the node.
async def _check_version(self) -> int:
""":class:`int`: The major version of the node.
This also does checks based on if that is supported.
Raises
------
:exc:`RuntimeError`
If the
- major version is not 3
- minor version is less than 7
- major version is not in (3, 4)
- minor version is less than 7 when the major version is 3
This is because the rest api is in 3.7, and v4 will have breaking changes.
This is because the rest api is in 3.7, and v5 will have breaking changes.
Warns
-----
:class:`UnsupportedVersionWarning`
If the minor version is greater than 7.
If the
- major version is 3 and the minor version is more than 7
- major version is 4 and the minor version is more than 0
Some features may not work.
"""
if self._rest_uri.path.endswith("/v3") or self._ws_uri.path.endswith(
"/websocket"
):
if self._checked_version:
# This process was already ran likely.
return
return self._version

if self.__session is None:
self.__session = await self._create_session()
Expand All @@ -444,21 +489,34 @@ async def _check_version(self) -> None:
try:
major, minor, _ = version.split(".", maxsplit=2)
except ValueError:
message = UnknownVersionWarning.message
warnings.warn(message, UnknownVersionWarning)
if version.endswith("-SNAPSHOT"):
major = 4
minor = 0
else:
major = 3
minor = 7
message = UnknownVersionWarning.message
warnings.warn(message, UnknownVersionWarning)
else:
major = int(major)
minor = int(minor)

if major != 3 or minor < 7:
msg = f"Unsupported lavalink version {version} (expected 3.7.x)"
if major not in (3, 4) or (major == 3 and minor < 7):
msg = (
f"Unsupported lavalink version {version} "
"(expected 3.7.x or 4.x.x)"
)
raise RuntimeError(msg)
elif minor > 7:
elif (major == 3 and minor > 7) or (major == 4 and minor > 0):
message = UnsupportedVersionWarning.message
warnings.warn(message, UnsupportedVersionWarning)

self._rest_uri /= "v3"
self._ws_uri /= "v3/websocket"
self._rest_uri /= f"v{major}"
self._ws_uri /= f"v{major}/websocket"

self._version = major
self._checked_version = True
return major

async def _connect_to_websocket(
self, headers: dict[str, str], session: aiohttp.ClientSession
Expand Down Expand Up @@ -522,15 +580,20 @@ async def connect(

session = self.__session

_log.debug("Checking lavalink version...", extra={"label": self._label})
version = await self._check_version()

headers: dict[str, str] = {
"Authorization": self.__password,
"User-Id": str(self._client.user.id),
"Client-Name": f"Mafic/{__import__('mafic').__version__}",
"Resume-Key": self._resume_key,
}

_log.debug("Checking lavalink version...", extra={"label": self._label})
await self._check_version()
# V4 uses session ID resuming
if version == 3:
headers["Resume-Key"] = self._resume_key
else:
headers["Session-Id"] = self._resuming_session_id

_log.info(
"Connecting to lavalink at %s...",
Expand Down Expand Up @@ -590,6 +653,7 @@ def remove_task(_: Task[None]) -> None:
)
await self.sync_players()
self._available = True
self._client.dispatch("node_ready", self)

async def _ws_listener(self) -> None:
"""Listen for messages from the websocket."""
Expand All @@ -610,9 +674,7 @@ async def _ws_listener(self) -> None:
_log.debug("Received message from websocket.", extra={"label": self._label})

# Please aiohttp, fix your typehints.
_type: aiohttp.WSMsgType = (
msg.type
) # pyright: ignore[reportUnknownMemberType]
_type: aiohttp.WSMsgType = msg.type # pyright: ignore

if _type is aiohttp.WSMsgType.CLOSED:
self._available = False
Expand Down Expand Up @@ -656,8 +718,8 @@ async def _handle_msg(self, data: IncomingMessage) -> None:
data:
The data to handle.
"""
_log.debug("Received event with op %s", data["op"])
_log.debug("Event data: %s", data)
_log.debug("Received event with op %s", data["op"])

if data["op"] == "playerUpdate":
guild_id = int(data["guildId"])
Expand Down Expand Up @@ -692,11 +754,18 @@ async def _handle_msg(self, data: IncomingMessage) -> None:
extra={"label": self._label},
)
else:
_log.debug(
"Sending configuration to resume with key %s",
self._resume_key,
extra={"label": self._label},
)
if self._version == 3:
_log.debug(
"Sending configuration to resume with key %s",
self._resume_key,
extra={"label": self._label},
)
else:
_log.debug(
"Sending configuration to resume with session ID %s",
self._session_id,
extra={"label": self._label},
)
await self.configure_resuming()

self._ready.set()
Expand Down Expand Up @@ -765,19 +834,32 @@ def voice_update(

def configure_resuming(self) -> Coro[None]:
"""Configure the node to resume."""
_log.info(
"Sending resume configuration to lavalink with resume key %s.",
self._resume_key,
extra={"label": self._label},
)
data: UpdateSessionPayload
if self._version == 3:
_log.info(
"Sending resume configuration to lavalink with resume key %s.",
self._resume_key,
extra={"label": self._label},
)
data = {
"resumingKey": self._resume_key,
"timeout": 60,
}
else:
_log.info(
"Sending resume configuration to lavalink with session ID %s.",
self._session_id,
extra={"label": self._label},
)
data = {
"resuming": True,
"timeout": 60,
}

return self.__request(
"PATCH",
f"sessions/{self._session_id}",
{
"resumingKey": self._resume_key,
"timeout": 60,
},
data,
)

def destroy(self, guild_id: int) -> Coro[None]:
Expand Down Expand Up @@ -939,7 +1021,7 @@ async def __request(
_log.debug("Received raw data %s from %s", json, path)
return json

async def fetch_tracks(
async def fetch_tracks( # noqa: PLR0911 # V3/V4 compat.
self, query: str, *, search_type: str
) -> list[Track] | Playlist | None:
r"""Fetch tracks from the node.
Expand Down Expand Up @@ -967,8 +1049,16 @@ async def fetch_tracks(
"GET", "loadtracks", params={"identifier": query}
)

if data["loadType"] == "NO_MATCHES":
if data["loadType"] in ("empty", "NO_MATCHES"):
return []
elif data["loadType"] == "track":
return [Track.from_data_with_info(data["data"])]
elif data["loadType"] == "playlist":
return Playlist(info=data["data"]["info"], tracks=data["data"]["tracks"])
elif data["loadType"] == "search":
return [Track.from_data_with_info(track) for track in data["data"]]
elif data["loadType"] == "error":
raise TrackLoadException.from_data(data["data"])
elif data["loadType"] == "TRACK_LOADED":
return [Track.from_data_with_info(data["tracks"][0])]
elif data["loadType"] == "PLAYLIST_LOADED":
Expand Down
Loading

0 comments on commit cd63b41

Please sign in to comment.