Skip to content

Commit 0ff7aa3

Browse files
committed
port all checks from sir-lancebot and bot
1 parent ee2501e commit 0ff7aa3

File tree

1 file changed

+187
-0
lines changed

1 file changed

+187
-0
lines changed

pydis_core/utils/checks.py

Lines changed: 187 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,187 @@
1+
import datetime
2+
from collections.abc import Callable, Container, Iterable
3+
4+
from discord.ext.commands import (
5+
BucketType,
6+
CheckFailure,
7+
Cog,
8+
Command,
9+
CommandOnCooldown,
10+
Context,
11+
Cooldown,
12+
CooldownMapping,
13+
)
14+
15+
from pydis_core.utils.logging import get_logger
16+
17+
log = get_logger(__name__)
18+
19+
20+
class ContextCheckFailure(CheckFailure):
21+
"""Raised when a context-specific check fails."""
22+
23+
def __init__(self, redirect_channel: int | None) -> None:
24+
self.redirect_channel = redirect_channel
25+
26+
if redirect_channel:
27+
redirect_message = f" here. Please use the <#{redirect_channel}> channel instead"
28+
else:
29+
redirect_message = ""
30+
31+
error_message = f"You are not allowed to use that command{redirect_message}."
32+
33+
super().__init__(error_message)
34+
35+
36+
class InWhitelistCheckFailure(ContextCheckFailure):
37+
"""Raised when the `in_whitelist` check fails."""
38+
39+
40+
def in_whitelist_check(
41+
ctx: Context,
42+
redirect: int,
43+
channels: Container[int] = (),
44+
categories: Container[int] = (),
45+
roles: Container[int] = (),
46+
fail_silently: bool = False,
47+
) -> bool:
48+
"""
49+
Check if a command was issued in a whitelisted context.
50+
51+
The whitelists that can be provided are:
52+
53+
- `channels`: a container with channel ids for whitelisted channels
54+
- `categories`: a container with category ids for whitelisted categories
55+
- `roles`: a container with with role ids for whitelisted roles
56+
57+
If the command was invoked in a context that was not whitelisted, the member is either
58+
redirected to the `redirect` channel that was passed (default: #bot-commands) or simply
59+
told that they're not allowed to use this particular command (if `None` was passed).
60+
"""
61+
if redirect not in channels:
62+
# It does not make sense for the channel whitelist to not contain the redirection
63+
# channel (if applicable). That's why we add the redirection channel to the `channels`
64+
# container if it's not already in it. As we allow any container type to be passed,
65+
# we first create a tuple in order to safely add the redirection channel.
66+
#
67+
# Note: It's possible for the redirect channel to be in a whitelisted category, but
68+
# there's no easy way to check that and as a channel can easily be moved in and out of
69+
# categories, it's probably not wise to rely on its category in any case.
70+
channels = tuple(channels) + (redirect,)
71+
72+
if channels and ctx.channel.id in channels:
73+
log.trace(f"{ctx.author} may use the `{ctx.command.name}` command as they are in a whitelisted channel.")
74+
return True
75+
76+
# Only check the category id if we have a category whitelist and the channel has a `category_id`
77+
if categories and hasattr(ctx.channel, "category_id") and ctx.channel.category_id in categories:
78+
log.trace(f"{ctx.author} may use the `{ctx.command.name}` command as they are in a whitelisted category.")
79+
return True
80+
81+
# Only check the roles whitelist if we have one and ensure the author's roles attribute returns
82+
# an iterable to prevent breakage in DM channels (for if we ever decide to enable commands there).
83+
if roles and any(r.id in roles for r in getattr(ctx.author, "roles", ())):
84+
log.trace(f"{ctx.author} may use the `{ctx.command.name}` command as they have a whitelisted role.")
85+
return True
86+
87+
log.trace(f"{ctx.author} may not use the `{ctx.command.name}` command within this context.")
88+
89+
# Some commands are secret, and should produce no feedback at all.
90+
if not fail_silently:
91+
raise InWhitelistCheckFailure(redirect)
92+
return False
93+
94+
95+
def cooldown_with_role_bypass(rate: int, per: float, type: BucketType = BucketType.default, *,
96+
bypass_roles: Iterable[int]) -> Callable:
97+
"""
98+
Applies a cooldown to a command, but allows members with certain roles to be ignored.
99+
100+
NOTE: this replaces the `Command.before_invoke` callback, which *might* introduce problems in the future.
101+
"""
102+
# Make it a set so lookup is hash based.
103+
bypass = set(bypass_roles)
104+
105+
# This handles the actual cooldown logic.
106+
buckets = CooldownMapping(Cooldown(rate, per, type))
107+
108+
# Will be called after the command has been parse but before it has been invoked, ensures that
109+
# the cooldown won't be updated if the user screws up their input to the command.
110+
async def predicate(cog: Cog, ctx: Context) -> None:
111+
nonlocal bypass, buckets
112+
113+
if any(role.id in bypass for role in ctx.author.roles):
114+
return
115+
116+
# Cooldown logic, taken from discord.py internals.
117+
current = ctx.message.created_at.replace(tzinfo=datetime.UTC).timestamp()
118+
bucket = buckets.get_bucket(ctx.message)
119+
retry_after = bucket.update_rate_limit(current)
120+
if retry_after:
121+
raise CommandOnCooldown(bucket, retry_after)
122+
123+
def wrapper(command: Command) -> Command:
124+
# NOTE: this could be changed if a subclass of Command were to be used. I didn't see the need for it
125+
# so I just made it raise an error when the decorator is applied before the actual command object exists.
126+
#
127+
# If the `before_invoke` detail is ever a problem then I can quickly just swap over.
128+
if not isinstance(command, Command):
129+
raise TypeError(
130+
"Decorator `cooldown_with_role_bypass` must be applied after the command decorator. "
131+
"This means it has to be above the command decorator in the code."
132+
)
133+
134+
command._before_invoke = predicate
135+
136+
return command
137+
138+
return wrapper
139+
140+
141+
async def has_any_role_check(ctx: Context, *roles: str | int) -> bool:
142+
"""
143+
Returns True if the context's author has any of the specified roles.
144+
145+
`roles` are the names or IDs of the roles for which to check.
146+
False is always returns if the context is outside a guild.
147+
"""
148+
if not ctx.guild: # Return False in a DM
149+
log.trace(
150+
f"{ctx.author} tried to use the '{ctx.command.name}'command from a DM. "
151+
"This command is restricted by the with_role decorator. Rejecting request."
152+
)
153+
return False
154+
155+
for role in ctx.author.roles:
156+
if role.id in roles:
157+
log.trace(f"{ctx.author} has the '{role.name}' role, and passes the check.")
158+
return True
159+
160+
log.trace(
161+
f"{ctx.author} does not have the required role to use "
162+
f"the '{ctx.command.name}' command, so the request is rejected."
163+
)
164+
return False
165+
166+
167+
async def has_no_roles_check(ctx: Context, *roles: str | int) -> bool:
168+
"""
169+
Returns True if the context's author doesn't have any of the specified roles.
170+
171+
`roles` are the names or IDs of the roles for which to check.
172+
False is always returns if the context is outside a guild.
173+
"""
174+
if not ctx.guild: # Return False in a DM
175+
log.trace(
176+
f"{ctx.author} tried to use the '{ctx.command.name}' command from a DM. "
177+
"This command is restricted by the without_role decorator. Rejecting request."
178+
)
179+
return False
180+
181+
author_roles = [role.id for role in ctx.author.roles]
182+
check = all(role not in author_roles for role in roles)
183+
log.trace(
184+
f"{ctx.author} tried to call the '{ctx.command.name}' command. "
185+
f"The result of the without_role check was {check}."
186+
)
187+
return check

0 commit comments

Comments
 (0)