Skip to content

Commit

Permalink
feat: make events generic on player
Browse files Browse the repository at this point in the history
  • Loading branch information
ooliver1 committed Feb 7, 2023
1 parent 0e053cb commit 68868e0
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 23 deletions.
44 changes: 22 additions & 22 deletions mafic/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@
from __future__ import annotations

from enum import Enum
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Generic, TypeVar

from .__libraries import VoiceProtocol

if TYPE_CHECKING:
from .player import Player
from .track import Track
from .type_variables import ClientT
from .typings import (
LavalinkException,
TrackEndEvent as TrackEndEventPayload,
Expand All @@ -19,6 +19,10 @@
WebSocketClosedEvent as WebSocketClosedEventPayload,
)


# This needs HKTs in python - as Player is generic on ClientT.
PlayerT = TypeVar("PlayerT", bound=VoiceProtocol)

__all__ = (
"EndReason",
"TrackEndEvent",
Expand Down Expand Up @@ -48,7 +52,7 @@ class EndReason(str, Enum):
"""The track was cleaned up."""


class WebSocketClosedEvent:
class WebSocketClosedEvent(Generic[PlayerT]):
"""Represents an event when the connection to Discord is lost.
Attributes
Expand All @@ -68,13 +72,11 @@ class WebSocketClosedEvent:

__slots__ = ("code", "reason", "by_discord", "player")

def __init__(
self, *, payload: WebSocketClosedEventPayload, player: Player[ClientT]
):
def __init__(self, *, payload: WebSocketClosedEventPayload, player: PlayerT):
self.code: int = payload["code"]
self.reason: str = payload["reason"]
self.by_discord: bool = payload["byRemote"]
self.player: Player[ClientT] = player
self.player: PlayerT = player

def __repr__(self) -> str:
return (
Expand All @@ -83,7 +85,7 @@ def __repr__(self) -> str:
)


class TrackStartEvent:
class TrackStartEvent(Generic[PlayerT]):
"""Represents an event when a track starts playing.
Attributes
Expand All @@ -96,15 +98,15 @@ class TrackStartEvent:

__slots__ = ("track", "player")

def __init__(self, *, track: Track, player: Player[ClientT]):
def __init__(self, *, track: Track, player: PlayerT):
self.track: Track = track
self.player: Player[ClientT] = player
self.player: PlayerT = player

def __repr__(self) -> str:
return f"<TrackStartEvent track={self.track!r}>"


class TrackEndEvent:
class TrackEndEvent(Generic[PlayerT]):
"""Represents an event when a track ends.
Attributes
Expand All @@ -119,18 +121,16 @@ class TrackEndEvent:

__slots__ = ("track", "reason", "player")

def __init__(
self, *, track: Track, payload: TrackEndEventPayload, player: Player[ClientT]
):
def __init__(self, *, track: Track, payload: TrackEndEventPayload, player: PlayerT):
self.track: Track = track
self.reason: EndReason = EndReason(payload["reason"])
self.player: Player[ClientT] = player
self.player: PlayerT = player

def __repr__(self) -> str:
return f"<TrackEndEvent track={self.track!r} reason={self.reason!r}>"


class TrackExceptionEvent:
class TrackExceptionEvent(Generic[PlayerT]):
"""Represents an event when an exception occurs while playing a track.
Attributes
Expand All @@ -150,19 +150,19 @@ def __init__(
*,
track: Track,
payload: TrackExceptionEventPayload,
player: Player[ClientT],
player: PlayerT,
):
self.track: Track = track
self.exception: LavalinkException = payload["exception"]
self.player: Player[ClientT] = player
self.player: PlayerT = player

def __repr__(self) -> str:
return (
f"<TrackExceptionEvent track={self.track!r} exception={self.exception!r}>"
)


class TrackStuckEvent:
class TrackStuckEvent(Generic[PlayerT]):
"""Represents an event when a track gets stuck.
Attributes
Expand All @@ -178,11 +178,11 @@ class TrackStuckEvent:
__slots__ = ("track", "threshold_ms", "player")

def __init__(
self, *, track: Track, payload: TrackStuckEventPayload, player: Player[ClientT]
self, *, track: Track, payload: TrackStuckEventPayload, player: PlayerT
):
self.track: Track = track
self.threshold_ms: int = payload["thresholdMs"]
self.player: Player[ClientT] = player
self.player: PlayerT = player

def __repr__(self) -> str:
return (
Expand Down
2 changes: 1 addition & 1 deletion test_bot/bot/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ async def play(inter: Interaction, query: str):


@bot.listen()
async def on_track_end(event: TrackEndEvent):
async def on_track_end(event: TrackEndEvent[MyPlayer]):
if event.player.queue:
await event.player.play(event.player.queue.pop(0))

Expand Down

0 comments on commit 68868e0

Please sign in to comment.